test fixes
[wolnelektury.git] / apps / piston / oauth.py
1 """
2 The MIT License
3
4 Copyright (c) 2007 Leah Culver
5
6 Permission is hereby granted, free of charge, to any person obtaining a copy
7 of this software and associated documentation files (the "Software"), to deal
8 in the Software without restriction, including without limitation the rights
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 copies of the Software, and to permit persons to whom the Software is
11 furnished to do so, subject to the following conditions:
12
13 The above copyright notice and this permission notice shall be included in
14 all copies or substantial portions of the Software.
15
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 THE SOFTWARE.
23 """
24
25 import cgi
26 import urllib
27 import time
28 import random
29 import urlparse
30 import hmac
31 import binascii
32
33
34 VERSION = '1.0' # Hi Blaine!
35 HTTP_METHOD = 'GET'
36 SIGNATURE_METHOD = 'PLAINTEXT'
37
38
39 class OAuthError(RuntimeError):
40     """Generic exception class."""
41     def __init__(self, message='OAuth error occured.'):
42         self.message = message
43
44 def build_authenticate_header(realm=''):
45     """Optional WWW-Authenticate header (401 error)"""
46     return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
47
48 def escape(s):
49     """Escape a URL including any /."""
50     return urllib.quote(s, safe='~')
51
52 def _utf8_str(s):
53     """Convert unicode to utf-8."""
54     if isinstance(s, unicode):
55         return s.encode("utf-8")
56     else:
57         return str(s)
58
59 def generate_timestamp():
60     """Get seconds since epoch (UTC)."""
61     return int(time.time())
62
63 def generate_nonce(length=8):
64     """Generate pseudorandom number."""
65     return ''.join([str(random.randint(0, 9)) for i in range(length)])
66
67 def generate_verifier(length=8):
68     """Generate pseudorandom number."""
69     return ''.join([str(random.randint(0, 9)) for i in range(length)])
70
71
72 class OAuthConsumer(object):
73     """Consumer of OAuth authentication.
74
75     OAuthConsumer is a data type that represents the identity of the Consumer
76     via its shared secret with the Service Provider.
77
78     """
79     key = None
80     secret = None
81
82     def __init__(self, key, secret):
83         self.key = key
84         self.secret = secret
85
86
87 class OAuthToken(object):
88     """OAuthToken is a data type that represents an End User via either an access
89     or request token.
90
91     key -- the token
92     secret -- the token secret
93
94     """
95     key = None
96     secret = None
97     callback = None
98     callback_confirmed = None
99     verifier = None
100
101     def __init__(self, key, secret):
102         self.key = key
103         self.secret = secret
104
105     def set_callback(self, callback):
106         self.callback = callback
107         self.callback_confirmed = 'true'
108
109     def set_verifier(self, verifier=None):
110         if verifier is not None:
111             self.verifier = verifier
112         else:
113             self.verifier = generate_verifier()
114
115     def get_callback_url(self):
116         if self.callback and self.verifier:
117             # Append the oauth_verifier.
118             parts = urlparse.urlparse(self.callback)
119             scheme, netloc, path, params, query, fragment = parts[:6]
120             if query:
121                 query = '%s&oauth_verifier=%s' % (query, self.verifier)
122             else:
123                 query = 'oauth_verifier=%s' % self.verifier
124             return urlparse.urlunparse((scheme, netloc, path, params,
125                 query, fragment))
126         return self.callback
127
128     def to_string(self):
129         data = {
130             'oauth_token': self.key,
131             'oauth_token_secret': self.secret,
132         }
133         if self.callback_confirmed is not None:
134             data['oauth_callback_confirmed'] = self.callback_confirmed
135         return urllib.urlencode(data)
136
137     def from_string(s):
138         """ Returns a token from something like:
139         oauth_token_secret=xxx&oauth_token=xxx
140         """
141         params = cgi.parse_qs(s, keep_blank_values=False)
142         key = params['oauth_token'][0]
143         secret = params['oauth_token_secret'][0]
144         token = OAuthToken(key, secret)
145         try:
146             token.callback_confirmed = params['oauth_callback_confirmed'][0]
147         except KeyError:
148             pass # 1.0, no callback confirmed.
149         return token
150     from_string = staticmethod(from_string)
151
152     def __str__(self):
153         return self.to_string()
154
155
156 class OAuthRequest(object):
157     """OAuthRequest represents the request and can be serialized.
158
159     OAuth parameters:
160         - oauth_consumer_key
161         - oauth_token
162         - oauth_signature_method
163         - oauth_signature
164         - oauth_timestamp
165         - oauth_nonce
166         - oauth_version
167         - oauth_verifier
168         ... any additional parameters, as defined by the Service Provider.
169     """
170     parameters = None # OAuth parameters.
171     http_method = HTTP_METHOD
172     http_url = None
173     version = VERSION
174
175     def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):
176         self.http_method = http_method
177         self.http_url = http_url
178         self.parameters = parameters or {}
179
180     def set_parameter(self, parameter, value):
181         self.parameters[parameter] = value
182
183     def get_parameter(self, parameter):
184         try:
185             return self.parameters[parameter]
186         except:
187             raise OAuthError('Parameter not found: %s' % parameter)
188
189     def _get_timestamp_nonce(self):
190         return self.get_parameter('oauth_timestamp'), self.get_parameter(
191             'oauth_nonce')
192
193     def get_nonoauth_parameters(self):
194         """Get any non-OAuth parameters."""
195         parameters = {}
196         for k, v in self.parameters.iteritems():
197             # Ignore oauth parameters.
198             if k.find('oauth_') < 0:
199                 parameters[k] = v
200         return parameters
201
202     def to_header(self, realm=''):
203         """Serialize as a header for an HTTPAuth request."""
204         auth_header = 'OAuth realm="%s"' % realm
205         # Add the oauth parameters.
206         if self.parameters:
207             for k, v in self.parameters.iteritems():
208                 if k[:6] == 'oauth_':
209                     auth_header += ', %s="%s"' % (k, escape(str(v)))
210         return {'Authorization': auth_header}
211
212     def to_postdata(self):
213         """Serialize as post data for a POST request."""
214         return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \
215             for k, v in self.parameters.iteritems()])
216
217     def to_url(self):
218         """Serialize as a URL for a GET request."""
219         return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
220
221     def get_normalized_parameters(self):
222         """Return a string that contains the parameters that must be signed."""
223         params = self.parameters
224         try:
225             # Exclude the signature if it exists.
226             del params['oauth_signature']
227         except:
228             pass
229         # Escape key values before sorting.
230         key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \
231             for k,v in params.items()]
232         # Sort lexicographically, first after key, then after value.
233         key_values.sort()
234         # Combine key value pairs into a string.
235         return '&'.join(['%s=%s' % (k, v) for k, v in key_values])
236
237     def get_normalized_http_method(self):
238         """Uppercases the http method."""
239         return self.http_method.upper()
240
241     def get_normalized_http_url(self):
242         """Parses the URL and rebuilds it to be scheme://host/path."""
243         parts = urlparse.urlparse(self.http_url)
244         scheme, netloc, path = parts[:3]
245         # Exclude default port numbers.
246         if scheme == 'http' and netloc[-3:] == ':80':
247             netloc = netloc[:-3]
248         elif scheme == 'https' and netloc[-4:] == ':443':
249             netloc = netloc[:-4]
250         return '%s://%s%s' % (scheme, netloc, path)
251
252     def sign_request(self, signature_method, consumer, token):
253         """Set the signature parameter to the result of build_signature."""
254         # Set the signature method.
255         self.set_parameter('oauth_signature_method',
256             signature_method.get_name())
257         # Set the signature.
258         self.set_parameter('oauth_signature',
259             self.build_signature(signature_method, consumer, token))
260
261     def build_signature(self, signature_method, consumer, token):
262         """Calls the build signature method within the signature method."""
263         return signature_method.build_signature(self, consumer, token)
264
265     def from_request(http_method, http_url, headers=None, parameters=None,
266             query_string=None):
267         """Combines multiple parameter sources."""
268         if parameters is None:
269             parameters = {}
270
271         # Headers
272         if headers and 'Authorization' in headers:
273             auth_header = headers['Authorization']
274             # Check that the authorization header is OAuth.
275             if auth_header[:6] == 'OAuth ':
276                 auth_header = auth_header[6:]
277                 try:
278                     # Get the parameters from the header.
279                     header_params = OAuthRequest._split_header(auth_header)
280                     parameters.update(header_params)
281                 except:
282                     raise OAuthError('Unable to parse OAuth parameters from '
283                         'Authorization header.')
284
285         # GET or POST query string.
286         if query_string:
287             query_params = OAuthRequest._split_url_string(query_string)
288             parameters.update(query_params)
289
290         # URL parameters.
291         param_str = urlparse.urlparse(http_url)[4] # query
292         url_params = OAuthRequest._split_url_string(param_str)
293         parameters.update(url_params)
294
295         if parameters:
296             return OAuthRequest(http_method, http_url, parameters)
297
298         return None
299     from_request = staticmethod(from_request)
300
301     def from_consumer_and_token(oauth_consumer, token=None,
302             callback=None, verifier=None, http_method=HTTP_METHOD,
303             http_url=None, parameters=None):
304         if not parameters:
305             parameters = {}
306
307         defaults = {
308             'oauth_consumer_key': oauth_consumer.key,
309             'oauth_timestamp': generate_timestamp(),
310             'oauth_nonce': generate_nonce(),
311             'oauth_version': OAuthRequest.version,
312         }
313
314         defaults.update(parameters)
315         parameters = defaults
316
317         if token:
318             parameters['oauth_token'] = token.key
319             parameters['oauth_callback'] = token.callback
320             # 1.0a support for verifier.
321             parameters['oauth_verifier'] = verifier
322         elif callback:
323             # 1.0a support for callback in the request token request.
324             parameters['oauth_callback'] = callback
325
326         return OAuthRequest(http_method, http_url, parameters)
327     from_consumer_and_token = staticmethod(from_consumer_and_token)
328
329     def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,
330             http_url=None, parameters=None):
331         if not parameters:
332             parameters = {}
333
334         parameters['oauth_token'] = token.key
335
336         if callback:
337             parameters['oauth_callback'] = callback
338
339         return OAuthRequest(http_method, http_url, parameters)
340     from_token_and_callback = staticmethod(from_token_and_callback)
341
342     def _split_header(header):
343         """Turn Authorization: header into parameters."""
344         params = {}
345         parts = header.split(',')
346         for param in parts:
347             # Ignore realm parameter.
348             if param.find('realm') > -1:
349                 continue
350             # Remove whitespace.
351             param = param.strip()
352             # Split key-value.
353             param_parts = param.split('=', 1)
354             # Remove quotes and unescape the value.
355             params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
356         return params
357     _split_header = staticmethod(_split_header)
358
359     def _split_url_string(param_str):
360         """Turn URL string into parameters."""
361         parameters = cgi.parse_qs(param_str, keep_blank_values=False)
362         for k, v in parameters.iteritems():
363             parameters[k] = urllib.unquote(v[0])
364         return parameters
365     _split_url_string = staticmethod(_split_url_string)
366
367 class OAuthServer(object):
368     """A worker to check the validity of a request against a data store."""
369     timestamp_threshold = 300 # In seconds, five minutes.
370     version = VERSION
371     signature_methods = None
372     data_store = None
373
374     def __init__(self, data_store=None, signature_methods=None):
375         self.data_store = data_store
376         self.signature_methods = signature_methods or {}
377
378     def set_data_store(self, data_store):
379         self.data_store = data_store
380
381     def get_data_store(self):
382         return self.data_store
383
384     def add_signature_method(self, signature_method):
385         self.signature_methods[signature_method.get_name()] = signature_method
386         return self.signature_methods
387
388     def fetch_request_token(self, oauth_request):
389         """Processes a request_token request and returns the
390         request token on success.
391         """
392         try:
393             # Get the request token for authorization.
394             token = self._get_token(oauth_request, 'request')
395         except OAuthError:
396             # No token required for the initial token request.
397             version = self._get_version(oauth_request)
398             consumer = self._get_consumer(oauth_request)
399             try:
400                 callback = self.get_callback(oauth_request)
401             except OAuthError:
402                 callback = None # 1.0, no callback specified.
403             self._check_signature(oauth_request, consumer, None)
404             # Fetch a new token.
405             token = self.data_store.fetch_request_token(consumer, callback)
406         return token
407
408     def fetch_access_token(self, oauth_request):
409         """Processes an access_token request and returns the
410         access token on success.
411         """
412         version = self._get_version(oauth_request)
413         consumer = self._get_consumer(oauth_request)
414         verifier = self._get_verifier(oauth_request)
415         # Get the request token.
416         token = self._get_token(oauth_request, 'request')
417         self._check_signature(oauth_request, consumer, token)
418         new_token = self.data_store.fetch_access_token(consumer, token, verifier)
419         return new_token
420
421     def verify_request(self, oauth_request):
422         """Verifies an api call and checks all the parameters."""
423         # -> consumer and token
424         version = self._get_version(oauth_request)
425         consumer = self._get_consumer(oauth_request)
426         # Get the access token.
427         token = self._get_token(oauth_request, 'access')
428         self._check_signature(oauth_request, consumer, token)
429         parameters = oauth_request.get_nonoauth_parameters()
430         return consumer, token, parameters
431
432     def authorize_token(self, token, user):
433         """Authorize a request token."""
434         return self.data_store.authorize_request_token(token, user)
435
436     def get_callback(self, oauth_request):
437         """Get the callback URL."""
438         return oauth_request.get_parameter('oauth_callback')
439
440     def build_authenticate_header(self, realm=''):
441         """Optional support for the authenticate header."""
442         return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
443
444     def _get_version(self, oauth_request):
445         """Verify the correct version request for this server."""
446         try:
447             version = oauth_request.get_parameter('oauth_version')
448         except:
449             version = VERSION
450         if version and version != self.version:
451             raise OAuthError('OAuth version %s not supported.' % str(version))
452         return version
453
454     def _get_signature_method(self, oauth_request):
455         """Figure out the signature with some defaults."""
456         try:
457             signature_method = oauth_request.get_parameter(
458                 'oauth_signature_method')
459         except:
460             signature_method = SIGNATURE_METHOD
461         try:
462             # Get the signature method object.
463             signature_method = self.signature_methods[signature_method]
464         except:
465             signature_method_names = ', '.join(self.signature_methods.keys())
466             raise OAuthError('Signature method %s not supported try one of the '
467                 'following: %s' % (signature_method, signature_method_names))
468
469         return signature_method
470
471     def _get_consumer(self, oauth_request):
472         consumer_key = oauth_request.get_parameter('oauth_consumer_key')
473         consumer = self.data_store.lookup_consumer(consumer_key)
474         if not consumer:
475             raise OAuthError('Invalid consumer.')
476         return consumer
477
478     def _get_token(self, oauth_request, token_type='access'):
479         """Try to find the token for the provided request token key."""
480         token_field = oauth_request.get_parameter('oauth_token')
481         token = self.data_store.lookup_token(token_type, token_field)
482         if not token:
483             raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
484         return token
485
486     def _get_verifier(self, oauth_request):
487         return oauth_request.get_parameter('oauth_verifier')
488
489     def _check_signature(self, oauth_request, consumer, token):
490         timestamp, nonce = oauth_request._get_timestamp_nonce()
491         self._check_timestamp(timestamp)
492         self._check_nonce(consumer, token, nonce)
493         signature_method = self._get_signature_method(oauth_request)
494         try:
495             signature = oauth_request.get_parameter('oauth_signature')
496         except:
497             raise OAuthError('Missing signature.')
498         # Validate the signature.
499         valid_sig = signature_method.check_signature(oauth_request, consumer,
500             token, signature)
501         if not valid_sig:
502             key, base = signature_method.build_signature_base_string(
503                 oauth_request, consumer, token)
504             raise OAuthError('Invalid signature. Expected signature base '
505                 'string: %s' % base)
506         built = signature_method.build_signature(oauth_request, consumer, token)
507
508     def _check_timestamp(self, timestamp):
509         """Verify that timestamp is recentish."""
510         timestamp = int(timestamp)
511         now = int(time.time())
512         lapsed = now - timestamp
513         if lapsed > self.timestamp_threshold:
514             raise OAuthError('Expired timestamp: given %d and now %s has a '
515                 'greater difference than threshold %d' %
516                 (timestamp, now, self.timestamp_threshold))
517
518     def _check_nonce(self, consumer, token, nonce):
519         """Verify that the nonce is uniqueish."""
520         nonce = self.data_store.lookup_nonce(consumer, token, nonce)
521         if nonce:
522             raise OAuthError('Nonce already used: %s' % str(nonce))
523
524
525 class OAuthClient(object):
526     """OAuthClient is a worker to attempt to execute a request."""
527     consumer = None
528     token = None
529
530     def __init__(self, oauth_consumer, oauth_token):
531         self.consumer = oauth_consumer
532         self.token = oauth_token
533
534     def get_consumer(self):
535         return self.consumer
536
537     def get_token(self):
538         return self.token
539
540     def fetch_request_token(self, oauth_request):
541         """-> OAuthToken."""
542         raise NotImplementedError
543
544     def fetch_access_token(self, oauth_request):
545         """-> OAuthToken."""
546         raise NotImplementedError
547
548     def access_resource(self, oauth_request):
549         """-> Some protected resource."""
550         raise NotImplementedError
551
552
553 class OAuthDataStore(object):
554     """A database abstraction used to lookup consumers and tokens."""
555
556     def lookup_consumer(self, key):
557         """-> OAuthConsumer."""
558         raise NotImplementedError
559
560     def lookup_token(self, oauth_consumer, token_type, token_token):
561         """-> OAuthToken."""
562         raise NotImplementedError
563
564     def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
565         """-> OAuthToken."""
566         raise NotImplementedError
567
568     def fetch_request_token(self, oauth_consumer, oauth_callback):
569         """-> OAuthToken."""
570         raise NotImplementedError
571
572     def fetch_access_token(self, oauth_consumer, oauth_token, oauth_verifier):
573         """-> OAuthToken."""
574         raise NotImplementedError
575
576     def authorize_request_token(self, oauth_token, user):
577         """-> OAuthToken."""
578         raise NotImplementedError
579
580
581 class OAuthSignatureMethod(object):
582     """A strategy class that implements a signature method."""
583     def get_name(self):
584         """-> str."""
585         raise NotImplementedError
586
587     def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
588         """-> str key, str raw."""
589         raise NotImplementedError
590
591     def build_signature(self, oauth_request, oauth_consumer, oauth_token):
592         """-> str."""
593         raise NotImplementedError
594
595     def check_signature(self, oauth_request, consumer, token, signature):
596         built = self.build_signature(oauth_request, consumer, token)
597         return built == signature
598
599
600 class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
601
602     def get_name(self):
603         return 'HMAC-SHA1'
604
605     def build_signature_base_string(self, oauth_request, consumer, token):
606         sig = (
607             escape(oauth_request.get_normalized_http_method()),
608             escape(oauth_request.get_normalized_http_url()),
609             escape(oauth_request.get_normalized_parameters()),
610         )
611
612         key = '%s&' % escape(consumer.secret)
613         if token:
614             key += escape(token.secret)
615         raw = '&'.join(sig)
616         return key, raw
617
618     def build_signature(self, oauth_request, consumer, token):
619         """Builds the base signature string."""
620         key, raw = self.build_signature_base_string(oauth_request, consumer,
621             token)
622
623         # HMAC object.
624         try:
625             import hashlib # 2.5
626             hashed = hmac.new(key, raw, hashlib.sha1)
627         except:
628             import sha # Deprecated
629             hashed = hmac.new(key, raw, sha)
630
631         # Calculate the digest base 64.
632         return binascii.b2a_base64(hashed.digest())[:-1]
633
634
635 class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
636
637     def get_name(self):
638         return 'PLAINTEXT'
639
640     def build_signature_base_string(self, oauth_request, consumer, token):
641         """Concatenates the consumer key and secret."""
642         sig = '%s&' % escape(consumer.secret)
643         if token:
644             sig = sig + escape(token.secret)
645         return sig, sig
646
647     def build_signature(self, oauth_request, consumer, token):
648         key, raw = self.build_signature_base_string(oauth_request, consumer,
649             token)
650         return key