S3/Utils.py
ec50b5a7
 ## Amazon S3 manager
 ## Author: Michal Ludvig <michal@logix.cz>
 ##         http://www.logix.cz/michal
 ## License: GPL Version 2
 
8ec1807f
 import os
df9fa4b5
 import time
 import re
8ec1807f
 import string
 import random
227fabf8
 import rfc822
0b2aefe3
 try:
0b8ea559
 	from hashlib import md5, sha1
0b2aefe3
 except ImportError:
3c682455
 	from md5 import md5
0b8ea559
 	import sha as sha1
 import hmac
 import base64
ac9940ec
 import errno
df9fa4b5
 
ed27a45e
 from logging import debug, info, warning, error
 
82d9eafa
 import Config
3c07424d
 import Exceptions
82d9eafa
 
7bae4e19
 try:
 	import xml.etree.ElementTree as ET
 except ImportError:
 	import elementtree.ElementTree as ET
3c07424d
 from xml.parsers.expat import ExpatError
7bae4e19
 
cb0bbaef
 __all__ = []
cb64ca9e
 def parseNodes(nodes):
e5c6f6c5
 	## WARNING: Ignores text nodes from mixed xml/text.
 	## For instance <tag1>some text<tag2>other text</tag2></tag1>
 	## will be ignore "some text" node
df9fa4b5
 	retval = []
 	for node in nodes:
 		retval_item = {}
 		for child in node.getchildren():
cb64ca9e
 			name = child.tag
e5c6f6c5
 			if child.getchildren():
cb64ca9e
 				retval_item[name] = parseNodes([child])
e5c6f6c5
 			else:
 				retval_item[name] = node.findtext(".//%s" % child.tag)
df9fa4b5
 		retval.append(retval_item)
 	return retval
cb0bbaef
 __all__.append("parseNodes")
df9fa4b5
 
cb64ca9e
 def stripNameSpace(xml):
 	"""
 	removeNameSpace(xml) -- remove top-level AWS namespace
 	"""
 	r = re.compile('^(<?[^>]+?>\s?)(<\w+) xmlns=[\'"](http://[^\'"]+)[\'"](.*)', re.MULTILINE)
01b31b84
 	if r.match(xml):
 		xmlns = r.match(xml).groups()[2]
 		xml = r.sub("\\1\\2\\4", xml)
 	else:
 		xmlns = None
cb64ca9e
 	return xml, xmlns
cb0bbaef
 __all__.append("stripNameSpace")
cb64ca9e
 
67a8d099
 def getTreeFromXml(xml):
cb64ca9e
 	xml, xmlns = stripNameSpace(xml)
3c07424d
 	try:
 		tree = ET.fromstring(xml)
 		if xmlns:
 			tree.attrib['xmlns'] = xmlns
 		return tree
 	except ExpatError, e:
 		error(e)
 		raise Exceptions.ParameterError("Bucket contains invalid filenames. Please run: s3cmd fixbucket s3://your-bucket/")
cb0bbaef
 __all__.append("getTreeFromXml")
67a8d099
 	
 def getListFromXml(xml, node):
 	tree = getTreeFromXml(xml)
cb64ca9e
 	nodes = tree.findall('.//%s' % (node))
 	return parseNodes(nodes)
cb0bbaef
 __all__.append("getListFromXml")
c3f0b06a
 
 def getDictFromTree(tree):
 	ret_dict = {}
 	for child in tree.getchildren():
 		if child.getchildren():
 			## Complex-type child. We're not interested
 			continue
 		if ret_dict.has_key(child.tag):
 			if not type(ret_dict[child.tag]) == list:
 				ret_dict[child.tag] = [ret_dict[child.tag]]
42fd22b5
 			ret_dict[child.tag].append(child.text or "")
c3f0b06a
 		else:
42fd22b5
 			ret_dict[child.tag] = child.text or ""
c3f0b06a
 	return ret_dict
cb0bbaef
 __all__.append("getDictFromTree")
c3f0b06a
 
0d91ff3f
 def getTextFromXml(xml, xpath):
