S3/S3.py
ec50b5a7
 ## Amazon S3 manager
 ## Author: Michal Ludvig <michal@logix.cz>
 ##         http://www.logix.cz/michal
 ## License: GPL Version 2
 
f98a27f2
 import sys
f81e7fba
 import os, os.path
 import base64
a0fc5bca
 import md5
 import sha
f81e7fba
 import hmac
 import httplib
 import logging
 from logging import debug, info, warning, error
 from stat import ST_SIZE
 
 from Utils import *
 from SortedDict import SortedDict
 from BidirMap import BidirMap
b008e471
 from Config import Config
f81e7fba
 
 class S3Error (Exception):
 	def __init__(self, response):
 		self.status = response["status"]
 		self.reason = response["reason"]
75405909
 		self.info = {}
f81e7fba
 		debug("S3Error: %s (%s)" % (self.status, self.reason))
 		if response.has_key("headers"):
 			for header in response["headers"]:
 				debug("HttpHeader: %s: %s" % (header, response["headers"][header]))
 		if response.has_key("data"):
 			tree = ET.fromstring(response["data"])
 			for child in tree.getchildren():
 				if child.text != "":
 					debug("ErrorXML: " + child.tag + ": " + repr(child.text))
75405909
 					self.info[child.tag] = child.text
f81e7fba
 
 	def __str__(self):
 		retval = "%d (%s)" % (self.status, self.reason)
 		try:
75405909
 			retval += (": %s" % self.info["Code"])
dc758146
 		except (AttributeError, KeyError):
f81e7fba
 			pass
 		return retval
 
 class ParameterError(Exception):
 	pass
 
eb9c54ec
 class S3(object):
f81e7fba
 	http_methods = BidirMap(
 		GET = 0x01,
 		PUT = 0x02,
 		HEAD = 0x04,
 		DELETE = 0x08,
 		MASK = 0x0F,
 		)
 	
 	targets = BidirMap(
 		SERVICE = 0x0100,
 		BUCKET = 0x0200,
 		OBJECT = 0x0400,
 		MASK = 0x0700,
 		)
 
 	operations = BidirMap(
 		UNDFINED = 0x0000,
 		LIST_ALL_BUCKETS = targets["SERVICE"] | http_methods["GET"],
 		BUCKET_CREATE = targets["BUCKET"] | http_methods["PUT"],
 		BUCKET_LIST = targets["BUCKET"] | http_methods["GET"],
 		BUCKET_DELETE = targets["BUCKET"] | http_methods["DELETE"],
 		OBJECT_PUT = targets["OBJECT"] | http_methods["PUT"],
 		OBJECT_GET = targets["OBJECT"] | http_methods["GET"],
 		OBJECT_HEAD = targets["OBJECT"] | http_methods["HEAD"],
 		OBJECT_DELETE = targets["OBJECT"] | http_methods["DELETE"],
 	)
 
 	codes = {
 		"NoSuchBucket" : "Bucket '%s' does not exist",
 		"AccessDenied" : "Access to bucket '%s' was denied",
 		"BucketAlreadyExists" : "Bucket '%s' already exists",
 		}
 
 	def __init__(self, config):
 		self.config = config
 
dc758146
 	def get_connection(self, bucket):
afe194f8
 		if self.config.proxy_host != "":
 			return httplib.HTTPConnection(self.config.proxy_host, self.config.proxy_port)
 		else:
dc758146
 			if self.config.use_https:
 				return httplib.HTTPSConnection(self.get_hostname(bucket))
 			else:
 				return httplib.HTTPConnection(self.get_hostname(bucket))
afe194f8
 
dc758146
 	def get_hostname(self, bucket):
 		if bucket:
 			host = self.config.host_bucket % { 'bucket' : bucket }
 		else:
 			host = self.config.host_base
 		debug('get_hostname(): ' + host)
 		return host
afe194f8
 
dc758146
 	def format_uri(self, resource):
 		if self.config.proxy_host != "":
 			uri = "http://%s%s" % (self.get_hostname(resource['bucket']), resource['uri'])
 		else:
 			uri = resource['uri']
 		debug('format_uri(): ' + uri)
 		return uri
afe194f8
 
ec50b5a7
 	## Commands / Actions
f81e7fba
 	def list_all_buckets(self):
 		request = self.create_request("LIST_ALL_BUCKETS")
 		response = self.send_request(request)
 		response["list"] = getListFromXml(response["data"], "Bucket")
 		return response
 	
9081133d
 	def bucket_list(self, bucket, prefix = None):
d94adea9
 		def _list_truncated(data):
