Browse code

Fixes #908 - Fixes unicode in req headers for py2 and py3

Florent Viard authored on 2017/08/19 23:36:15
Showing 6 changed files
... ...
@@ -26,6 +26,7 @@ from .Utils import getTreeFromXml, appendXmlTextNode, getDictFromTree, dateS3toP
26 26
 from .Crypto import sign_string_v2
27 27
 from .S3Uri import S3Uri, S3UriS3
28 28
 from .ConnMan import ConnMan
29
+from .SortedDict import SortedDict
29 30
 
30 31
 cloudfront_api_version = "2010-11-01"
31 32
 cloudfront_resource = "/%(api_ver)s/distribution" % { 'api_ver' : cloudfront_api_version }
... ...
@@ -397,7 +398,7 @@ class CloudFront(object):
397 397
                     break
398 398
                 warning("Still waiting...")
399 399
                 time.sleep(10)
400
-        headers = {}
400
+        headers = SortedDict(ignore_case = True)
401 401
         headers['if-match'] = response['headers']['etag']
402 402
         response = self.send_request("DeleteDist", dist_id = cfuri.dist_id(),
403 403
                                      headers = headers)
... ...
@@ -424,7 +425,7 @@ class CloudFront(object):
424 424
         debug("SetDistConfig(): Etag = %s" % etag)
425 425
         request_body = str(dist_config)
426 426
         debug("SetDistConfig(): request_body: %s" % request_body)
427
-        headers = {}
427
+        headers = SortedDict(ignore_case = True)
428 428
         headers['if-match'] = etag
429 429
         response = self.send_request("SetDistConfig", dist_id = cfuri.dist_id(),
430 430
                                      body = request_body, headers = headers)
... ...
@@ -498,7 +499,7 @@ class CloudFront(object):
498 498
 
499 499
     def send_request(self, op_name, dist_id = None, request_id = None, body = None, headers = None, retries = _max_retries):
500 500
         if headers is None:
501
-            headers = {}
501
+            headers = SortedDict(ignore_case = True)
502 502
         operation = self.operations[op_name]
503 503
         if body:
504 504
             headers['content-type'] = 'text/plain'
... ...
@@ -537,7 +538,7 @@ class CloudFront(object):
537 537
                    operation['resource'] % { 'dist_id' : dist_id, 'request_id' : request_id })
538 538
 
539 539
         if not headers:
540
-            headers = {}
540
+            headers = SortedDict(ignore_case = True)
541 541
 
542 542
         if "date" in headers:
543 543
             if "x-amz-date" not in headers:
... ...
@@ -69,7 +69,7 @@ def sign_string_v2(string_to_sign):
69 69
     return signature
70 70
 __all__.append("sign_string_v2")
71 71
 
