import repository from arizona
[raven.git] / 2.0 / python / s3 / S3.py
1 #!/usr/bin/env python
2
3 #  This software code is made available "AS IS" without warranties of any
4 #  kind.  You may copy, display, modify and redistribute the software
5 #  code either by itself or as incorporated into your code; provided that
6 #  you do not remove any proprietary notices.  Your use of this software
7 #  code is at your own risk and you waive any claim against Amazon
8 #  Digital Services, Inc. or its affiliates with respect to your use of
9 #  this software code. (c) 2006 Amazon Digital Services, Inc. or its
10 #  affiliates.
11
12 import base64
13 import hmac
14 import httplib
15 import re
16 import sha
17 import sys
18 import time
19 import urllib
20 import xml.sax
21
22 DEFAULT_HOST = 's3.amazonaws.com'
23 PORTS_BY_SECURITY = { True: 443, False: 80 }
24 METADATA_PREFIX = 'x-amz-meta-'
25 AMAZON_HEADER_PREFIX = 'x-amz-'
26
27 # generates the aws canonical string for the given parameters
28 def canonical_string(method, bucket="", key="", query_args={}, headers={}, expires=None):
29     interesting_headers = {}
30     for header_key in headers:
31         lk = header_key.lower()
32         if lk in ['content-md5', 'content-type', 'date'] or lk.startswith(AMAZON_HEADER_PREFIX):
33             interesting_headers[lk] = headers[header_key].strip()
34
35     # these keys get empty strings if they don't exist
36     if not interesting_headers.has_key('content-type'):
37         interesting_headers['content-type'] = ''
38     if not interesting_headers.has_key('content-md5'):
39         interesting_headers['content-md5'] = ''
40
41     # just in case someone used this.  it's not necessary in this lib.
42     if interesting_headers.has_key('x-amz-date'):
43         interesting_headers['date'] = ''
44
45     # if you're using expires for query string auth, then it trumps date
46     # (and x-amz-date)
47     if expires:
48         interesting_headers['date'] = str(expires)
49
50     sorted_header_keys = interesting_headers.keys()
51     sorted_header_keys.sort()
52
53     buf = "%s\n" % method
54     for header_key in sorted_header_keys:
55         if header_key.startswith(AMAZON_HEADER_PREFIX):
56             buf += "%s:%s\n" % (header_key, interesting_headers[header_key])
57         else:
58             buf += "%s\n" % interesting_headers[header_key]
59
60     # append the bucket if it exists
61     if bucket != "":
62         buf += "/%s" % bucket
63
64     # add the key.  even if it doesn't exist, add the slash
65     buf += "/%s" % urllib.quote_plus(key)
66
67     # handle special query string arguments
68
69     if query_args.has_key("acl"):
70         buf += "?acl"
71     elif query_args.has_key("torrent"):
72         buf += "?torrent"
73     elif query_args.has_key("logging"):
74         buf += "?logging"
75
76     return buf
77
78 # computes the base64'ed hmac-sha hash of the canonical string and the secret
79 # access key, optionally urlencoding the result
80 def encode(aws_secret_access_key, str, urlencode=False):
81     b64_hmac = base64.encodestring(hmac.new(aws_secret_access_key, str, sha).digest()).strip()
82     if urlencode:
83         return urllib.quote_plus(b64_hmac)
84     else:
85         return b64_hmac
86
87 def merge_meta(headers, metadata):
88     final_headers = headers.copy()
89     for k in metadata.keys():
90         final_headers[METADATA_PREFIX + k] = metadata[k]
91
92     return final_headers
93
94 # builds the query arg string
95 def query_args_hash_to_string(query_args):
96     query_string = ""
97     pairs = []
98     for k, v in query_args.items():
99         piece = k
100         if v != None:
101             piece += "=%s" % urllib.quote_plus(str(v))
102         pairs.append(piece)
103
104     return '&'.join(pairs)
105
106
107 class CallingFormat:
108     REGULAR = 1
109     SUBDOMAIN = 2
110     VANITY = 3
111
112     def build_url_base(protocol, server, port, bucket, calling_format):
113         url_base = '%s://' % protocol
114
115         if bucket == '':
116             url_base += server
117         elif calling_format == CallingFormat.SUBDOMAIN:
118             url_base += "%s.%s" % (bucket, server)
119         elif calling_format == CallingFormat.VANITY:
120             url_base += bucket
121         else:
122             url_base += server
123
124         url_base += ":%s" % port
125
126         if (bucket != '') and (calling_format == CallingFormat.REGULAR):
127             url_base += "/%s" % bucket
128
129         return url_base
130
131     build_url_base = staticmethod(build_url_base)
132
133
134
135 class AWSAuthConnection:
136     def __init__(self, aws_access_key_id, aws_secret_access_key, is_secure=True,
137             server=DEFAULT_HOST, port=None, calling_format=CallingFormat.REGULAR):
138
139         if not port:
140             port = PORTS_BY_SECURITY[is_secure]
141
142         self.aws_access_key_id = aws_access_key_id
143         self.aws_secret_access_key = aws_secret_access_key
144         self.is_secure = is_secure
145         self.server = server
146         self.port = port
147         self.calling_format = calling_format
148
149     def create_bucket(self, bucket, headers={}):
150         return Response(self.make_request('PUT', bucket, '', {}, headers))
151
152     def list_bucket(self, bucket, options={}, headers={}):
153         return ListBucketResponse(self.make_request('GET', bucket, '', options, headers))
154
155     def delete_bucket(self, bucket, headers={}):
156         return Response(self.make_request('DELETE', bucket, '', {}, headers))
157
158     def put(self, bucket, key, object, headers={}):
159         if not isinstance(object, S3Object):
160             object = S3Object(object)
161
162         return Response(
163                 self.make_request(
164                     'PUT',
165                     bucket,
166                     key,
167                     {},
168                     headers,
169                     object.data,
170                     object.metadata))
171
172     def get(self, bucket, key, headers={}):
173         return GetResponse(
174                 self.make_request('GET', bucket, key, {}, headers))
175
176     def delete(self, bucket, key, headers={}):
177         return Response(
178                 self.make_request('DELETE', bucket, key, {}, headers))
179
180     def get_bucket_logging(self, bucket, headers={}):
181         return GetResponse(self.make_request('GET', bucket, '', { 'logging': None }, headers))
182
183     def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
184         return Response(self.make_request('PUT', bucket, '', { 'logging': None }, headers, logging_xml_doc))
185
186     def get_bucket_acl(self, bucket, headers={}):
187         return self.get_acl(bucket, '', headers)
188
189     def get_acl(self, bucket, key, headers={}):
190         return GetResponse(
191                 self.make_request('GET', bucket, key, { 'acl': None }, headers))
192
193     def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
194         return self.put_acl(bucket, '', acl_xml_document, headers)
195
196     def put_acl(self, bucket, key, acl_xml_document, headers={}):
197         return Response(
198                 self.make_request(
199                     'PUT',
200                     bucket,
201                     key,
202                     { 'acl': None },
203                     headers,
204                     acl_xml_document))
205
206     def list_all_my_buckets(self, headers={}):
207         return ListAllMyBucketsResponse(self.make_request('GET', '', '', {}, headers))
208
209     def make_request(self, method, bucket='', key='', query_args={}, headers={}, data='', metadata={}):
210
211         server = ''
212         if bucket == '':
213             server = self.server
214         elif self.calling_format == CallingFormat.SUBDOMAIN:
215             server = "%s.%s" % (bucket, self.server)
216         elif self.calling_format == CallingFormat.VANITY:
217             server = bucket
218         else:
219             server = self.server
220
221         path = ''
222
223         if (bucket != '') and (self.calling_format == CallingFormat.REGULAR):
224             path += "/%s" % bucket
225
226         # add the slash after the bucket regardless
227         # the key will be appended if it is non-empty
228         path += "/%s" % urllib.quote_plus(key)
229
230
231         # build the path_argument string
232         # add the ? in all cases since 
233         # signature and credentials follow path args
234         path += "?"
235         path += query_args_hash_to_string(query_args)
236
237
238         final_headers = merge_meta(headers, metadata);
239         # add auth header
240         self.add_aws_auth_header(final_headers, method, bucket, key, query_args)
241
242         if (self.is_secure):
243             connection = httplib.HTTPSConnection("%s:%d" % (server, self.port))
244         else:
245             connection = httplib.HTTPConnection("%s:%d" % (server, self.port))
246
247         connection.request(method, path, data, final_headers)
248         return connection.getresponse()
249
250
251     def add_aws_auth_header(self, headers, method, bucket, key, query_args):
252         if not headers.has_key('Date'):
253             headers['Date'] = time.strftime("%a, %d %b %Y %X GMT", time.gmtime())
254
255         c_string = canonical_string(method, bucket, key, query_args, headers)
256         headers['Authorization'] = \
257             "AWS %s:%s" % (self.aws_access_key_id, encode(self.aws_secret_access_key, c_string))
258
259
260 class QueryStringAuthGenerator:
261     # by default, expire in 1 minute
262     DEFAULT_EXPIRES_IN = 60
263
264     def __init__(self, aws_access_key_id, aws_secret_access_key, is_secure=True,
265                  server=DEFAULT_HOST, port=None, calling_format=CallingFormat.REGULAR):
266
267         if not port:
268             port = PORTS_BY_SECURITY[is_secure]
269
270         self.aws_access_key_id = aws_access_key_id
271         self.aws_secret_access_key = aws_secret_access_key
272         if (is_secure):
273             self.protocol = 'https'
274         else:
275             self.protocol = 'http'
276
277         self.is_secure = is_secure
278         self.server = server
279         self.port = port
280         self.calling_format = calling_format
281         self.__expires_in = QueryStringAuthGenerator.DEFAULT_EXPIRES_IN
282         self.__expires = None
283
284         # for backwards compatibility with older versions
285         self.server_name = "%s:%s" % (self.server, self.port)
286
287     def set_expires_in(self, expires_in):
288         self.__expires_in = expires_in
289         self.__expires = None
290
291     def set_expires(self, expires):
292         self.__expires = expires
293         self.__expires_in = None
294
295     def create_bucket(self, bucket, headers={}):
296         return self.generate_url('PUT', bucket, '', {}, headers)
297
298     def list_bucket(self, bucket, options={}, headers={}):
299         return self.generate_url('GET', bucket, '', options, headers)
300
301     def delete_bucket(self, bucket, headers={}):
302         return self.generate_url('DELETE', bucket, '', {}, headers)
303
304     def put(self, bucket, key, object, headers={}):
305         if not isinstance(object, S3Object):
306             object = S3Object(object)
307
308         return self.generate_url(
309                 'PUT',
310                 bucket,
311                 key,
312                 {},
313                 merge_meta(headers, object.metadata))
314
315     def get(self, bucket, key, headers={}):
316         return self.generate_url('GET', bucket, key, {}, headers)
317
318     def delete(self, bucket, key, headers={}):
319         return self.generate_url('DELETE', bucket, key, {}, headers)
320
321     def get_bucket_logging(self, bucket, headers={}):
322         return self.generate_url('GET', bucket, '', { 'logging': None }, headers)
323
324     def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
325         return self.generate_url('PUT', bucket, '', { 'logging': None }, headers)
326
327     def get_bucket_acl(self, bucket, headers={}):
328         return self.get_acl(bucket, '', headers)
329
330     def get_acl(self, bucket, key='', headers={}):
331         return self.generate_url('GET', bucket, key, { 'acl': None }, headers)
332
333     def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
334         return self.put_acl(bucket, '', acl_xml_document, headers)
335
336     # don't really care what the doc is here.
337     def put_acl(self, bucket, key, acl_xml_document, headers={}):
338         return self.generate_url('PUT', bucket, key, { 'acl': None }, headers)
339
340     def list_all_my_buckets(self, headers={}):
341         return self.generate_url('GET', '', '', {}, headers)
342
343     def make_bare_url(self, bucket, key=''):
344         full_url = self.generate_url(self, bucket, key)
345         return full_url[:full_url.index('?')]
346
347     def generate_url(self, method, bucket='', key='', query_args={}, headers={}):
348         expires = 0
349         if self.__expires_in != None:
350             expires = int(time.time() + self.__expires_in)
351         elif self.__expires != None:
352             expires = int(self.__expires)
353         else:
354             raise "Invalid expires state"
355
356         canonical_str = canonical_string(method, bucket, key, query_args, headers, expires)
357         encoded_canonical = encode(self.aws_secret_access_key, canonical_str)
358
359         url = CallingFormat.build_url_base(self.protocol, self.server, self.port, bucket, self.calling_format)
360
361         url += "/%s" % urllib.quote_plus(key)
362
363         query_args['Signature'] = encoded_canonical
364         query_args['Expires'] = expires
365         query_args['AWSAccessKeyId'] = self.aws_access_key_id
366
367         url += "?%s" % query_args_hash_to_string(query_args)
368
369         return url
370
371
372 class S3Object:
373     def __init__(self, data, metadata={}):
374         self.data = data
375         self.metadata = metadata
376
377 class Owner:
378     def __init__(self, id='', display_name=''):
379         self.id = id
380         self.display_name = display_name
381
382 class ListEntry:
383     def __init__(self, key='', last_modified=None, etag='', size=0, storage_class='', owner=None):
384         self.key = key
385         self.last_modified = last_modified
386         self.etag = etag
387         self.size = size
388         self.storage_class = storage_class
389         self.owner = owner
390
391 class CommonPrefixEntry:
392     def __init(self, prefix=''):
393         self.prefix = prefix
394
395 class Bucket:
396     def __init__(self, name='', creation_date=''):
397         self.name = name
398         self.creation_date = creation_date
399
400 class Response:
401     def __init__(self, http_response):
402         self.http_response = http_response
403         # you have to do this read, even if you don't expect a body.
404         # otherwise, the next request fails.
405         self.body = http_response.read()
406
407 class ListBucketResponse(Response):
408     def __init__(self, http_response):
409         Response.__init__(self, http_response)
410         if http_response.status < 300:
411             handler = ListBucketHandler()
412             xml.sax.parseString(self.body, handler)
413             self.entries = handler.entries
414             self.common_prefixes = handler.common_prefixes
415             self.name = handler.name
416             self.marker = handler.marker
417             self.prefix = handler.prefix
418             self.is_truncated = handler.is_truncated
419             self.delimiter = handler.delimiter
420             self.max_keys = handler.max_keys
421             self.next_marker = handler.next_marker
422         else:
423             self.entries = []
424
425 class ListAllMyBucketsResponse(Response):
426     def __init__(self, http_response):
427         Response.__init__(self, http_response)
428         if http_response.status < 300: 
429             handler = ListAllMyBucketsHandler()
430             xml.sax.parseString(self.body, handler)
431             self.entries = handler.entries
432         else:
433             self.entries = []
434
435 class GetResponse(Response):
436     def __init__(self, http_response):
437         Response.__init__(self, http_response)
438         response_headers = http_response.msg   # older pythons don't have getheaders
439         metadata = self.get_aws_metadata(response_headers)
440         self.object = S3Object(self.body, metadata)
441
442     def get_aws_metadata(self, headers):
443         metadata = {}
444         for hkey in headers.keys():
445             if hkey.lower().startswith(METADATA_PREFIX):
446                 metadata[hkey[len(METADATA_PREFIX):]] = headers[hkey]
447                 del headers[hkey]
448
449         return metadata
450
451 class ListBucketHandler(xml.sax.ContentHandler):
452     def __init__(self):
453         self.entries = []
454         self.curr_entry = None
455         self.curr_text = ''
456         self.common_prefixes = []
457         self.curr_common_prefix = None
458         self.name = ''
459         self.marker = ''
460         self.prefix = ''
461         self.is_truncated = False
462         self.delimiter = ''
463         self.max_keys = 0
464         self.next_marker = ''
465         self.is_echoed_prefix_set = False
466
467     def startElement(self, name, attrs):
468         if name == 'Contents':
469             self.curr_entry = ListEntry()
470         elif name == 'Owner':
471             self.curr_entry.owner = Owner()
472         elif name == 'CommonPrefixes':
473             self.curr_common_prefix = CommonPrefixEntry()
474
475
476     def endElement(self, name):
477         if name == 'Contents':
478             self.entries.append(self.curr_entry)
479         elif name == 'CommonPrefixes':
480             self.common_prefixes.append(self.curr_common_prefix)
481         elif name == 'Key':
482             self.curr_entry.key = self.curr_text
483         elif name == 'LastModified':
484             self.curr_entry.last_modified = self.curr_text
485         elif name == 'ETag':
486             self.curr_entry.etag = self.curr_text
487         elif name == 'Size':
488             self.curr_entry.size = int(self.curr_text)
489         elif name == 'ID':
490             self.curr_entry.owner.id = self.curr_text
491         elif name == 'DisplayName':
492             self.curr_entry.owner.display_name = self.curr_text
493         elif name == 'StorageClass':
494             self.curr_entry.storage_class = self.curr_text
495         elif name == 'Name':
496             self.name = self.curr_text
497         elif name == 'Prefix' and self.is_echoed_prefix_set:
498             self.curr_common_prefix.prefix = self.curr_text
499         elif name == 'Prefix':
500             self.prefix = self.curr_text
501             self.is_echoed_prefix_set = True
502         elif name == 'Marker':
503             self.marker = self.curr_text
504         elif name == 'IsTruncated':
505             self.is_truncated = self.curr_text == 'true'
506         elif name == 'Delimiter':
507             self.delimiter = self.curr_text
508         elif name == 'MaxKeys':
509             self.max_keys = int(self.curr_text)
510         elif name == 'NextMarker':
511             self.next_marker = self.curr_text
512
513         self.curr_text = ''
514
515     def characters(self, content):
516         self.curr_text += content
517
518
519 class ListAllMyBucketsHandler(xml.sax.ContentHandler):
520     def __init__(self):
521         self.entries = []
522         self.curr_entry = None
523         self.curr_text = ''
524
525     def startElement(self, name, attrs):
526         if name == 'Bucket':
527             self.curr_entry = Bucket()
528
529     def endElement(self, name):
530         if name == 'Name':
531             self.curr_entry.name = self.curr_text
532         elif name == 'CreationDate':
533             self.curr_entry.creation_date = self.curr_text
534         elif name == 'Bucket':
535             self.entries.append(self.curr_entry)
536
537     def characters(self, content):
538         self.curr_text = content