ac9940ec
 			## <IsTruncated> can either be "true" or "false" or be missing completely
 			is_truncated = getTextFromXml(data, ".//IsTruncated") or "false"
 			return is_truncated.lower() != "false"
d94adea9
 
 		def _get_contents(data):
 			return getListFromXml(data, "Contents")
 
f4555c39
 		request = self.create_request("BUCKET_LIST", bucket = bucket, prefix = prefix)
f81e7fba
 		response = self.send_request(request)
0d91ff3f
 		#debug(response)
d94adea9
 		list = _get_contents(response["data"])
 		while _list_truncated(response["data"]):
 			marker = list[-1]["Key"]
 			info("Listing continues after '%s'" % marker)
 			request = self.create_request("BUCKET_LIST", bucket = bucket, prefix = prefix, marker = marker)
 			response = self.send_request(request)
 			list += _get_contents(response["data"])
 		response['list'] = list
f81e7fba
 		return response
 
dc758146
 	def bucket_create(self, bucket, bucket_location = None):
f81e7fba
 		self.check_bucket_name(bucket)
afe194f8
 		headers = SortedDict()
dc758146
 		body = ""
 		if bucket_location and bucket_location.strip().upper() != "US":
 			body  = "<CreateBucketConfiguration><LocationConstraint>"
 			body += bucket_location.strip().upper()
 			body += "</LocationConstraint></CreateBucketConfiguration>"
 			debug("bucket_location: " + body)
 		headers["content-length"] = len(body)
afe194f8
 		request = self.create_request("BUCKET_CREATE", bucket = bucket, headers = headers)
dc758146
 		response = self.send_request(request, body)
f81e7fba
 		return response
 
 	def bucket_delete(self, bucket):
 		request = self.create_request("BUCKET_DELETE", bucket = bucket)
 		response = self.send_request(request)
 		return response
 
82157846
 	def bucket_info(self, bucket):
dc758146
 		request = self.create_request("BUCKET_LIST", bucket = bucket, extra = "?location")
82157846
 		response = self.send_request(request)
dc758146
 		response['bucket-location'] = getTextFromXml(response['data'], "LocationConstraint") or "any"
82157846
 		return response
 
8ec1807f
 	def object_put(self, filename, bucket, object, extra_headers = None):
f81e7fba
 		if not os.path.isfile(filename):
 			raise ParameterError("%s is not a regular file" % filename)
 		try:
 			file = open(filename, "r")
 			size = os.stat(filename)[ST_SIZE]
 		except IOError, e:
 			raise ParameterError("%s: %s" % (filename, e.strerror))
 		headers = SortedDict()
8ec1807f
 		if extra_headers:
 			headers.update(extra_headers)
f81e7fba
 		headers["content-length"] = size
9b7618ae
 		if self.config.acl_public:
9081133d
 			headers["x-amz-acl"] = "public-read"
f81e7fba
 		request = self.create_request("OBJECT_PUT", bucket = bucket, object = object, headers = headers)
 		response = self.send_file(request, file)
 		response["size"] = size
 		return response
 
f98a27f2
 	def object_get_file(self, bucket, object, filename):
f81e7fba
 		try:
f98a27f2
 			stream = open(filename, "w")
f81e7fba
 		except IOError, e:
 			raise ParameterError("%s: %s" % (filename, e.strerror))
f98a27f2
 		return self.object_get_stream(bucket, object, stream)
 
 	def object_get_stream(self, bucket, object, stream):
f81e7fba
 		request = self.create_request("OBJECT_GET", bucket = bucket, object = object)
f98a27f2
 		response = self.recv_file(request, stream)
f81e7fba
 		return response
f98a27f2
 		
f81e7fba
 	def object_delete(self, bucket, object):
 		request = self.create_request("OBJECT_DELETE", bucket = bucket, object = object)
 		response = self.send_request(request)
 		return response
 
8ec1807f
 	def object_put_uri(self, filename, uri, extra_headers = None):
b819c70c
 		if uri.type != "s3":
 			raise ValueError("Expected URI type 's3', got '%s'" % uri.type)
8ec1807f
 		return self.object_put(filename, uri.bucket(), uri.object(), extra_headers)
b819c70c
 
f98a27f2
 	def object_get_uri(self, uri, filename):
b819c70c
 		if uri.type != "s3":
 			raise ValueError("Expected URI type 's3', got '%s'" % uri.type)
f98a27f2
 		if filename == "-":
 			return self.object_get_stream(uri.bucket(), uri.object(), sys.stdout)
 		else:
 			return self.object_get_file(uri.bucket(), uri.object(), filename)