72
-def sign_request_v2(method='GET', canonical_uri='/', params=None, cur_headers={}):
72
+def sign_request_v2(method='GET', canonical_uri='/', params=None, cur_headers=None):
73 73
     """Sign a string with the secret key, returning base64 encoded results.
74 74
     By default the configured secret key is used, but may be overridden as
75 75
     an argument.
... ...
@@ -86,6 +86,9 @@ def sign_request_v2(method='GET', canonical_uri='/', params=None, cur_headers={}
86 86
                                # Missing of aws s3 doc but needed
87 87
                                'delete', 'cors']
88 88
 
89
+    if cur_headers is None:
90
+        cur_headers = SortedDict(ignore_case = True)
91
+
89 92
     access_key = Config.Config().access_key
90 93
 
91 94
     string_to_sign  = method + "\n"
... ...
@@ -111,7 +114,7 @@ def sign_request_v2(method='GET', canonical_uri='/', params=None, cur_headers={}
111 111
     debug("SignHeaders: " + repr(string_to_sign))
112 112
     signature = decode_from_s3(sign_string_v2(encode_to_s3(string_to_sign)))
113 113
 
114
-    new_headers = dict(list(cur_headers.items()))
114
+    new_headers = SortedDict(list(cur_headers.items()), ignore_case=True)
115 115
     new_headers["Authorization"] = "AWS " + access_key + ":" + signature
116 116
 
117 117
     return new_headers
... ...
@@ -174,8 +177,10 @@ def getSignatureKey(key, dateStamp, regionName, serviceName):
174 174
     return kSigning
175 175
 
176 176
 def sign_request_v4(method='GET', host='', canonical_uri='/', params=None,
177
-                    region='us-east-1', cur_headers={}, body=b''):
177
+                    region='us-east-1', cur_headers=None, body=b''):
178 178
     service = 's3'
179
+    if cur_headers is None:
180
+        cur_headers = SortedDict(ignore_case = True)
179 181
 
180 182
     cfg = Config.Config()
181 183
     access_key = cfg.access_key
... ...
@@ -207,7 +212,7 @@ def sign_request_v4(method='GET', host='', canonical_uri='/', params=None,
207 207
         # avoid duplicate headers and previous Authorization
208 208
         if header == 'Authorization' or header in signed_headers.split(';'):
209 209
             continue
210
-        canonical_headers[header.strip()] = str(cur_headers[header]).strip()
210
+        canonical_headers[header.strip()] = cur_headers[header].strip()
211 211
         signed_headers += ';' + header.strip()
212 212
 
213 213
     # sort headers into a string
... ...
@@ -138,12 +138,13 @@ def httpconnection_patched_send_request(self, method, url, body, headers):
138 138
         if 'expect' == hdr.lower() and '100-continue' in value.lower():
139 139
             expect_continue = True
140 140
 
141
+    url = encode_to_s3(url)
141 142
     self.putrequest(method, url, **skips)
142 143
 
143 144
     if 'content-length' not in header_names:
144 145
         self._set_content_length(body, method)
145 146
     for hdr, value in headers.iteritems():
146
-        self.putheader(hdr, value)
147
+        self.putheader(encode_to_s3(hdr), encode_to_s3(value))
147 148
 
148 149
     # If an Expect: 100-continue was sent, we need to check for a 417
149 150
     # Expectation Failed to avoid unecessarily sending the body
... ...
@@ -11,6 +11,8 @@ from http.client import (_CS_REQ_SENT, _CS_REQ_STARTED, CONTINUE, UnknownProtoco
11 11
 
12 12
 from io import StringIO
13 13
 
14
+from .Utils import encode_to_s3
15
+
14 16
 
15 17
 _METHODS_EXPECTING_BODY = ['PATCH', 'POST', 'PUT']
16 18
 
... ...
@@ -162,7 +164,6 @@ def httpconnection_patched_send_request(self, method, url, body, headers,
162 162
     # 1. content-length has not been explicitly set
163 163
     # 2. the body is a file or iterable, but not a str or bytes-like
164 164
     # 3. Transfer-Encoding has NOT been explicitly set by the caller
165
-
166 165
     if 'content-length' not in header_names:
167 166
         # only chunk body if not explicitly set for backwards
168 167
         # compatibility, assuming the client code is already handling the
... ...
@@ -184,7 +185,7 @@ def httpconnection_patched_send_request(self, method, url, body, headers,
184 184
         encode_chunked = False
185 185
 
186 186
     for hdr, value in headers.items():
187
-        self.putheader(hdr, value)
187
+        self.putheader(encode_to_s3(hdr), encode_to_s3(value))
188 188
 
189 189
     if isinstance(body, str):
190 190
         # RFC 2616 Section 3.7.1 says that text default has a
... ...
@@ -740,8 +740,8 @@ class S3(object):
740 740
             raise ValueError("Key list is empty")
741 741
         bucket = S3Uri(batch[0]).bucket()
742 742
         request_body = compose_batch_del_xml(bucket, batch)
743
-        headers = {'content-md5': compute_content_md5(request_body),
744
-                   'content-type': 'application/xml'}
743
+        headers = SortedDict({'content-md5': compute_content_md5(request_body),
744
+                   'content-type': 'application/xml'}, ignore_case=True)
745 745
         request = self.create_request("BATCH_DELETE", bucket = bucket,
746 746
                                       headers = headers, body = request_body,
747 747
                                       uri_params = {'delete': None})
... ...
@@ -944,7 +944,7 @@ class S3(object):
944 944
         body = u"%s"% acl
945 945
         debug(u"set_acl(%s): acl-xml: %s" % (uri, body))
946 946
 
947
-        headers = {'content-type': 'application/xml'}
947
+        headers = SortedDict({'content-type': 'application/xml'}, ignore_case = True)
948 948
         if uri.has_object():
949 949
             request = self.create_request("OBJECT_PUT", uri = uri,
950 950
                                           headers = headers, body = body,
... ...
@@ -964,7 +964,7 @@ class S3(object):
964 964
         return response['data']
965 965
 
966 966
     def set_policy(self, uri, policy):
967
-        headers = {}
967
+        headers = SortedDict(ignore_case = True)
968 968
         # TODO check policy is proper json string
969 969
         headers['content-type'] = 'application/json'
970 970
         request = self.create_request("BUCKET_CREATE", uri = uri,
... ...
@@ -987,7 +987,7 @@ class S3(object):
987 987
         return response['data']
988 988
 
989 989
     def set_cors(self, uri, cors):
990
-        headers = {}
990
+        headers = SortedDict(ignore_case = True)
991 991
         # TODO check cors is proper json string
992 992
         headers['content-type'] = 'application/xml'
993 993
         headers['content-md5'] = compute_content_md5(cors)
... ...
@@ -1015,7 +1015,7 @@ class S3(object):
1015 1015
         return response
1016 1016
 
1017 1017
     def set_payer(self, uri):
1018
-        headers = {}
1018
+        headers = SortedDict(ignore_case = True)
1019 1019
         headers['content-type'] = 'application/xml'
1020 1020
         body = '<RequestPaymentConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">\n'
1021 1021
         if self.config.requester_pays:
... ...
@@ -1367,7 +1367,7 @@ class S3(object):
1367 1367
             conn = ConnMan.get(self.get_hostname(resource['bucket']))
1368 1368
             conn.c.putrequest(method_string, self.format_uri(resource, conn.path))
1369 1369
             for header in headers.keys():
1370
-                conn.c.putheader(header, str(headers[header]))
1370
+                conn.c.putheader(encode_to_s3(header), encode_to_s3(headers[header]))
1371 1371
             conn.c.endheaders()
1372 1372
         except ParameterError as e:
1373 1373
             raise
... ...
@@ -1600,7 +1600,7 @@ class S3(object):
1600 1600
         try:
1601 1601
             conn.c.putrequest(method_string, self.format_uri(resource, conn.path))
1602 1602
             for header in headers.keys():
1603
-                conn.c.putheader(header, str(headers[header]))
1603
+                conn.c.putheader(encode_to_s3(header), encode_to_s3(headers[header]))
1604 1604
             if start_position > 0:
1605 1605
                 debug("Requesting Range: %d .. end" % start_position)
1606 1606
                 conn.c.putheader("Range", "bytes=%d-" % start_position)
... ...
@@ -2801,16 +2801,16 @@ def main():
2801 2801
     if options.add_header:
2802 2802
         for hdr in options.add_header:
2803 2803
             try:
2804
-                key, val = hdr.split(":", 1)
2804
+                key, val = unicodise_s(hdr).split(":", 1)
2805 2805
             except ValueError:
2806
-                raise ParameterError("Invalid header format: %s" % hdr)
2806
+                raise ParameterError("Invalid header format: %s" % unicodise_s(hdr))
2807 2807
             key_inval = re.sub("[a-zA-Z0-9-.]", "", key)
2808 2808
             if key_inval:
2809 2809
                 key_inval = key_inval.replace(" ", "<space>")
2810 2810
                 key_inval = key_inval.replace("\t", "<tab>")
2811 2811
                 raise ParameterError("Invalid character(s) in header name '%s': \"%s\"" % (key, key_inval))
2812
-            debug(u"Updating Config.Config extra_headers[%s] -> %s" % (key.strip().lower(), val.strip()))
2813
-            cfg.extra_headers[key.strip().lower()] = val.strip()
2812
+            debug(u"Updating Config.Config extra_headers[%s] -> %s" % (key.replace('_', '-').strip().lower(), val.strip()))
2813
+            cfg.extra_headers[key.replace('_', '-').strip().lower()] = val.strip()
2814 2814
 
2815 2815
     # Process --remove-header
2816 2816
     if options.remove_headers:
... ...
@@ -2841,9 +2841,12 @@ def main():
2841 2841
     ## Update Config with other parameters
2842 2842
     for option in cfg.option_list():
2843 2843
         try:
2844
-            if getattr(options, option) != None:
2845
-                debug(u"Updating Config.Config %s -> %s" % (option, getattr(options, option)))
2846
-                cfg.update_option(option, getattr(options, option))
2844
+            value = getattr(options, option)
2845
+            if value != None:
2846
+                if type(value) == type(b''):
2847
+                    value = unicodise_s(value)
2848
+                debug(u"Updating Config.Config %s -> %s" % (option, value))
2849
+                cfg.update_option(option, value)
2847 2850
         except AttributeError:
2848 2851
             ## Some Config() options are not settable from command line
2849 2852
             pass
... ...
@@ -2876,9 +2879,13 @@ def main():
2876 2876
     ## Update CloudFront options if some were set
2877 2877
     for option in CfCmd.options.option_list():
2878 2878
         try:
2879
-            if getattr(options, option) != None:
2880
-                debug(u"Updating CloudFront.Cmd %s -> %s" % (option, getattr(options, option)))
2881
-                CfCmd.options.update_option(option, getattr(options, option))
2879
+            value = getattr(options, option)
2880
+            if value != None:
2881
+                if type(value) == type(b''):
2882
+                    value = unicodise_s(value)
2883
+            if value != None:
2884
+                debug(u"Updating CloudFront.Cmd %s -> %s" % (option, value))
2885
+                CfCmd.options.update_option(option, value)
2882 2886
         except AttributeError:
2883 2887
             ## Some CloudFront.Cmd.Options() options are not settable from command line
2884 2888
             pass