67a8d099
 	tree = getTreeFromXml(xml)
dc758146
 	if tree.tag.endswith(xpath):
 		return tree.text
 	else:
cb64ca9e
 		return tree.findtext(xpath)
cb0bbaef
 __all__.append("getTextFromXml")
67a8d099
 
 def getRootTagName(xml):
 	tree = getTreeFromXml(xml)
cb64ca9e
 	return tree.tag
cb0bbaef
 __all__.append("getRootTagName")
0d91ff3f
 
c3f0b06a
 def xmlTextNode(tag_name, text):
 	el = ET.Element(tag_name)
 	el.text = unicode(text)
 	return el
cb0bbaef
 __all__.append("xmlTextNode")
c3f0b06a
 
 def appendXmlTextNode(tag_name, text, parent):
 	"""
 	Creates a new <tag_name> Node and sets
 	its content to 'text'. Then appends the
 	created Node to 'parent' element if given.
 	Returns the newly created Node.
 	"""
b020ea02
 	el = xmlTextNode(tag_name, text)
 	parent.append(el)
 	return el
cb0bbaef
 __all__.append("appendXmlTextNode")
c3f0b06a
 
df9fa4b5
 def dateS3toPython(date):
c490c410
 	date = re.compile("(\.\d*)?Z").sub(".000Z", date)
df9fa4b5
 	return time.strptime(date, "%Y-%m-%dT%H:%M:%S.000Z")
cb0bbaef
 __all__.append("dateS3toPython")
df9fa4b5
 
 def dateS3toUnix(date):
 	## FIXME: This should be timezone-aware.
 	## Currently the argument to strptime() is GMT but mktime() 
 	## treats it as "localtime". Anyway...
 	return time.mktime(dateS3toPython(date))
cb0bbaef
 __all__.append("dateS3toUnix")
df9fa4b5
 
227fabf8
 def dateRFC822toPython(date):
 	return rfc822.parsedate(date)
cb0bbaef
 __all__.append("dateRFC822toPython")
227fabf8
 
 def dateRFC822toUnix(date):
 	return time.mktime(dateRFC822toPython(date))
cb0bbaef
 __all__.append("dateRFC822toUnix")
227fabf8
 
63ba9974
 def formatSize(size, human_readable = False, floating_point = False):
 	size = floating_point and float(size) or int(size)
df9fa4b5
 	if human_readable:
 		coeffs = ['k', 'M', 'G', 'T']
 		coeff = ""
 		while size > 2048:
b5fe5ac4
 			size /= 1024
df9fa4b5
 			coeff = coeffs.pop(0)
 		return (size, coeff)
 	else:
 		return (size, "")
cb0bbaef
 __all__.append("formatSize")
df9fa4b5
 
 def formatDateTime(s3timestamp):
 	return time.strftime("%Y-%m-%d %H:%M", dateS3toPython(s3timestamp))
cb0bbaef
 __all__.append("formatDateTime")
b5fe5ac4
 
 def convertTupleListToDict(list):
 	retval = {}
 	for tuple in list:
 		retval[tuple[0]] = tuple[1]
 	return retval
cb0bbaef
 __all__.append("convertTupleListToDict")
8ec1807f
 
 _rnd_chars = string.ascii_letters+string.digits
 _rnd_chars_len = len(_rnd_chars)
 def rndstr(len):
 	retval = ""
 	while len > 0:
 		retval += _rnd_chars[random.randint(0, _rnd_chars_len-1)]
 		len -= 1
 	return retval
cb0bbaef
 __all__.append("rndstr")
8ec1807f
 
 def mktmpsomething(prefix, randchars, createfunc):
 	old_umask = os.umask(0077)
 	tries = 5
 	while tries > 0:
 		dirname = prefix + rndstr(randchars)
 		try:
 			createfunc(dirname)
 			break
 		except OSError, e:
 			if e.errno != errno.EEXIST:
 				os.umask(old_umask)
 				raise
 		tries -= 1
 
 	os.umask(old_umask)
 	return dirname