b819c70c
 
 	def object_delete_uri(self, uri):
 		if uri.type != "s3":
 			raise ValueError("Expected URI type 's3', got '%s'" % uri.type)
42b24cac
 		return self.object_delete(uri.bucket(), uri.object())
b819c70c
 
ec50b5a7
 	## Low level methods
c0e0c042
 	def urlencode_string(self, string):
 		encoded = ""
 		## List of characters that must be escaped for S3
 		## Haven't found this in any official docs
 		## but my tests show it's more less correct.
 		## If you start getting InvalidSignature errors
 		## from S3 check the error headers returned
 		## from S3 to see whether the list hasn't
 		## changed.
 		for c in string:	# I'm not sure how to know in what encoding 
 					# 'object' is. Apparently "type(object)==str"
 					# but the contents is a string of unicode
 					# bytes, e.g. '\xc4\x8d\xc5\xafr\xc3\xa1k'
 					# Don't know what it will do on non-utf8 
 					# systems.
 					#           [hope that sounds reassuring ;-)]
 			o = ord(c)
 			if (o <= 32 or		# Space and below
 			    o == 0x22 or	# "
 			    o == 0x23 or	# #
 			    o == 0x25 or	# %
 			    o == 0x2B or	# + (or it would become <space>)
 			    o == 0x3C or	# <
 			    o == 0x3E or	# >
 			    o == 0x3F or	# ?
 			    o == 0x5B or	# [
 			    o == 0x5C or	# \
 			    o == 0x5D or	# ]
 			    o == 0x5E or	# ^
 			    o == 0x60 or	# `
 			    o >= 123):   	# { and above, including >= 128 for UTF-8
 				encoded += "%%%02X" % o
 			else:
 				encoded += c
 		debug("String '%s' encoded to '%s'" % (string, encoded))
 		return encoded
 
dc758146
 	def create_request(self, operation, bucket = None, object = None, headers = None, extra = None, **params):
 		resource = { 'bucket' : None, 'uri' : "/" }
f81e7fba
 		if bucket:
dc758146
 			resource['bucket'] = str(bucket)
f81e7fba
 			if object:
dc758146
 				resource['uri'] = "/" + self.urlencode_string(object)
 		if extra:
 			resource['uri'] += extra
f81e7fba
 
 		if not headers:
 			headers = SortedDict()
 
 		if headers.has_key("date"):
 			if not headers.has_key("x-amz-date"):
 				headers["x-amz-date"] = headers["date"]
 			del(headers["date"])
 		
 		if not headers.has_key("x-amz-date"):
 			headers["x-amz-date"] = time.strftime("%a, %d %b %Y %H:%M:%S %z", time.gmtime(time.time()))
 
 		method_string = S3.http_methods.getkey(S3.operations[operation] & S3.http_methods["MASK"])
 		signature = self.sign_headers(method_string, resource, headers)
 		headers["Authorization"] = "AWS "+self.config.access_key+":"+signature
f4555c39
 		param_str = ""
 		for param in params:
 			if params[param] not in (None, ""):
 				param_str += "&%s=%s" % (param, params[param])
82157846
 			else:
 				param_str += "&%s" % param
f4555c39
 		if param_str != "":
dc758146
 			resource['uri'] += "?" + param_str[1:]
 		debug("CreateRequest: resource[uri]=" + resource['uri'])
f81e7fba
 		return (method_string, resource, headers)
 	
dc758146
 	def send_request(self, request, body = None):
f81e7fba
 		method_string, resource, headers = request
 		info("Processing request, please wait...")
dc758146
  		conn = self.get_connection(resource['bucket'])
  		conn.request(method_string, self.format_uri(resource), body, headers)
f81e7fba
 		response = {}
 		http_response = conn.getresponse()
 		response["status"] = http_response.status
 		response["reason"] = http_response.reason
 		response["headers"] = convertTupleListToDict(http_response.getheaders())
 		response["data"] =  http_response.read()
dc758146
 		debug("Response: " + str(response))
f81e7fba
 		conn.close()
 		if response["status"] < 200 or response["status"] > 299:
 			raise S3Error(response)
 		return response
 
 	def send_file(self, request, file):
 		method_string, resource, headers = request
 		info("Sending file '%s', please wait..." % file.name)
dc758146
 		conn = self.get_connection(resource['bucket'])
f81e7fba
 		conn.connect()
dc758146
 		conn.putrequest(method_string, self.format_uri(resource))
f81e7fba
 		for header in headers.keys():
 			conn.putheader(header, str(headers[header]))
 		conn.endheaders()
 		size_left = size_total = headers.get("content-length")
 		while (size_left > 0):
9b7618ae
 			debug("SendFile: Reading up to %d bytes from '%s'" % (self.config.send_chunk, file.name))
 			data = file.read(self.config.send_chunk)
f81e7fba
 			debug("SendFile: Sending %d bytes to the server" % len(data))
 			conn.send(data)
 			size_left -= len(data)
 			info("Sent %d bytes (%d %% of %d)" % (
 				(size_total - size_left),
 				(size_total - size_left) * 100 / size_total,
 				size_total))
 		response = {}
 		http_response = conn.getresponse()
 		response["status"] = http_response.status
 		response["reason"] = http_response.reason
 		response["headers"] = convertTupleListToDict(http_response.getheaders())
 		response["data"] =  http_response.read()
 		conn.close()
 		if response["status"] < 200 or response["status"] > 299:
 			raise S3Error(response)
 		return response
 
f98a27f2
 	def recv_file(self, request, stream):
f81e7fba
 		method_string, resource, headers = request
f98a27f2
 		info("Receiving file '%s', please wait..." % stream.name)
dc758146
 		conn = self.get_connection(resource['bucket'])
f81e7fba
 		conn.connect()
dc758146
 		conn.putrequest(method_string, self.format_uri(resource))
f81e7fba
 		for header in headers.keys():
 			conn.putheader(header, str(headers[header]))
 		conn.endheaders()
 		response = {}
 		http_response = conn.getresponse()
 		response["status"] = http_response.status
 		response["reason"] = http_response.reason
 		response["headers"] = convertTupleListToDict(http_response.getheaders())
 		if response["status"] < 200 or response["status"] > 299:
 			raise S3Error(response)
 
a0fc5bca
 		md5_hash = md5.new()
f81e7fba
 		size_left = size_total = int(response["headers"]["content-length"])
f98a27f2
 		size_recvd = 0
 		while (size_recvd < size_total):
9b7618ae
 			this_chunk = size_left > self.config.recv_chunk and self.config.recv_chunk or size_left
f81e7fba
 			debug("ReceiveFile: Receiving up to %d bytes from the server" % this_chunk)
 			data = http_response.read(this_chunk)
f98a27f2
 			debug("ReceiveFile: Writing %d bytes to file '%s'" % (len(data), stream.name))
 			stream.write(data)
a0fc5bca
 			md5_hash.update(data)
f98a27f2
 			size_recvd += len(data)
f81e7fba
 			info("Received %d bytes (%d %% of %d)" % (
f98a27f2
 				size_recvd,
 				size_recvd * 100 / size_total,
f81e7fba
 				size_total))
 		conn.close()
a0fc5bca
 		response["md5"] = md5_hash.hexdigest()
f81e7fba
 		response["md5match"] = response["headers"]["etag"].find(response["md5"]) >= 0
f98a27f2
 		response["size"] = size_recvd
 		if response["size"] != long(response["headers"]["content-length"]):
 			warning("Reported size (%s) does not match received size (%s)" % (
 				response["headers"]["content-length"], response["size"]))
f81e7fba
 		debug("ReceiveFile: Computed MD5 = %s" % response["md5"])
 		if not response["md5match"]:
 			warning("MD5 signatures do not match: computed=%s, received=%s" % (
 				response["md5"], response["headers"]["etag"]))
 		return response
 
 	def sign_headers(self, method, resource, headers):
 		h  = method+"\n"
 		h += headers.get("content-md5", "")+"\n"
 		h += headers.get("content-type", "")+"\n"
 		h += headers.get("date", "")+"\n"
 		for header in headers.keys():
 			if header.startswith("x-amz-"):
 				h += header+":"+str(headers[header])+"\n"
dc758146
 		if resource['bucket']:
 			h += "/" + resource['bucket']
 		h += resource['uri']
f4555c39
 		debug("SignHeaders: " + repr(h))
a0fc5bca
 		return base64.encodestring(hmac.new(self.config.secret_key, h, sha).digest()).strip()
f81e7fba
 
 	def check_bucket_name(self, bucket):
 		if re.compile("[^A-Za-z0-9\._-]").search(bucket):
 			raise ParameterError("Bucket name '%s' contains unallowed characters" % bucket)
 		if len(bucket) < 3:
 			raise ParameterError("Bucket name '%s' is too short (min 3 characters)" % bucket)
 		if len(bucket) > 255:
 			raise ParameterError("Bucket name '%s' is too long (max 255 characters)" % bucket)
 		return True