cb0bbaef
 __all__.append("mktmpsomething")
8ec1807f
 
 def mktmpdir(prefix = "/tmp/tmpdir-", randchars = 10):
 	return mktmpsomething(prefix, randchars, os.mkdir)
cb0bbaef
 __all__.append("mktmpdir")
8ec1807f
 
 def mktmpfile(prefix = "/tmp/tmpfile-", randchars = 20):
 	createfunc = lambda filename : os.close(os.open(filename, os.O_CREAT | os.O_EXCL))
 	return mktmpsomething(prefix, randchars, createfunc)
cb0bbaef
 __all__.append("mktmpfile")
49731b40
 
 def hash_file_md5(filename):
3c682455
 	h = md5()
37a8ad44
 	f = open(filename, "rb")
9a5cde49
 	while True:
 		# Hash 32kB chunks
 		data = f.read(32*1024)
 		if not data:
 			break
 		h.update(data)
49731b40
 	f.close()
 	return h.hexdigest()
cb0bbaef
 __all__.append("hash_file_md5")
ed27a45e
 
bc4c306d
 def mkdir_with_parents(dir_name):
ed27a45e
 	"""
bc4c306d
 	mkdir_with_parents(dst_dir)
ed27a45e
 	
 	Create directory 'dir_name' with all parent directories
 
 	Returns True on success, False otherwise.
 	"""
 	pathmembers = dir_name.split(os.sep)
 	tmp_stack = []
 	while pathmembers and not os.path.isdir(os.sep.join(pathmembers)):
 		tmp_stack.append(pathmembers.pop())
 	while tmp_stack:
 		pathmembers.append(tmp_stack.pop())
 		cur_dir = os.sep.join(pathmembers)
 		try:
 			debug("mkdir(%s)" % cur_dir)
 			os.mkdir(cur_dir)
bc4c306d
 		except (OSError, IOError), e:
 			warning("%s: can not make directory: %s" % (cur_dir, e.strerror))
ed27a45e
 			return False
 		except Exception, e:
bc4c306d
 			warning("%s: %s" % (cur_dir, e))
ed27a45e
 			return False
 	return True
cb0bbaef
 __all__.append("mkdir_with_parents")
d90a7929
 
82d9eafa
 def unicodise(string, encoding = None, errors = "replace"):
d90a7929
 	"""
 	Convert 'string' to Unicode or raise an exception.
 	"""
82d9eafa
 
 	if not encoding:
 		encoding = Config.Config().encoding
 
d90a7929
 	if type(string) == unicode:
 		return string
227fabf8
 	debug("Unicodising %r using %s" % (string, encoding))
d90a7929
 	try:
82d9eafa
 		return string.decode(encoding, errors)
d90a7929
 	except UnicodeDecodeError:
 		raise UnicodeDecodeError("Conversion to unicode failed: %r" % string)
cb0bbaef
 __all__.append("unicodise")
d90a7929
 
82d9eafa
 def deunicodise(string, encoding = None, errors = "replace"):
 	"""
 	Convert unicode 'string' to <type str>, by default replacing
 	all invalid characters with '?' or raise an exception.
 	"""
 
 	if not encoding:
 		encoding = Config.Config().encoding
 
 	if type(string) != unicode:
 		return str(string)
227fabf8
 	debug("DeUnicodising %r using %s" % (string, encoding))
d90a7929
 	try:
82d9eafa
 		return string.encode(encoding, errors)
 	except UnicodeEncodeError:
 		raise UnicodeEncodeError("Conversion from unicode failed: %r" % string)
cb0bbaef
 __all__.append("deunicodise")
82d9eafa
 
 def unicodise_safe(string, encoding = None):
 	"""
 	Convert 'string' to Unicode according to current encoding 
 	and replace all invalid characters with '?'
 	"""
 
 	return unicodise(deunicodise(string, encoding), encoding).replace(u'\ufffd', '?')
cb0bbaef
 __all__.append("unicodise_safe")
d90a7929
 
b40dd815
 def replace_nonprintables(string):
 	"""
 	replace_nonprintables(string)
 
 	Replaces all non-printable characters 'ch' in 'string'
 	where ord(ch) <= 26 with ^@, ^A, ... ^Z
 	"""
 	new_string = ""
 	modified = 0
 	for c in string:
 		o = ord(c)
 		if (o <= 31):
 			new_string += "^" + chr(ord('@') + o)
 			modified += 1
 		elif (o == 127):
 			new_string += "^?"
 			modified += 1
 		else:
 			new_string += c
3c07424d
 	if modified and Config.Config().urlencoding_mode != "fixbucket":
b40dd815
 		warning("%d non-printable characters replaced in: %s" % (modified, new_string))
 	return new_string
cb0bbaef
 __all__.append("replace_nonprintables")
b40dd815
 
0b8ea559
 def sign_string(string_to_sign):
 	#debug("string_to_sign: %s" % string_to_sign)
 	signature = base64.encodestring(hmac.new(Config.Config().secret_key, string_to_sign, sha1).digest()).strip()
 	#debug("signature: %s" % signature)
 	return signature
cb0bbaef
 __all__.append("sign_string")
b020ea02
 
 def check_bucket_name(bucket, dns_strict = True):
 	if dns_strict:
 		invalid = re.search("([^a-z0-9\.-])", bucket)
 		if invalid:
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' contains disallowed character '%s'. The only supported ones are: lowercase us-ascii letters (a-z), digits (0-9), dot (.) and hyphen (-)." % (bucket, invalid.groups()[0]))
b020ea02
 	else:
 		invalid = re.search("([^A-Za-z0-9\._-])", bucket)
 		if invalid:
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' contains disallowed character '%s'. The only supported ones are: us-ascii letters (a-z, A-Z), digits (0-9), dot (.), hyphen (-) and underscore (_)." % (bucket, invalid.groups()[0]))
b020ea02
 
 	if len(bucket) < 3:
9bacffc4
 		raise Exceptions.ParameterError("Bucket name '%s' is too short (min 3 characters)" % bucket)
b020ea02
 	if len(bucket) > 255:
9bacffc4
 		raise Exceptions.ParameterError("Bucket name '%s' is too long (max 255 characters)" % bucket)
b020ea02
 	if dns_strict:
 		if len(bucket) > 63:
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' is too long (max 63 characters)" % bucket)
b020ea02
 		if re.search("-\.", bucket):
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' must not contain sequence '-.' for DNS compatibility" % bucket)
b020ea02
 		if re.search("\.\.", bucket):
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' must not contain sequence '..' for DNS compatibility" % bucket)
b020ea02
 		if not re.search("^[0-9a-z]", bucket):
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' must start with a letter or a digit" % bucket)
b020ea02
 		if not re.search("[0-9a-z]$", bucket):
9bacffc4
 			raise Exceptions.ParameterError("Bucket name '%s' must end with a letter or a digit" % bucket)
b020ea02
 	return True
 __all__.append("check_bucket_name")
 
 def check_bucket_name_dns_conformity(bucket):
 	try:
 		return check_bucket_name(bucket, dns_strict = True)
9bacffc4
 	except Exceptions.ParameterError:
b020ea02
 		return False
 __all__.append("check_bucket_name_dns_conformity")
 
 def getBucketFromHostname(hostname):
 	"""
 	bucket, success = getBucketFromHostname(hostname)
 
 	Only works for hostnames derived from bucket names
 	using Config.host_bucket pattern.
 
 	Returns bucket name and a boolean success flag.
 	"""
 
 	# Create RE pattern from Config.host_bucket
 	pattern = Config.Config().host_bucket % { 'bucket' : '(?P<bucket>.*)' }
 	m = re.match(pattern, hostname)
 	if not m:
 		return (hostname, False)
 	return m.groups()[0], True
 __all__.append("getBucketFromHostname")
 
 def getHostnameFromBucket(bucket):
 	return Config.Config().host_bucket % { 'bucket' : bucket }
 __all__.append("getHostnameFromBucket")