--- /dev/null
+from django.shortcuts import get_object_or_404
+from django.contrib.auth.decorators import login_required, user_passes_test
+from piston.handler import BaseHandler
+from piston.utils import rc, validate
+from catalogue.models import Book
+from catalogue.forms import BookImportForm
+
+
+staff_required = user_passes_test(lambda user: user.is_staff)
+
+
+class BookHandler(BaseHandler):
+ model = Book
+ fields = ('slug', 'title')
+
+ @staff_required
+ def read(self, request, slug=None):
+ if slug:
+ return get_object_or_404(Book, slug=slug)
+ else:
+ return Book.objects.all()
+
+ @staff_required
+ def create(self, request):
+ form = BookImportForm(request.POST, request.FILES)
+ if form.is_valid():
+ form.save()
+ return rc.CREATED
+ else:
+ return rc.BAD_REQUEST
+
--- /dev/null
+# -*- coding: utf-8 -*-
+from django.conf.urls.defaults import *
+from piston.resource import Resource
+from piston.authentication import HttpBasicAuthentication
+
+from api.handlers import BookHandler
+
+
+auth = HttpBasicAuthentication(realm='My sample API')
+book_resource = Resource(handler=BookHandler, authentication=auth)
+
+
+urlpatterns = patterns('',
+ url(r'^books/(?P<slug>[^/]+)\.(?P<emitter_format>xml|json|yaml)$', book_resource),
+ url(r'^books\.(?P<emitter_format>xml|json|yaml)$', book_resource),
+)
+
--- /dev/null
+import binascii
+
+import oauth
+from django.http import HttpResponse, HttpResponseRedirect
+from django.contrib.auth.models import User, AnonymousUser
+from django.contrib.auth.decorators import login_required
+from django.template import loader
+from django.contrib.auth import authenticate
+from django.conf import settings
+from django.core.urlresolvers import get_callable
+from django.core.exceptions import ImproperlyConfigured
+from django.shortcuts import render_to_response
+from django.template import RequestContext
+
+from piston import forms
+
+class NoAuthentication(object):
+ """
+ Authentication handler that always returns
+ True, so no authentication is needed, nor
+ initiated (`challenge` is missing.)
+ """
+ def is_authenticated(self, request):
+ return True
+
+class HttpBasicAuthentication(object):
+ """
+ Basic HTTP authenticater. Synopsis:
+
+ Authentication handlers must implement two methods:
+ - `is_authenticated`: Will be called when checking for
+ authentication. Receives a `request` object, please
+ set your `User` object on `request.user`, otherwise
+ return False (or something that evaluates to False.)
+ - `challenge`: In cases where `is_authenticated` returns
+ False, the result of this method will be returned.
+ This will usually be a `HttpResponse` object with
+ some kind of challenge headers and 401 code on it.
+ """
+ def __init__(self, auth_func=authenticate, realm='API'):
+ self.auth_func = auth_func
+ self.realm = realm
+
+ def is_authenticated(self, request):
+ auth_string = request.META.get('HTTP_AUTHORIZATION', None)
+
+ if not auth_string:
+ return False
+
+ try:
+ (authmeth, auth) = auth_string.split(" ", 1)
+
+ if not authmeth.lower() == 'basic':
+ return False
+
+ auth = auth.strip().decode('base64')
+ (username, password) = auth.split(':', 1)
+ except (ValueError, binascii.Error):
+ return False
+
+ request.user = self.auth_func(username=username, password=password) \
+ or AnonymousUser()
+
+ return not request.user in (False, None, AnonymousUser())
+
+ def challenge(self):
+ resp = HttpResponse("Authorization Required")
+ resp['WWW-Authenticate'] = 'Basic realm="%s"' % self.realm
+ resp.status_code = 401
+ return resp
+
+ def __repr__(self):
+ return u'<HTTPBasic: realm=%s>' % self.realm
+
+class HttpBasicSimple(HttpBasicAuthentication):
+ def __init__(self, realm, username, password):
+ self.user = User.objects.get(username=username)
+ self.password = password
+
+ super(HttpBasicSimple, self).__init__(auth_func=self.hash, realm=realm)
+
+ def hash(self, username, password):
+ if username == self.user.username and password == self.password:
+ return self.user
+
+def load_data_store():
+ '''Load data store for OAuth Consumers, Tokens, Nonces and Resources
+ '''
+ path = getattr(settings, 'OAUTH_DATA_STORE', 'piston.store.DataStore')
+
+ # stolen from django.contrib.auth.load_backend
+ i = path.rfind('.')
+ module, attr = path[:i], path[i+1:]
+
+ try:
+ mod = __import__(module, {}, {}, attr)
+ except ImportError, e:
+ raise ImproperlyConfigured, 'Error importing OAuth data store %s: "%s"' % (module, e)
+
+ try:
+ cls = getattr(mod, attr)
+ except AttributeError:
+ raise ImproperlyConfigured, 'Module %s does not define a "%s" OAuth data store' % (module, attr)
+
+ return cls
+
+# Set the datastore here.
+oauth_datastore = load_data_store()
+
+def initialize_server_request(request):
+ """
+ Shortcut for initialization.
+ """
+ if request.method == "POST": #and \
+# request.META['CONTENT_TYPE'] == "application/x-www-form-urlencoded":
+ params = dict(request.REQUEST.items())
+ else:
+ params = { }
+
+ # Seems that we want to put HTTP_AUTHORIZATION into 'Authorization'
+ # for oauth.py to understand. Lovely.
+ request.META['Authorization'] = request.META.get('HTTP_AUTHORIZATION', '')
+
+ oauth_request = oauth.OAuthRequest.from_request(
+ request.method, request.build_absolute_uri(),
+ headers=request.META, parameters=params,
+ query_string=request.environ.get('QUERY_STRING', ''))
+
+ if oauth_request:
+ oauth_server = oauth.OAuthServer(oauth_datastore(oauth_request))
+ oauth_server.add_signature_method(oauth.OAuthSignatureMethod_PLAINTEXT())
+ oauth_server.add_signature_method(oauth.OAuthSignatureMethod_HMAC_SHA1())
+ else:
+ oauth_server = None
+
+ return oauth_server, oauth_request
+
+def send_oauth_error(err=None):
+ """
+ Shortcut for sending an error.
+ """
+ response = HttpResponse(err.message.encode('utf-8'))
+ response.status_code = 401
+
+ realm = 'OAuth'
+ header = oauth.build_authenticate_header(realm=realm)
+
+ for k, v in header.iteritems():
+ response[k] = v
+
+ return response
+
+def oauth_request_token(request):
+ oauth_server, oauth_request = initialize_server_request(request)
+
+ if oauth_server is None:
+ return INVALID_PARAMS_RESPONSE
+ try:
+ token = oauth_server.fetch_request_token(oauth_request)
+
+ response = HttpResponse(token.to_string())
+ except oauth.OAuthError, err:
+ response = send_oauth_error(err)
+
+ return response
+
+def oauth_auth_view(request, token, callback, params):
+ form = forms.OAuthAuthenticationForm(initial={
+ 'oauth_token': token.key,
+ 'oauth_callback': token.get_callback_url() or callback,
+ })
+
+ return render_to_response('piston/authorize_token.html',
+ { 'form': form }, RequestContext(request))
+
+@login_required
+def oauth_user_auth(request):
+ oauth_server, oauth_request = initialize_server_request(request)
+
+ if oauth_request is None:
+ return INVALID_PARAMS_RESPONSE
+
+ try:
+ token = oauth_server.fetch_request_token(oauth_request)
+ except oauth.OAuthError, err:
+ return send_oauth_error(err)
+
+ try:
+ callback = oauth_server.get_callback(oauth_request)
+ except:
+ callback = None
+
+ if request.method == "GET":
+ params = oauth_request.get_normalized_parameters()
+
+ oauth_view = getattr(settings, 'OAUTH_AUTH_VIEW', None)
+ if oauth_view is None:
+ return oauth_auth_view(request, token, callback, params)
+ else:
+ return get_callable(oauth_view)(request, token, callback, params)
+ elif request.method == "POST":
+ try:
+ form = forms.OAuthAuthenticationForm(request.POST)
+ if form.is_valid():
+ token = oauth_server.authorize_token(token, request.user)
+ args = '?'+token.to_string(only_key=True)
+ else:
+ args = '?error=%s' % 'Access not granted by user.'
+ print "FORM ERROR", form.errors
+
+ if not callback:
+ callback = getattr(settings, 'OAUTH_CALLBACK_VIEW')
+ return get_callable(callback)(request, token)
+
+ response = HttpResponseRedirect(callback+args)
+
+ except oauth.OAuthError, err:
+ response = send_oauth_error(err)
+ else:
+ response = HttpResponse('Action not allowed.')
+
+ return response
+
+def oauth_access_token(request):
+ oauth_server, oauth_request = initialize_server_request(request)
+
+ if oauth_request is None:
+ return INVALID_PARAMS_RESPONSE
+
+ try:
+ token = oauth_server.fetch_access_token(oauth_request)
+ return HttpResponse(token.to_string())
+ except oauth.OAuthError, err:
+ return send_oauth_error(err)
+
+INVALID_PARAMS_RESPONSE = send_oauth_error(oauth.OAuthError('Invalid request parameters.'))
+
+class OAuthAuthentication(object):
+ """
+ OAuth authentication. Based on work by Leah Culver.
+ """
+ def __init__(self, realm='API'):
+ self.realm = realm
+ self.builder = oauth.build_authenticate_header
+
+ def is_authenticated(self, request):
+ """
+ Checks whether a means of specifying authentication
+ is provided, and if so, if it is a valid token.
+
+ Read the documentation on `HttpBasicAuthentication`
+ for more information about what goes on here.
+ """
+ if self.is_valid_request(request):
+ try:
+ consumer, token, parameters = self.validate_token(request)
+ except oauth.OAuthError, err:
+ print send_oauth_error(err)
+ return False
+
+ if consumer and token:
+ request.user = token.user
+ request.consumer = consumer
+ request.throttle_extra = token.consumer.id
+ return True
+
+ return False
+
+ def challenge(self):
+ """
+ Returns a 401 response with a small bit on
+ what OAuth is, and where to learn more about it.
+
+ When this was written, browsers did not understand
+ OAuth authentication on the browser side, and hence
+ the helpful template we render. Maybe some day in the
+ future, browsers will take care of this stuff for us
+ and understand the 401 with the realm we give it.
+ """
+ response = HttpResponse()
+ response.status_code = 401
+ realm = 'API'
+
+ for k, v in self.builder(realm=realm).iteritems():
+ response[k] = v
+
+ tmpl = loader.render_to_string('oauth/challenge.html',
+ { 'MEDIA_URL': settings.MEDIA_URL })
+
+ response.content = tmpl
+
+ return response
+
+ @staticmethod
+ def is_valid_request(request):
+ """
+ Checks whether the required parameters are either in
+ the http-authorization header sent by some clients,
+ which is by the way the preferred method according to
+ OAuth spec, but otherwise fall back to `GET` and `POST`.
+ """
+ must_have = [ 'oauth_'+s for s in [
+ 'consumer_key', 'token', 'signature',
+ 'signature_method', 'timestamp', 'nonce' ] ]
+
+ is_in = lambda l: all([ (p in l) for p in must_have ])
+
+ auth_params = request.META.get("HTTP_AUTHORIZATION", "")
+ req_params = request.REQUEST
+
+ return is_in(auth_params) or is_in(req_params)
+
+ @staticmethod
+ def validate_token(request, check_timestamp=True, check_nonce=True):
+ oauth_server, oauth_request = initialize_server_request(request)
+ return oauth_server.verify_request(oauth_request)
+
--- /dev/null
+"""
+Decorator module, see
+http://www.phyast.pitt.edu/~micheles/python/documentation.html
+for the documentation and below for the licence.
+"""
+
+## The basic trick is to generate the source code for the decorated function
+## with the right signature and to evaluate it.
+## Uncomment the statement 'print >> sys.stderr, func_src' in _decorator
+## to understand what is going on.
+
+__all__ = ["decorator", "new_wrapper", "getinfo"]
+
+import inspect, sys
+
+try:
+ set
+except NameError:
+ from sets import Set as set
+
+def getinfo(func):
+ """
+ Returns an info dictionary containing:
+ - name (the name of the function : str)
+ - argnames (the names of the arguments : list)
+ - defaults (the values of the default arguments : tuple)
+ - signature (the signature : str)
+ - doc (the docstring : str)
+ - module (the module name : str)
+ - dict (the function __dict__ : str)
+
+ >>> def f(self, x=1, y=2, *args, **kw): pass
+
+ >>> info = getinfo(f)
+
+ >>> info["name"]
+ 'f'
+ >>> info["argnames"]
+ ['self', 'x', 'y', 'args', 'kw']
+
+ >>> info["defaults"]
+ (1, 2)
+
+ >>> info["signature"]
+ 'self, x, y, *args, **kw'
+ """
+ assert inspect.ismethod(func) or inspect.isfunction(func)
+ regargs, varargs, varkwargs, defaults = inspect.getargspec(func)
+ argnames = list(regargs)
+ if varargs:
+ argnames.append(varargs)
+ if varkwargs:
+ argnames.append(varkwargs)
+ signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults,
+ formatvalue=lambda value: "")[1:-1]
+ return dict(name=func.__name__, argnames=argnames, signature=signature,
+ defaults = func.func_defaults, doc=func.__doc__,
+ module=func.__module__, dict=func.__dict__,
+ globals=func.func_globals, closure=func.func_closure)
+
+# akin to functools.update_wrapper
+def update_wrapper(wrapper, model, infodict=None):
+ infodict = infodict or getinfo(model)
+ try:
+ wrapper.__name__ = infodict['name']
+ except: # Python version < 2.4
+ pass
+ wrapper.__doc__ = infodict['doc']
+ wrapper.__module__ = infodict['module']
+ wrapper.__dict__.update(infodict['dict'])
+ wrapper.func_defaults = infodict['defaults']
+ wrapper.undecorated = model
+ return wrapper
+
+def new_wrapper(wrapper, model):
+ """
+ An improvement over functools.update_wrapper. The wrapper is a generic
+ callable object. It works by generating a copy of the wrapper with the
+ right signature and by updating the copy, not the original.
+ Moreovoer, 'model' can be a dictionary with keys 'name', 'doc', 'module',
+ 'dict', 'defaults'.
+ """
+ if isinstance(model, dict):
+ infodict = model
+ else: # assume model is a function
+ infodict = getinfo(model)
+ assert not '_wrapper_' in infodict["argnames"], (
+ '"_wrapper_" is a reserved argument name!')
+ src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict
+ funcopy = eval(src, dict(_wrapper_=wrapper))
+ return update_wrapper(funcopy, model, infodict)
+
+# helper used in decorator_factory
+def __call__(self, func):
+ infodict = getinfo(func)
+ for name in ('_func_', '_self_'):
+ assert not name in infodict["argnames"], (
+ '%s is a reserved argument name!' % name)
+ src = "lambda %(signature)s: _self_.call(_func_, %(signature)s)"
+ new = eval(src % infodict, dict(_func_=func, _self_=self))
+ return update_wrapper(new, func, infodict)
+
+def decorator_factory(cls):
+ """
+ Take a class with a ``.caller`` method and return a callable decorator
+ object. It works by adding a suitable __call__ method to the class;
+ it raises a TypeError if the class already has a nontrivial __call__
+ method.
+ """
+ attrs = set(dir(cls))
+ if '__call__' in attrs:
+ raise TypeError('You cannot decorate a class with a nontrivial '
+ '__call__ method')
+ if 'call' not in attrs:
+ raise TypeError('You cannot decorate a class without a '
+ '.call method')
+ cls.__call__ = __call__
+ return cls
+
+def decorator(caller):
+ """
+ General purpose decorator factory: takes a caller function as
+ input and returns a decorator with the same attributes.
+ A caller function is any function like this::
+
+ def caller(func, *args, **kw):
+ # do something
+ return func(*args, **kw)
+
+ Here is an example of usage:
+
+ >>> @decorator
+ ... def chatty(f, *args, **kw):
+ ... print "Calling %r" % f.__name__
+ ... return f(*args, **kw)
+
+ >>> chatty.__name__
+ 'chatty'
+
+ >>> @chatty
+ ... def f(): pass
+ ...
+ >>> f()
+ Calling 'f'
+
+ decorator can also take in input a class with a .caller method; in this
+ case it converts the class into a factory of callable decorator objects.
+ See the documentation for an example.
+ """
+ if inspect.isclass(caller):
+ return decorator_factory(caller)
+ def _decorator(func): # the real meat is here
+ infodict = getinfo(func)
+ argnames = infodict['argnames']
+ assert not ('_call_' in argnames or '_func_' in argnames), (
+ 'You cannot use _call_ or _func_ as argument names!')
+ src = "lambda %(signature)s: _call_(_func_, %(signature)s)" % infodict
+ # import sys; print >> sys.stderr, src # for debugging purposes
+ dec_func = eval(src, dict(_func_=func, _call_=caller))
+ return update_wrapper(dec_func, func, infodict)
+ return update_wrapper(_decorator, caller)
+
+if __name__ == "__main__":
+ import doctest; doctest.testmod()
+
+########################## LEGALESE ###############################
+
+## Redistributions of source code must retain the above copyright
+## notice, this list of conditions and the following disclaimer.
+## Redistributions in bytecode form must reproduce the above copyright
+## notice, this list of conditions and the following disclaimer in
+## the documentation and/or other materials provided with the
+## distribution.
+
+## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
+## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+## DAMAGE.
--- /dev/null
+import inspect, handler
+
+from piston.handler import typemapper
+from piston.handler import handler_tracker
+
+from django.core.urlresolvers import get_resolver, get_callable, get_script_prefix
+from django.shortcuts import render_to_response
+from django.template import RequestContext
+
+def generate_doc(handler_cls):
+ """
+ Returns a `HandlerDocumentation` object
+ for the given handler. Use this to generate
+ documentation for your API.
+ """
+ if not type(handler_cls) is handler.HandlerMetaClass:
+ raise ValueError("Give me handler, not %s" % type(handler_cls))
+
+ return HandlerDocumentation(handler_cls)
+
+class HandlerMethod(object):
+ def __init__(self, method, stale=False):
+ self.method = method
+ self.stale = stale
+
+ def iter_args(self):
+ args, _, _, defaults = inspect.getargspec(self.method)
+
+ for idx, arg in enumerate(args):
+ if arg in ('self', 'request', 'form'):
+ continue
+
+ didx = len(args)-idx
+
+ if defaults and len(defaults) >= didx:
+ yield (arg, str(defaults[-didx]))
+ else:
+ yield (arg, None)
+
+ @property
+ def signature(self, parse_optional=True):
+ spec = ""
+
+ for argn, argdef in self.iter_args():
+ spec += argn
+
+ if argdef:
+ spec += '=%s' % argdef
+
+ spec += ', '
+
+ spec = spec.rstrip(", ")
+
+ if parse_optional:
+ return spec.replace("=None", "=<optional>")
+
+ return spec
+
+ @property
+ def doc(self):
+ return inspect.getdoc(self.method)
+
+ @property
+ def name(self):
+ return self.method.__name__
+
+ @property
+ def http_name(self):
+ if self.name == 'read':
+ return 'GET'
+ elif self.name == 'create':
+ return 'POST'
+ elif self.name == 'delete':
+ return 'DELETE'
+ elif self.name == 'update':
+ return 'PUT'
+
+ def __repr__(self):
+ return "<Method: %s>" % self.name
+
+class HandlerDocumentation(object):
+ def __init__(self, handler):
+ self.handler = handler
+
+ def get_methods(self, include_default=False):
+ for method in "read create update delete".split():
+ met = getattr(self.handler, method, None)
+
+ if not met:
+ continue
+
+ stale = inspect.getmodule(met) is handler
+
+ if not self.handler.is_anonymous:
+ if met and (not stale or include_default):
+ yield HandlerMethod(met, stale)
+ else:
+ if not stale or met.__name__ == "read" \
+ and 'GET' in self.allowed_methods:
+
+ yield HandlerMethod(met, stale)
+
+ def get_all_methods(self):
+ return self.get_methods(include_default=True)
+
+ @property
+ def is_anonymous(self):
+ return handler.is_anonymous
+
+ def get_model(self):
+ return getattr(self, 'model', None)
+
+ @property
+ def has_anonymous(self):
+ return self.handler.anonymous
+
+ @property
+ def anonymous(self):
+ if self.has_anonymous:
+ return HandlerDocumentation(self.handler.anonymous)
+
+ @property
+ def doc(self):
+ return self.handler.__doc__
+
+ @property
+ def name(self):
+ return self.handler.__name__
+
+ @property
+ def allowed_methods(self):
+ return self.handler.allowed_methods
+
+ def get_resource_uri_template(self):
+ """
+ URI template processor.
+
+ See http://bitworking.org/projects/URI-Templates/
+ """
+ def _convert(template, params=[]):
+ """URI template converter"""
+ paths = template % dict([p, "{%s}" % p] for p in params)
+ return u'%s%s' % (get_script_prefix(), paths)
+
+ try:
+ resource_uri = self.handler.resource_uri()
+
+ components = [None, [], {}]
+
+ for i, value in enumerate(resource_uri):
+ components[i] = value
+
+ lookup_view, args, kwargs = components
+ lookup_view = get_callable(lookup_view, True)
+
+ possibilities = get_resolver(None).reverse_dict.getlist(lookup_view)
+
+ for possibility, pattern in possibilities:
+ for result, params in possibility:
+ if args:
+ if len(args) != len(params):
+ continue
+ return _convert(result, params)
+ else:
+ if set(kwargs.keys()) != set(params):
+ continue
+ return _convert(result, params)
+ except:
+ return None
+
+ resource_uri_template = property(get_resource_uri_template)
+
+ def __repr__(self):
+ return u'<Documentation for "%s">' % self.name
+
+def documentation_view(request):
+ """
+ Generic documentation view. Generates documentation
+ from the handlers you've defined.
+ """
+ docs = [ ]
+
+ for handler in handler_tracker:
+ docs.append(generate_doc(handler))
+
+ def _compare(doc1, doc2):
+ #handlers and their anonymous counterparts are put next to each other.
+ name1 = doc1.name.replace("Anonymous", "")
+ name2 = doc2.name.replace("Anonymous", "")
+ return cmp(name1, name2)
+
+ docs.sort(_compare)
+
+ return render_to_response('documentation.html',
+ { 'docs': docs }, RequestContext(request))
--- /dev/null
+from __future__ import generators
+
+import decimal, re, inspect
+import copy
+
+try:
+ # yaml isn't standard with python. It shouldn't be required if it
+ # isn't used.
+ import yaml
+except ImportError:
+ yaml = None
+
+# Fallback since `any` isn't in Python <2.5
+try:
+ any
+except NameError:
+ def any(iterable):
+ for element in iterable:
+ if element:
+ return True
+ return False
+
+from django.db.models.query import QuerySet
+from django.db.models import Model, permalink
+from django.utils import simplejson
+from django.utils.xmlutils import SimplerXMLGenerator
+from django.utils.encoding import smart_unicode
+from django.core.urlresolvers import reverse, NoReverseMatch
+from django.core.serializers.json import DateTimeAwareJSONEncoder
+from django.http import HttpResponse
+from django.core import serializers
+
+from utils import HttpStatusCode, Mimer
+
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+try:
+ import cPickle as pickle
+except ImportError:
+ import pickle
+
+# Allow people to change the reverser (default `permalink`).
+reverser = permalink
+
+class Emitter(object):
+ """
+ Super emitter. All other emitters should subclass
+ this one. It has the `construct` method which
+ conveniently returns a serialized `dict`. This is
+ usually the only method you want to use in your
+ emitter. See below for examples.
+
+ `RESERVED_FIELDS` was introduced when better resource
+ method detection came, and we accidentially caught these
+ as the methods on the handler. Issue58 says that's no good.
+ """
+ EMITTERS = { }
+ RESERVED_FIELDS = set([ 'read', 'update', 'create',
+ 'delete', 'model', 'anonymous',
+ 'allowed_methods', 'fields', 'exclude' ])
+
+ def __init__(self, payload, typemapper, handler, fields=(), anonymous=True):
+ self.typemapper = typemapper
+ self.data = payload
+ self.handler = handler
+ self.fields = fields
+ self.anonymous = anonymous
+
+ if isinstance(self.data, Exception):
+ raise
+
+ def method_fields(self, handler, fields):
+ if not handler:
+ return { }
+
+ ret = dict()
+
+ for field in fields - Emitter.RESERVED_FIELDS:
+ t = getattr(handler, str(field), None)
+
+ if t and callable(t):
+ ret[field] = t
+
+ return ret
+
+ def construct(self):
+ """
+ Recursively serialize a lot of types, and
+ in cases where it doesn't recognize the type,
+ it will fall back to Django's `smart_unicode`.
+
+ Returns `dict`.
+ """
+ def _any(thing, fields=()):
+ """
+ Dispatch, all types are routed through here.
+ """
+ ret = None
+
+ if isinstance(thing, QuerySet):
+ ret = _qs(thing, fields=fields)
+ elif isinstance(thing, (tuple, list)):
+ ret = _list(thing)
+ elif isinstance(thing, dict):
+ ret = _dict(thing)
+ elif isinstance(thing, decimal.Decimal):
+ ret = str(thing)
+ elif isinstance(thing, Model):
+ ret = _model(thing, fields=fields)
+ elif isinstance(thing, HttpResponse):
+ raise HttpStatusCode(thing)
+ elif inspect.isfunction(thing):
+ if not inspect.getargspec(thing)[0]:
+ ret = _any(thing())
+ elif hasattr(thing, '__emittable__'):
+ f = thing.__emittable__
+ if inspect.ismethod(f) and len(inspect.getargspec(f)[0]) == 1:
+ ret = _any(f())
+ elif repr(thing).startswith("<django.db.models.fields.related.RelatedManager"):
+ ret = _any(thing.all())
+ else:
+ ret = smart_unicode(thing, strings_only=True)
+
+ return ret
+
+ def _fk(data, field):
+ """
+ Foreign keys.
+ """
+ return _any(getattr(data, field.name))
+
+ def _related(data, fields=()):
+ """
+ Foreign keys.
+ """
+ return [ _model(m, fields) for m in data.iterator() ]
+
+ def _m2m(data, field, fields=()):
+ """
+ Many to many (re-route to `_model`.)
+ """
+ return [ _model(m, fields) for m in getattr(data, field.name).iterator() ]
+
+ def _model(data, fields=()):
+ """
+ Models. Will respect the `fields` and/or
+ `exclude` on the handler (see `typemapper`.)
+ """
+ ret = { }
+ handler = self.in_typemapper(type(data), self.anonymous)
+ get_absolute_uri = False
+
+ if handler or fields:
+ v = lambda f: getattr(data, f.attname)
+
+ if not fields:
+ """
+ Fields was not specified, try to find teh correct
+ version in the typemapper we were sent.
+ """
+ mapped = self.in_typemapper(type(data), self.anonymous)
+ get_fields = set(mapped.fields)
+ exclude_fields = set(mapped.exclude).difference(get_fields)
+
+ if 'absolute_uri' in get_fields:
+ get_absolute_uri = True
+
+ if not get_fields:
+ get_fields = set([ f.attname.replace("_id", "", 1)
+ for f in data._meta.fields ])
+
+ # sets can be negated.
+ for exclude in exclude_fields:
+ if isinstance(exclude, basestring):
+ get_fields.discard(exclude)
+
+ elif isinstance(exclude, re._pattern_type):
+ for field in get_fields.copy():
+ if exclude.match(field):
+ get_fields.discard(field)
+
+ else:
+ get_fields = set(fields)
+
+ met_fields = self.method_fields(handler, get_fields)
+
+ for f in data._meta.local_fields:
+ if f.serialize and not any([ p in met_fields for p in [ f.attname, f.name ]]):
+ if not f.rel:
+ if f.attname in get_fields:
+ ret[f.attname] = _any(v(f))
+ get_fields.remove(f.attname)
+ else:
+ if f.attname[:-3] in get_fields:
+ ret[f.name] = _fk(data, f)
+ get_fields.remove(f.name)
+
+ for mf in data._meta.many_to_many:
+ if mf.serialize and mf.attname not in met_fields:
+ if mf.attname in get_fields:
+ ret[mf.name] = _m2m(data, mf)
+ get_fields.remove(mf.name)
+
+ # try to get the remainder of fields
+ for maybe_field in get_fields:
+ if isinstance(maybe_field, (list, tuple)):
+ model, fields = maybe_field
+ inst = getattr(data, model, None)
+
+ if inst:
+ if hasattr(inst, 'all'):
+ ret[model] = _related(inst, fields)
+ elif callable(inst):
+ if len(inspect.getargspec(inst)[0]) == 1:
+ ret[model] = _any(inst(), fields)
+ else:
+ ret[model] = _model(inst, fields)
+
+ elif maybe_field in met_fields:
+ # Overriding normal field which has a "resource method"
+ # so you can alter the contents of certain fields without
+ # using different names.
+ ret[maybe_field] = _any(met_fields[maybe_field](data))
+
+ else:
+ maybe = getattr(data, maybe_field, None)
+ if maybe:
+ if callable(maybe):
+ if len(inspect.getargspec(maybe)[0]) == 1:
+ ret[maybe_field] = _any(maybe())
+ else:
+ ret[maybe_field] = _any(maybe)
+ else:
+ handler_f = getattr(handler or self.handler, maybe_field, None)
+
+ if handler_f:
+ ret[maybe_field] = _any(handler_f(data))
+
+ else:
+ for f in data._meta.fields:
+ ret[f.attname] = _any(getattr(data, f.attname))
+
+ fields = dir(data.__class__) + ret.keys()
+ add_ons = [k for k in dir(data) if k not in fields]
+
+ for k in add_ons:
+ ret[k] = _any(getattr(data, k))
+
+ # resouce uri
+ if self.in_typemapper(type(data), self.anonymous):
+ handler = self.in_typemapper(type(data), self.anonymous)
+ if hasattr(handler, 'resource_uri'):
+ url_id, fields = handler.resource_uri(data)
+
+ try:
+ ret['resource_uri'] = reverser( lambda: (url_id, fields) )()
+ except NoReverseMatch, e:
+ pass
+
+ if hasattr(data, 'get_api_url') and 'resource_uri' not in ret:
+ try: ret['resource_uri'] = data.get_api_url()
+ except: pass
+
+ # absolute uri
+ if hasattr(data, 'get_absolute_url') and get_absolute_uri:
+ try: ret['absolute_uri'] = data.get_absolute_url()
+ except: pass
+
+ return ret
+
+ def _qs(data, fields=()):
+ """
+ Querysets.
+ """
+ return [ _any(v, fields) for v in data ]
+
+ def _list(data):
+ """
+ Lists.
+ """
+ return [ _any(v) for v in data ]
+
+ def _dict(data):
+ """
+ Dictionaries.
+ """
+ return dict([ (k, _any(v)) for k, v in data.iteritems() ])
+
+ # Kickstart the seralizin'.
+ return _any(self.data, self.fields)
+
+ def in_typemapper(self, model, anonymous):
+ for klass, (km, is_anon) in self.typemapper.iteritems():
+ if model is km and is_anon is anonymous:
+ return klass
+
+ def render(self):
+ """
+ This super emitter does not implement `render`,
+ this is a job for the specific emitter below.
+ """
+ raise NotImplementedError("Please implement render.")
+
+ def stream_render(self, request, stream=True):
+ """
+ Tells our patched middleware not to look
+ at the contents, and returns a generator
+ rather than the buffered string. Should be
+ more memory friendly for large datasets.
+ """
+ yield self.render(request)
+
+ @classmethod
+ def get(cls, format):
+ """
+ Gets an emitter, returns the class and a content-type.
+ """
+ if cls.EMITTERS.has_key(format):
+ return cls.EMITTERS.get(format)
+
+ raise ValueError("No emitters found for type %s" % format)
+
+ @classmethod
+ def register(cls, name, klass, content_type='text/plain'):
+ """
+ Register an emitter.
+
+ Parameters::
+ - `name`: The name of the emitter ('json', 'xml', 'yaml', ...)
+ - `klass`: The emitter class.
+ - `content_type`: The content type to serve response as.
+ """
+ cls.EMITTERS[name] = (klass, content_type)
+
+ @classmethod
+ def unregister(cls, name):
+ """
+ Remove an emitter from the registry. Useful if you don't
+ want to provide output in one of the built-in emitters.
+ """
+ return cls.EMITTERS.pop(name, None)
+
+class XMLEmitter(Emitter):
+ def _to_xml(self, xml, data):
+ if isinstance(data, (list, tuple)):
+ for item in data:
+ xml.startElement("resource", {})
+ self._to_xml(xml, item)
+ xml.endElement("resource")
+ elif isinstance(data, dict):
+ for key, value in data.iteritems():
+ xml.startElement(key, {})
+ self._to_xml(xml, value)
+ xml.endElement(key)
+ else:
+ xml.characters(smart_unicode(data))
+
+ def render(self, request):
+ stream = StringIO.StringIO()
+
+ xml = SimplerXMLGenerator(stream, "utf-8")
+ xml.startDocument()
+ xml.startElement("response", {})
+
+ self._to_xml(xml, self.construct())
+
+ xml.endElement("response")
+ xml.endDocument()
+
+ return stream.getvalue()
+
+Emitter.register('xml', XMLEmitter, 'text/xml; charset=utf-8')
+Mimer.register(lambda *a: None, ('text/xml',))
+
+class JSONEmitter(Emitter):
+ """
+ JSON emitter, understands timestamps.
+ """
+ def render(self, request):
+ cb = request.GET.get('callback')
+ seria = simplejson.dumps(self.construct(), cls=DateTimeAwareJSONEncoder, ensure_ascii=False, indent=4)
+
+ # Callback
+ if cb:
+ return '%s(%s)' % (cb, seria)
+
+ return seria
+
+Emitter.register('json', JSONEmitter, 'application/json; charset=utf-8')
+Mimer.register(simplejson.loads, ('application/json',))
+
+class YAMLEmitter(Emitter):
+ """
+ YAML emitter, uses `safe_dump` to omit the
+ specific types when outputting to non-Python.
+ """
+ def render(self, request):
+ return yaml.safe_dump(self.construct())
+
+if yaml: # Only register yaml if it was import successfully.
+ Emitter.register('yaml', YAMLEmitter, 'application/x-yaml; charset=utf-8')
+ Mimer.register(lambda s: dict(yaml.load(s)), ('application/x-yaml',))
+
+class PickleEmitter(Emitter):
+ """
+ Emitter that returns Python pickled.
+ """
+ def render(self, request):
+ return pickle.dumps(self.construct())
+
+Emitter.register('pickle', PickleEmitter, 'application/python-pickle')
+
+"""
+WARNING: Accepting arbitrary pickled data is a huge security concern.
+The unpickler has been disabled by default now, and if you want to use
+it, please be aware of what implications it will have.
+
+Read more: http://nadiana.com/python-pickle-insecure
+
+Uncomment the line below to enable it. You're doing so at your own risk.
+"""
+# Mimer.register(pickle.loads, ('application/python-pickle',))
+
+class DjangoEmitter(Emitter):
+ """
+ Emitter for the Django serialized format.
+ """
+ def render(self, request, format='xml'):
+ if isinstance(self.data, HttpResponse):
+ return self.data
+ elif isinstance(self.data, (int, str)):
+ response = self.data
+ else:
+ response = serializers.serialize(format, self.data, indent=True)
+
+ return response
+
+Emitter.register('django', DjangoEmitter, 'text/xml; charset=utf-8')
--- /dev/null
+[
+ {
+ "pk": 2,
+ "model": "auth.user",
+ "fields": {
+ "username": "pistontestuser",
+ "first_name": "Piston",
+ "last_name": "User",
+ "is_active": true,
+ "is_superuser": false,
+ "is_staff": false,
+ "last_login": "2009-08-03 13:11:53",
+ "groups": [],
+ "user_permissions": [],
+ "password": "sha1$b6c1f$83d5879f3854f6e9d27f393e3bcb4b8db05cf671",
+ "email": "pistontestuser@example.com",
+ "date_joined": "2009-08-03 13:11:53"
+ }
+ },
+ {
+ "pk": 3,
+ "model": "auth.user",
+ "fields": {
+ "username": "pistontestconsumer",
+ "first_name": "Piston",
+ "last_name": "Consumer",
+ "is_active": true,
+ "is_superuser": false,
+ "is_staff": false,
+ "last_login": "2009-08-03 13:11:53",
+ "groups": [],
+ "user_permissions": [],
+ "password": "sha1$b6c1f$83d5879f3854f6e9d27f393e3bcb4b8db05cf671",
+ "email": "pistontestconsumer@example.com",
+ "date_joined": "2009-08-03 13:11:53"
+ }
+ },
+ {
+ "pk": 1,
+ "model": "sites.site",
+ "fields": {
+ "domain": "example.com",
+ "name": "example.com"
+ }
+ }
+]
--- /dev/null
+[
+ {
+ "pk": 1,
+ "model": "piston.consumer",
+ "fields": {
+ "status": "accepted",
+ "name": "Piston Test Consumer",
+ "secret": "T5XkNMkcjffDpC9mNQJbyQnJXGsenYbz",
+ "user": 2,
+ "key": "8aZSFj3W54h8J8sCpx",
+ "description": "A test consumer record for Piston unit tests."
+ }
+ },
+ {
+ "pk": 1,
+ "model": "piston.token",
+ "fields": {
+ "is_approved": true,
+ "timestamp": 1249347414,
+ "token_type": 2,
+ "secret": "qSWZq36t7yvkBquetYBkd8JxnuCu9jKk",
+ "user": 2,
+ "key": "Y7358vL5hDBbeP3HHL",
+ "consumer": 1
+ }
+ }
+]
--- /dev/null
+import hmac, base64
+
+from django import forms
+from django.conf import settings
+
+class Form(forms.Form):
+ pass
+
+class ModelForm(forms.ModelForm):
+ """
+ Subclass of `forms.ModelForm` which makes sure
+ that the initial values are present in the form
+ data, so you don't have to send all old values
+ for the form to actually validate. Django does not
+ do this on its own, which is really annoying.
+ """
+ def merge_from_initial(self):
+ self.data._mutable = True
+ filt = lambda v: v not in self.data.keys()
+ for field in filter(filt, getattr(self.Meta, 'fields', ())):
+ self.data[field] = self.initial.get(field, None)
+
+
+class OAuthAuthenticationForm(forms.Form):
+ oauth_token = forms.CharField(widget=forms.HiddenInput)
+ oauth_callback = forms.CharField(widget=forms.HiddenInput, required=False)
+ authorize_access = forms.BooleanField(required=True)
+ csrf_signature = forms.CharField(widget=forms.HiddenInput)
+
+ def __init__(self, *args, **kwargs):
+ forms.Form.__init__(self, *args, **kwargs)
+
+ self.fields['csrf_signature'].initial = self.initial_csrf_signature
+
+ def clean_csrf_signature(self):
+ sig = self.cleaned_data['csrf_signature']
+ token = self.cleaned_data['oauth_token']
+
+ sig1 = OAuthAuthenticationForm.get_csrf_signature(settings.SECRET_KEY, token)
+
+ if sig != sig1:
+ raise forms.ValidationError("CSRF signature is not valid")
+
+ return sig
+
+ def initial_csrf_signature(self):
+ token = self.initial['oauth_token']
+ return OAuthAuthenticationForm.get_csrf_signature(settings.SECRET_KEY, token)
+
+ @staticmethod
+ def get_csrf_signature(key, token):
+ # Check signature...
+ try:
+ import hashlib # 2.5
+ hashed = hmac.new(key, token, hashlib.sha1)
+ except:
+ import sha # deprecated
+ hashed = hmac.new(key, token, sha)
+
+ # calculate the digest base 64
+ return base64.b64encode(hashed.digest())
+
--- /dev/null
+import warnings
+
+from utils import rc
+from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
+from django.conf import settings
+
+typemapper = { }
+handler_tracker = [ ]
+
+class HandlerMetaClass(type):
+ """
+ Metaclass that keeps a registry of class -> handler
+ mappings.
+ """
+ def __new__(cls, name, bases, attrs):
+ new_cls = type.__new__(cls, name, bases, attrs)
+
+ def already_registered(model, anon):
+ for k, (m, a) in typemapper.iteritems():
+ if model == m and anon == a:
+ return k
+
+ if hasattr(new_cls, 'model'):
+ if already_registered(new_cls.model, new_cls.is_anonymous):
+ if not getattr(settings, 'PISTON_IGNORE_DUPE_MODELS', False):
+ warnings.warn("Handler already registered for model %s, "
+ "you may experience inconsistent results." % new_cls.model.__name__)
+
+ typemapper[new_cls] = (new_cls.model, new_cls.is_anonymous)
+ else:
+ typemapper[new_cls] = (None, new_cls.is_anonymous)
+
+ if name not in ('BaseHandler', 'AnonymousBaseHandler'):
+ handler_tracker.append(new_cls)
+
+ return new_cls
+
+class BaseHandler(object):
+ """
+ Basehandler that gives you CRUD for free.
+ You are supposed to subclass this for specific
+ functionality.
+
+ All CRUD methods (`read`/`update`/`create`/`delete`)
+ receive a request as the first argument from the
+ resource. Use this for checking `request.user`, etc.
+ """
+ __metaclass__ = HandlerMetaClass
+
+ allowed_methods = ('GET', 'POST', 'PUT', 'DELETE')
+ anonymous = is_anonymous = False
+ exclude = ( 'id', )
+ fields = ( )
+
+ def flatten_dict(self, dct):
+ return dict([ (str(k), dct.get(k)) for k in dct.keys() ])
+
+ def has_model(self):
+ return hasattr(self, 'model') or hasattr(self, 'queryset')
+
+ def queryset(self, request):
+ return self.model.objects.all()
+
+ def value_from_tuple(tu, name):
+ for int_, n in tu:
+ if n == name:
+ return int_
+ return None
+
+ def exists(self, **kwargs):
+ if not self.has_model():
+ raise NotImplementedError
+
+ try:
+ self.model.objects.get(**kwargs)
+ return True
+ except self.model.DoesNotExist:
+ return False
+
+ def read(self, request, *args, **kwargs):
+ if not self.has_model():
+ return rc.NOT_IMPLEMENTED
+
+ pkfield = self.model._meta.pk.name
+
+ if pkfield in kwargs:
+ try:
+ return self.queryset(request).get(pk=kwargs.get(pkfield))
+ except ObjectDoesNotExist:
+ return rc.NOT_FOUND
+ except MultipleObjectsReturned: # should never happen, since we're using a PK
+ return rc.BAD_REQUEST
+ else:
+ return self.queryset(request).filter(*args, **kwargs)
+
+ def create(self, request, *args, **kwargs):
+ if not self.has_model():
+ return rc.NOT_IMPLEMENTED
+
+ attrs = self.flatten_dict(request.POST)
+
+ try:
+ inst = self.queryset(request).get(**attrs)
+ return rc.DUPLICATE_ENTRY
+ except self.model.DoesNotExist:
+ inst = self.model(**attrs)
+ inst.save()
+ return inst
+ except self.model.MultipleObjectsReturned:
+ return rc.DUPLICATE_ENTRY
+
+ def update(self, request, *args, **kwargs):
+ if not self.has_model():
+ return rc.NOT_IMPLEMENTED
+
+ pkfield = self.model._meta.pk.name
+
+ if pkfield not in kwargs:
+ # No pk was specified
+ return rc.BAD_REQUEST
+
+ try:
+ inst = self.queryset(request).get(pk=kwargs.get(pkfield))
+ except ObjectDoesNotExist:
+ return rc.NOT_FOUND
+ except MultipleObjectsReturned: # should never happen, since we're using a PK
+ return rc.BAD_REQUEST
+
+ attrs = self.flatten_dict(request.POST)
+ for k,v in attrs.iteritems():
+ setattr( inst, k, v )
+
+ inst.save()
+ return rc.ALL_OK
+
+ def delete(self, request, *args, **kwargs):
+ if not self.has_model():
+ raise NotImplementedError
+
+ try:
+ inst = self.queryset(request).get(*args, **kwargs)
+
+ inst.delete()
+
+ return rc.DELETED
+ except self.model.MultipleObjectsReturned:
+ return rc.DUPLICATE_ENTRY
+ except self.model.DoesNotExist:
+ return rc.NOT_HERE
+
+class AnonymousBaseHandler(BaseHandler):
+ """
+ Anonymous handler.
+ """
+ is_anonymous = True
+ allowed_methods = ('GET',)
--- /dev/null
+from django.db import models
+from django.contrib.auth.models import User
+
+KEY_SIZE = 18
+SECRET_SIZE = 32
+
+class KeyManager(models.Manager):
+ '''Add support for random key/secret generation
+ '''
+ def generate_random_codes(self):
+ key = User.objects.make_random_password(length=KEY_SIZE)
+ secret = User.objects.make_random_password(length=SECRET_SIZE)
+
+ while self.filter(key__exact=key, secret__exact=secret).count():
+ secret = User.objects.make_random_password(length=SECRET_SIZE)
+
+ return key, secret
+
+
+class ConsumerManager(KeyManager):
+ def create_consumer(self, name, description=None, user=None):
+ """
+ Shortcut to create a consumer with random key/secret.
+ """
+ consumer, created = self.get_or_create(name=name)
+
+ if user:
+ consumer.user = user
+
+ if description:
+ consumer.description = description
+
+ if created:
+ consumer.key, consumer.secret = self.generate_random_codes()
+ consumer.save()
+
+ return consumer
+
+ _default_consumer = None
+
+class ResourceManager(models.Manager):
+ _default_resource = None
+
+ def get_default_resource(self, name):
+ """
+ Add cache if you use a default resource.
+ """
+ if not self._default_resource:
+ self._default_resource = self.get(name=name)
+
+ return self._default_resource
+
+class TokenManager(KeyManager):
+ def create_token(self, consumer, token_type, timestamp, user=None):
+ """
+ Shortcut to create a token with random key/secret.
+ """
+ token, created = self.get_or_create(consumer=consumer,
+ token_type=token_type,
+ timestamp=timestamp,
+ user=user)
+
+ if created:
+ token.key, token.secret = self.generate_random_codes()
+ token.save()
+
+ return token
+
--- /dev/null
+from django.middleware.http import ConditionalGetMiddleware
+from django.middleware.common import CommonMiddleware
+
+def compat_middleware_factory(klass):
+ """
+ Class wrapper that only executes `process_response`
+ if `streaming` is not set on the `HttpResponse` object.
+ Django has a bad habbit of looking at the content,
+ which will prematurely exhaust the data source if we're
+ using generators or buffers.
+ """
+ class compatwrapper(klass):
+ def process_response(self, req, resp):
+ if not hasattr(resp, 'streaming'):
+ return klass.process_response(self, req, resp)
+ return resp
+ return compatwrapper
+
+ConditionalMiddlewareCompatProxy = compat_middleware_factory(ConditionalGetMiddleware)
+CommonMiddlewareCompatProxy = compat_middleware_factory(CommonMiddleware)
--- /dev/null
+import urllib, time, urlparse
+
+# Django imports
+from django.db.models.signals import post_save, post_delete
+from django.db import models
+from django.contrib.auth.models import User
+from django.contrib import admin
+from django.core.mail import send_mail, mail_admins
+
+# Piston imports
+from managers import TokenManager, ConsumerManager, ResourceManager
+from signals import consumer_post_save, consumer_post_delete
+
+KEY_SIZE = 18
+SECRET_SIZE = 32
+VERIFIER_SIZE = 10
+
+CONSUMER_STATES = (
+ ('pending', 'Pending'),
+ ('accepted', 'Accepted'),
+ ('canceled', 'Canceled'),
+ ('rejected', 'Rejected')
+)
+
+def generate_random(length=SECRET_SIZE):
+ return User.objects.make_random_password(length=length)
+
+class Nonce(models.Model):
+ token_key = models.CharField(max_length=KEY_SIZE)
+ consumer_key = models.CharField(max_length=KEY_SIZE)
+ key = models.CharField(max_length=255)
+
+ def __unicode__(self):
+ return u"Nonce %s for %s" % (self.key, self.consumer_key)
+
+admin.site.register(Nonce)
+
+class Consumer(models.Model):
+ name = models.CharField(max_length=255)
+ description = models.TextField()
+
+ key = models.CharField(max_length=KEY_SIZE)
+ secret = models.CharField(max_length=SECRET_SIZE)
+
+ status = models.CharField(max_length=16, choices=CONSUMER_STATES, default='pending')
+ user = models.ForeignKey(User, null=True, blank=True, related_name='consumers')
+
+ objects = ConsumerManager()
+
+ def __unicode__(self):
+ return u"Consumer %s with key %s" % (self.name, self.key)
+
+ def generate_random_codes(self):
+ """
+ Used to generate random key/secret pairings. Use this after you've
+ added the other data in place of save().
+
+ c = Consumer()
+ c.name = "My consumer"
+ c.description = "An app that makes ponies from the API."
+ c.user = some_user_object
+ c.generate_random_codes()
+ """
+ key = User.objects.make_random_password(length=KEY_SIZE)
+ secret = generate_random(SECRET_SIZE)
+
+ while Consumer.objects.filter(key__exact=key, secret__exact=secret).count():
+ secret = generate_random(SECRET_SIZE)
+
+ self.key = key
+ self.secret = secret
+ self.save()
+
+admin.site.register(Consumer)
+
+class Token(models.Model):
+ REQUEST = 1
+ ACCESS = 2
+ TOKEN_TYPES = ((REQUEST, u'Request'), (ACCESS, u'Access'))
+
+ key = models.CharField(max_length=KEY_SIZE)
+ secret = models.CharField(max_length=SECRET_SIZE)
+ verifier = models.CharField(max_length=VERIFIER_SIZE)
+ token_type = models.IntegerField(choices=TOKEN_TYPES)
+ timestamp = models.IntegerField(default=long(time.time()))
+ is_approved = models.BooleanField(default=False)
+
+ user = models.ForeignKey(User, null=True, blank=True, related_name='tokens')
+ consumer = models.ForeignKey(Consumer)
+
+ callback = models.CharField(max_length=255, null=True, blank=True)
+ callback_confirmed = models.BooleanField(default=False)
+
+ objects = TokenManager()
+
+ def __unicode__(self):
+ return u"%s Token %s for %s" % (self.get_token_type_display(), self.key, self.consumer)
+
+ def to_string(self, only_key=False):
+ token_dict = {
+ 'oauth_token': self.key,
+ 'oauth_token_secret': self.secret,
+ 'oauth_callback_confirmed': 'true',
+ }
+
+ if self.verifier:
+ token_dict.update({ 'oauth_verifier': self.verifier })
+
+ if only_key:
+ del token_dict['oauth_token_secret']
+
+ return urllib.urlencode(token_dict)
+
+ def generate_random_codes(self):
+ key = User.objects.make_random_password(length=KEY_SIZE)
+ secret = generate_random(SECRET_SIZE)
+
+ while Token.objects.filter(key__exact=key, secret__exact=secret).count():
+ secret = generate_random(SECRET_SIZE)
+
+ self.key = key
+ self.secret = secret
+ self.save()
+
+ # -- OAuth 1.0a stuff
+
+ def get_callback_url(self):
+ if self.callback and self.verifier:
+ # Append the oauth_verifier.
+ parts = urlparse.urlparse(self.callback)
+ scheme, netloc, path, params, query, fragment = parts[:6]
+ if query:
+ query = '%s&oauth_verifier=%s' % (query, self.verifier)
+ else:
+ query = 'oauth_verifier=%s' % self.verifier
+ return urlparse.urlunparse((scheme, netloc, path, params,
+ query, fragment))
+ return self.callback
+
+ def set_callback(self, callback):
+ if callback != "oob": # out of band, says "we can't do this!"
+ self.callback = callback
+ self.callback_confirmed = True
+ self.save()
+
+admin.site.register(Token)
+
+# Attach our signals
+post_save.connect(consumer_post_save, sender=Consumer)
+post_delete.connect(consumer_post_delete, sender=Consumer)
--- /dev/null
+"""
+The MIT License
+
+Copyright (c) 2007 Leah Culver
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import cgi
+import urllib
+import time
+import random
+import urlparse
+import hmac
+import binascii
+
+
+VERSION = '1.0' # Hi Blaine!
+HTTP_METHOD = 'GET'
+SIGNATURE_METHOD = 'PLAINTEXT'
+
+
+class OAuthError(RuntimeError):
+ """Generic exception class."""
+ def __init__(self, message='OAuth error occured.'):
+ self.message = message
+
+def build_authenticate_header(realm=''):
+ """Optional WWW-Authenticate header (401 error)"""
+ return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
+
+def escape(s):
+ """Escape a URL including any /."""
+ return urllib.quote(s, safe='~')
+
+def _utf8_str(s):
+ """Convert unicode to utf-8."""
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
+ else:
+ return str(s)
+
+def generate_timestamp():
+ """Get seconds since epoch (UTC)."""
+ return int(time.time())
+
+def generate_nonce(length=8):
+ """Generate pseudorandom number."""
+ return ''.join([str(random.randint(0, 9)) for i in range(length)])
+
+def generate_verifier(length=8):
+ """Generate pseudorandom number."""
+ return ''.join([str(random.randint(0, 9)) for i in range(length)])
+
+
+class OAuthConsumer(object):
+ """Consumer of OAuth authentication.
+
+ OAuthConsumer is a data type that represents the identity of the Consumer
+ via its shared secret with the Service Provider.
+
+ """
+ key = None
+ secret = None
+
+ def __init__(self, key, secret):
+ self.key = key
+ self.secret = secret
+
+
+class OAuthToken(object):
+ """OAuthToken is a data type that represents an End User via either an access
+ or request token.
+
+ key -- the token
+ secret -- the token secret
+
+ """
+ key = None
+ secret = None
+ callback = None
+ callback_confirmed = None
+ verifier = None
+
+ def __init__(self, key, secret):
+ self.key = key
+ self.secret = secret
+
+ def set_callback(self, callback):
+ self.callback = callback
+ self.callback_confirmed = 'true'
+
+ def set_verifier(self, verifier=None):
+ if verifier is not None:
+ self.verifier = verifier
+ else:
+ self.verifier = generate_verifier()
+
+ def get_callback_url(self):
+ if self.callback and self.verifier:
+ # Append the oauth_verifier.
+ parts = urlparse.urlparse(self.callback)
+ scheme, netloc, path, params, query, fragment = parts[:6]
+ if query:
+ query = '%s&oauth_verifier=%s' % (query, self.verifier)
+ else:
+ query = 'oauth_verifier=%s' % self.verifier
+ return urlparse.urlunparse((scheme, netloc, path, params,
+ query, fragment))
+ return self.callback
+
+ def to_string(self):
+ data = {
+ 'oauth_token': self.key,
+ 'oauth_token_secret': self.secret,
+ }
+ if self.callback_confirmed is not None:
+ data['oauth_callback_confirmed'] = self.callback_confirmed
+ return urllib.urlencode(data)
+
+ def from_string(s):
+ """ Returns a token from something like:
+ oauth_token_secret=xxx&oauth_token=xxx
+ """
+ params = cgi.parse_qs(s, keep_blank_values=False)
+ key = params['oauth_token'][0]
+ secret = params['oauth_token_secret'][0]
+ token = OAuthToken(key, secret)
+ try:
+ token.callback_confirmed = params['oauth_callback_confirmed'][0]
+ except KeyError:
+ pass # 1.0, no callback confirmed.
+ return token
+ from_string = staticmethod(from_string)
+
+ def __str__(self):
+ return self.to_string()
+
+
+class OAuthRequest(object):
+ """OAuthRequest represents the request and can be serialized.
+
+ OAuth parameters:
+ - oauth_consumer_key
+ - oauth_token
+ - oauth_signature_method
+ - oauth_signature
+ - oauth_timestamp
+ - oauth_nonce
+ - oauth_version
+ - oauth_verifier
+ ... any additional parameters, as defined by the Service Provider.
+ """
+ parameters = None # OAuth parameters.
+ http_method = HTTP_METHOD
+ http_url = None
+ version = VERSION
+
+ def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):
+ self.http_method = http_method
+ self.http_url = http_url
+ self.parameters = parameters or {}
+
+ def set_parameter(self, parameter, value):
+ self.parameters[parameter] = value
+
+ def get_parameter(self, parameter):
+ try:
+ return self.parameters[parameter]
+ except:
+ raise OAuthError('Parameter not found: %s' % parameter)
+
+ def _get_timestamp_nonce(self):
+ return self.get_parameter('oauth_timestamp'), self.get_parameter(
+ 'oauth_nonce')
+
+ def get_nonoauth_parameters(self):
+ """Get any non-OAuth parameters."""
+ parameters = {}
+ for k, v in self.parameters.iteritems():
+ # Ignore oauth parameters.
+ if k.find('oauth_') < 0:
+ parameters[k] = v
+ return parameters
+
+ def to_header(self, realm=''):
+ """Serialize as a header for an HTTPAuth request."""
+ auth_header = 'OAuth realm="%s"' % realm
+ # Add the oauth parameters.
+ if self.parameters:
+ for k, v in self.parameters.iteritems():
+ if k[:6] == 'oauth_':
+ auth_header += ', %s="%s"' % (k, escape(str(v)))
+ return {'Authorization': auth_header}
+
+ def to_postdata(self):
+ """Serialize as post data for a POST request."""
+ return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \
+ for k, v in self.parameters.iteritems()])
+
+ def to_url(self):
+ """Serialize as a URL for a GET request."""
+ return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
+
+ def get_normalized_parameters(self):
+ """Return a string that contains the parameters that must be signed."""
+ params = self.parameters
+ try:
+ # Exclude the signature if it exists.
+ del params['oauth_signature']
+ except:
+ pass
+ # Escape key values before sorting.
+ key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \
+ for k,v in params.items()]
+ # Sort lexicographically, first after key, then after value.
+ key_values.sort()
+ # Combine key value pairs into a string.
+ return '&'.join(['%s=%s' % (k, v) for k, v in key_values])
+
+ def get_normalized_http_method(self):
+ """Uppercases the http method."""
+ return self.http_method.upper()
+
+ def get_normalized_http_url(self):
+ """Parses the URL and rebuilds it to be scheme://host/path."""
+ parts = urlparse.urlparse(self.http_url)
+ scheme, netloc, path = parts[:3]
+ # Exclude default port numbers.
+ if scheme == 'http' and netloc[-3:] == ':80':
+ netloc = netloc[:-3]
+ elif scheme == 'https' and netloc[-4:] == ':443':
+ netloc = netloc[:-4]
+ return '%s://%s%s' % (scheme, netloc, path)
+
+ def sign_request(self, signature_method, consumer, token):
+ """Set the signature parameter to the result of build_signature."""
+ # Set the signature method.
+ self.set_parameter('oauth_signature_method',
+ signature_method.get_name())
+ # Set the signature.
+ self.set_parameter('oauth_signature',
+ self.build_signature(signature_method, consumer, token))
+
+ def build_signature(self, signature_method, consumer, token):
+ """Calls the build signature method within the signature method."""
+ return signature_method.build_signature(self, consumer, token)
+
+ def from_request(http_method, http_url, headers=None, parameters=None,
+ query_string=None):
+ """Combines multiple parameter sources."""
+ if parameters is None:
+ parameters = {}
+
+ # Headers
+ if headers and 'Authorization' in headers:
+ auth_header = headers['Authorization']
+ # Check that the authorization header is OAuth.
+ if auth_header[:6] == 'OAuth ':
+ auth_header = auth_header[6:]
+ try:
+ # Get the parameters from the header.
+ header_params = OAuthRequest._split_header(auth_header)
+ parameters.update(header_params)
+ except:
+ raise OAuthError('Unable to parse OAuth parameters from '
+ 'Authorization header.')
+
+ # GET or POST query string.
+ if query_string:
+ query_params = OAuthRequest._split_url_string(query_string)
+ parameters.update(query_params)
+
+ # URL parameters.
+ param_str = urlparse.urlparse(http_url)[4] # query
+ url_params = OAuthRequest._split_url_string(param_str)
+ parameters.update(url_params)
+
+ if parameters:
+ return OAuthRequest(http_method, http_url, parameters)
+
+ return None
+ from_request = staticmethod(from_request)
+
+ def from_consumer_and_token(oauth_consumer, token=None,
+ callback=None, verifier=None, http_method=HTTP_METHOD,
+ http_url=None, parameters=None):
+ if not parameters:
+ parameters = {}
+
+ defaults = {
+ 'oauth_consumer_key': oauth_consumer.key,
+ 'oauth_timestamp': generate_timestamp(),
+ 'oauth_nonce': generate_nonce(),
+ 'oauth_version': OAuthRequest.version,
+ }
+
+ defaults.update(parameters)
+ parameters = defaults
+
+ if token:
+ parameters['oauth_token'] = token.key
+ parameters['oauth_callback'] = token.callback
+ # 1.0a support for verifier.
+ parameters['oauth_verifier'] = verifier
+ elif callback:
+ # 1.0a support for callback in the request token request.
+ parameters['oauth_callback'] = callback
+
+ return OAuthRequest(http_method, http_url, parameters)
+ from_consumer_and_token = staticmethod(from_consumer_and_token)
+
+ def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,
+ http_url=None, parameters=None):
+ if not parameters:
+ parameters = {}
+
+ parameters['oauth_token'] = token.key
+
+ if callback:
+ parameters['oauth_callback'] = callback
+
+ return OAuthRequest(http_method, http_url, parameters)
+ from_token_and_callback = staticmethod(from_token_and_callback)
+
+ def _split_header(header):
+ """Turn Authorization: header into parameters."""
+ params = {}
+ parts = header.split(',')
+ for param in parts:
+ # Ignore realm parameter.
+ if param.find('realm') > -1:
+ continue
+ # Remove whitespace.
+ param = param.strip()
+ # Split key-value.
+ param_parts = param.split('=', 1)
+ # Remove quotes and unescape the value.
+ params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
+ return params
+ _split_header = staticmethod(_split_header)
+
+ def _split_url_string(param_str):
+ """Turn URL string into parameters."""
+ parameters = cgi.parse_qs(param_str, keep_blank_values=False)
+ for k, v in parameters.iteritems():
+ parameters[k] = urllib.unquote(v[0])
+ return parameters
+ _split_url_string = staticmethod(_split_url_string)
+
+class OAuthServer(object):
+ """A worker to check the validity of a request against a data store."""
+ timestamp_threshold = 300 # In seconds, five minutes.
+ version = VERSION
+ signature_methods = None
+ data_store = None
+
+ def __init__(self, data_store=None, signature_methods=None):
+ self.data_store = data_store
+ self.signature_methods = signature_methods or {}
+
+ def set_data_store(self, data_store):
+ self.data_store = data_store
+
+ def get_data_store(self):
+ return self.data_store
+
+ def add_signature_method(self, signature_method):
+ self.signature_methods[signature_method.get_name()] = signature_method
+ return self.signature_methods
+
+ def fetch_request_token(self, oauth_request):
+ """Processes a request_token request and returns the
+ request token on success.
+ """
+ try:
+ # Get the request token for authorization.
+ token = self._get_token(oauth_request, 'request')
+ except OAuthError:
+ # No token required for the initial token request.
+ version = self._get_version(oauth_request)
+ consumer = self._get_consumer(oauth_request)
+ try:
+ callback = self.get_callback(oauth_request)
+ except OAuthError:
+ callback = None # 1.0, no callback specified.
+ self._check_signature(oauth_request, consumer, None)
+ # Fetch a new token.
+ token = self.data_store.fetch_request_token(consumer, callback)
+ return token
+
+ def fetch_access_token(self, oauth_request):
+ """Processes an access_token request and returns the
+ access token on success.
+ """
+ version = self._get_version(oauth_request)
+ consumer = self._get_consumer(oauth_request)
+ verifier = self._get_verifier(oauth_request)
+ # Get the request token.
+ token = self._get_token(oauth_request, 'request')
+ self._check_signature(oauth_request, consumer, token)
+ new_token = self.data_store.fetch_access_token(consumer, token, verifier)
+ return new_token
+
+ def verify_request(self, oauth_request):
+ """Verifies an api call and checks all the parameters."""
+ # -> consumer and token
+ version = self._get_version(oauth_request)
+ consumer = self._get_consumer(oauth_request)
+ # Get the access token.
+ token = self._get_token(oauth_request, 'access')
+ self._check_signature(oauth_request, consumer, token)
+ parameters = oauth_request.get_nonoauth_parameters()
+ return consumer, token, parameters
+
+ def authorize_token(self, token, user):
+ """Authorize a request token."""
+ return self.data_store.authorize_request_token(token, user)
+
+ def get_callback(self, oauth_request):
+ """Get the callback URL."""
+ return oauth_request.get_parameter('oauth_callback')
+
+ def build_authenticate_header(self, realm=''):
+ """Optional support for the authenticate header."""
+ return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
+
+ def _get_version(self, oauth_request):
+ """Verify the correct version request for this server."""
+ try:
+ version = oauth_request.get_parameter('oauth_version')
+ except:
+ version = VERSION
+ if version and version != self.version:
+ raise OAuthError('OAuth version %s not supported.' % str(version))
+ return version
+
+ def _get_signature_method(self, oauth_request):
+ """Figure out the signature with some defaults."""
+ try:
+ signature_method = oauth_request.get_parameter(
+ 'oauth_signature_method')
+ except:
+ signature_method = SIGNATURE_METHOD
+ try:
+ # Get the signature method object.
+ signature_method = self.signature_methods[signature_method]
+ except:
+ signature_method_names = ', '.join(self.signature_methods.keys())
+ raise OAuthError('Signature method %s not supported try one of the '
+ 'following: %s' % (signature_method, signature_method_names))
+
+ return signature_method
+
+ def _get_consumer(self, oauth_request):
+ consumer_key = oauth_request.get_parameter('oauth_consumer_key')
+ consumer = self.data_store.lookup_consumer(consumer_key)
+ if not consumer:
+ raise OAuthError('Invalid consumer.')
+ return consumer
+
+ def _get_token(self, oauth_request, token_type='access'):
+ """Try to find the token for the provided request token key."""
+ token_field = oauth_request.get_parameter('oauth_token')
+ token = self.data_store.lookup_token(token_type, token_field)
+ if not token:
+ raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
+ return token
+
+ def _get_verifier(self, oauth_request):
+ return oauth_request.get_parameter('oauth_verifier')
+
+ def _check_signature(self, oauth_request, consumer, token):
+ timestamp, nonce = oauth_request._get_timestamp_nonce()
+ self._check_timestamp(timestamp)
+ self._check_nonce(consumer, token, nonce)
+ signature_method = self._get_signature_method(oauth_request)
+ try:
+ signature = oauth_request.get_parameter('oauth_signature')
+ except:
+ raise OAuthError('Missing signature.')
+ # Validate the signature.
+ valid_sig = signature_method.check_signature(oauth_request, consumer,
+ token, signature)
+ if not valid_sig:
+ key, base = signature_method.build_signature_base_string(
+ oauth_request, consumer, token)
+ raise OAuthError('Invalid signature. Expected signature base '
+ 'string: %s' % base)
+ built = signature_method.build_signature(oauth_request, consumer, token)
+
+ def _check_timestamp(self, timestamp):
+ """Verify that timestamp is recentish."""
+ timestamp = int(timestamp)
+ now = int(time.time())
+ lapsed = now - timestamp
+ if lapsed > self.timestamp_threshold:
+ raise OAuthError('Expired timestamp: given %d and now %s has a '
+ 'greater difference than threshold %d' %
+ (timestamp, now, self.timestamp_threshold))
+
+ def _check_nonce(self, consumer, token, nonce):
+ """Verify that the nonce is uniqueish."""
+ nonce = self.data_store.lookup_nonce(consumer, token, nonce)
+ if nonce:
+ raise OAuthError('Nonce already used: %s' % str(nonce))
+
+
+class OAuthClient(object):
+ """OAuthClient is a worker to attempt to execute a request."""
+ consumer = None
+ token = None
+
+ def __init__(self, oauth_consumer, oauth_token):
+ self.consumer = oauth_consumer
+ self.token = oauth_token
+
+ def get_consumer(self):
+ return self.consumer
+
+ def get_token(self):
+ return self.token
+
+ def fetch_request_token(self, oauth_request):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def fetch_access_token(self, oauth_request):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def access_resource(self, oauth_request):
+ """-> Some protected resource."""
+ raise NotImplementedError
+
+
+class OAuthDataStore(object):
+ """A database abstraction used to lookup consumers and tokens."""
+
+ def lookup_consumer(self, key):
+ """-> OAuthConsumer."""
+ raise NotImplementedError
+
+ def lookup_token(self, oauth_consumer, token_type, token_token):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def fetch_request_token(self, oauth_consumer, oauth_callback):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def fetch_access_token(self, oauth_consumer, oauth_token, oauth_verifier):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+ def authorize_request_token(self, oauth_token, user):
+ """-> OAuthToken."""
+ raise NotImplementedError
+
+
+class OAuthSignatureMethod(object):
+ """A strategy class that implements a signature method."""
+ def get_name(self):
+ """-> str."""
+ raise NotImplementedError
+
+ def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
+ """-> str key, str raw."""
+ raise NotImplementedError
+
+ def build_signature(self, oauth_request, oauth_consumer, oauth_token):
+ """-> str."""
+ raise NotImplementedError
+
+ def check_signature(self, oauth_request, consumer, token, signature):
+ built = self.build_signature(oauth_request, consumer, token)
+ return built == signature
+
+
+class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
+
+ def get_name(self):
+ return 'HMAC-SHA1'
+
+ def build_signature_base_string(self, oauth_request, consumer, token):
+ sig = (
+ escape(oauth_request.get_normalized_http_method()),
+ escape(oauth_request.get_normalized_http_url()),
+ escape(oauth_request.get_normalized_parameters()),
+ )
+
+ key = '%s&' % escape(consumer.secret)
+ if token:
+ key += escape(token.secret)
+ raw = '&'.join(sig)
+ return key, raw
+
+ def build_signature(self, oauth_request, consumer, token):
+ """Builds the base signature string."""
+ key, raw = self.build_signature_base_string(oauth_request, consumer,
+ token)
+
+ # HMAC object.
+ try:
+ import hashlib # 2.5
+ hashed = hmac.new(key, raw, hashlib.sha1)
+ except:
+ import sha # Deprecated
+ hashed = hmac.new(key, raw, sha)
+
+ # Calculate the digest base 64.
+ return binascii.b2a_base64(hashed.digest())[:-1]
+
+
+class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
+
+ def get_name(self):
+ return 'PLAINTEXT'
+
+ def build_signature_base_string(self, oauth_request, consumer, token):
+ """Concatenates the consumer key and secret."""
+ sig = '%s&' % escape(consumer.secret)
+ if token:
+ sig = sig + escape(token.secret)
+ return sig, sig
+
+ def build_signature(self, oauth_request, consumer, token):
+ key, raw = self.build_signature_base_string(oauth_request, consumer,
+ token)
+ return key
\ No newline at end of file
--- /dev/null
+import sys, inspect
+
+from django.http import (HttpResponse, Http404, HttpResponseNotAllowed,
+ HttpResponseForbidden, HttpResponseServerError)
+from django.views.debug import ExceptionReporter
+from django.views.decorators.vary import vary_on_headers
+from django.conf import settings
+from django.core.mail import send_mail, EmailMessage
+from django.db.models.query import QuerySet
+from django.http import Http404
+
+from emitters import Emitter
+from handler import typemapper
+from doc import HandlerMethod
+from authentication import NoAuthentication
+from utils import coerce_put_post, FormValidationError, HttpStatusCode
+from utils import rc, format_error, translate_mime, MimerDataException
+
+CHALLENGE = object()
+
+class Resource(object):
+ """
+ Resource. Create one for your URL mappings, just
+ like you would with Django. Takes one argument,
+ the handler. The second argument is optional, and
+ is an authentication handler. If not specified,
+ `NoAuthentication` will be used by default.
+ """
+ callmap = { 'GET': 'read', 'POST': 'create',
+ 'PUT': 'update', 'DELETE': 'delete' }
+
+ def __init__(self, handler, authentication=None):
+ if not callable(handler):
+ raise AttributeError, "Handler not callable."
+
+ self.handler = handler()
+
+ if not authentication:
+ self.authentication = (NoAuthentication(),)
+ elif isinstance(authentication, (list, tuple)):
+ self.authentication = authentication
+ else:
+ self.authentication = (authentication,)
+
+ # Erroring
+ self.email_errors = getattr(settings, 'PISTON_EMAIL_ERRORS', True)
+ self.display_errors = getattr(settings, 'PISTON_DISPLAY_ERRORS', True)
+ self.stream = getattr(settings, 'PISTON_STREAM_OUTPUT', False)
+
+ def determine_emitter(self, request, *args, **kwargs):
+ """
+ Function for determening which emitter to use
+ for output. It lives here so you can easily subclass
+ `Resource` in order to change how emission is detected.
+
+ You could also check for the `Accept` HTTP header here,
+ since that pretty much makes sense. Refer to `Mimer` for
+ that as well.
+ """
+ em = kwargs.pop('emitter_format', None)
+
+ if not em:
+ em = request.GET.get('format', 'json')
+
+ return em
+
+ @property
+ def anonymous(self):
+ """
+ Gets the anonymous handler. Also tries to grab a class
+ if the `anonymous` value is a string, so that we can define
+ anonymous handlers that aren't defined yet (like, when
+ you're subclassing your basehandler into an anonymous one.)
+ """
+ if hasattr(self.handler, 'anonymous'):
+ anon = self.handler.anonymous
+
+ if callable(anon):
+ return anon
+
+ for klass in typemapper.keys():
+ if anon == klass.__name__:
+ return klass
+
+ return None
+
+ def authenticate(self, request, rm):
+ actor, anonymous = False, True
+
+ for authenticator in self.authentication:
+ if not authenticator.is_authenticated(request):
+ if self.anonymous and \
+ rm in self.anonymous.allowed_methods:
+
+ actor, anonymous = self.anonymous(), True
+ else:
+ actor, anonymous = authenticator.challenge, CHALLENGE
+ else:
+ return self.handler, self.handler.is_anonymous
+
+ return actor, anonymous
+
+ @vary_on_headers('Authorization')
+ def __call__(self, request, *args, **kwargs):
+ """
+ NB: Sends a `Vary` header so we don't cache requests
+ that are different (OAuth stuff in `Authorization` header.)
+ """
+ rm = request.method.upper()
+
+ # Django's internal mechanism doesn't pick up
+ # PUT request, so we trick it a little here.
+ if rm == "PUT":
+ coerce_put_post(request)
+
+ actor, anonymous = self.authenticate(request, rm)
+
+ if anonymous is CHALLENGE:
+ return actor()
+ else:
+ handler = actor
+
+ # Translate nested datastructs into `request.data` here.
+ if rm in ('POST', 'PUT'):
+ try:
+ translate_mime(request)
+ except MimerDataException:
+ return rc.BAD_REQUEST
+
+ if not rm in handler.allowed_methods:
+ return HttpResponseNotAllowed(handler.allowed_methods)
+
+ meth = getattr(handler, self.callmap.get(rm), None)
+
+ if not meth:
+ raise Http404
+
+ # Support emitter both through (?P<emitter_format>) and ?format=emitter.
+ em_format = self.determine_emitter(request, *args, **kwargs)
+
+ kwargs.pop('emitter_format', None)
+
+ # Clean up the request object a bit, since we might
+ # very well have `oauth_`-headers in there, and we
+ # don't want to pass these along to the handler.
+ request = self.cleanup_request(request)
+
+ try:
+ result = meth(request, *args, **kwargs)
+ except FormValidationError, e:
+ resp = rc.BAD_REQUEST
+ resp.write(' '+str(e.form.errors))
+
+ return resp
+ except TypeError, e:
+ result = rc.BAD_REQUEST
+ hm = HandlerMethod(meth)
+ sig = hm.signature
+
+ msg = 'Method signature does not match.\n\n'
+
+ if sig:
+ msg += 'Signature should be: %s' % sig
+ else:
+ msg += 'Resource does not expect any parameters.'
+
+ if self.display_errors:
+ msg += '\n\nException was: %s' % str(e)
+
+ result.content = format_error(msg)
+ except Http404:
+ return rc.NOT_FOUND
+ except HttpStatusCode, e:
+ return e.response
+ except Exception, e:
+ """
+ On errors (like code errors), we'd like to be able to
+ give crash reports to both admins and also the calling
+ user. There's two setting parameters for this:
+
+ Parameters::
+ - `PISTON_EMAIL_ERRORS`: Will send a Django formatted
+ error email to people in `settings.ADMINS`.
+ - `PISTON_DISPLAY_ERRORS`: Will return a simple traceback
+ to the caller, so he can tell you what error they got.
+
+ If `PISTON_DISPLAY_ERRORS` is not enabled, the caller will
+ receive a basic "500 Internal Server Error" message.
+ """
+ exc_type, exc_value, tb = sys.exc_info()
+ rep = ExceptionReporter(request, exc_type, exc_value, tb.tb_next)
+ if self.email_errors:
+ self.email_exception(rep)
+ if self.display_errors:
+ return HttpResponseServerError(
+ format_error('\n'.join(rep.format_exception())))
+ else:
+ raise
+
+ emitter, ct = Emitter.get(em_format)
+ fields = handler.fields
+ if hasattr(handler, 'list_fields') and (
+ isinstance(result, list) or isinstance(result, QuerySet)):
+ fields = handler.list_fields
+
+ srl = emitter(result, typemapper, handler, fields, anonymous)
+
+ try:
+ """
+ Decide whether or not we want a generator here,
+ or we just want to buffer up the entire result
+ before sending it to the client. Won't matter for
+ smaller datasets, but larger will have an impact.
+ """
+ if self.stream: stream = srl.stream_render(request)
+ else: stream = srl.render(request)
+
+ if not isinstance(stream, HttpResponse):
+ resp = HttpResponse(stream, mimetype=ct)
+ else:
+ resp = stream
+
+ resp.streaming = self.stream
+
+ return resp
+ except HttpStatusCode, e:
+ return e.response
+
+ @staticmethod
+ def cleanup_request(request):
+ """
+ Removes `oauth_` keys from various dicts on the
+ request object, and returns the sanitized version.
+ """
+ for method_type in ('GET', 'PUT', 'POST', 'DELETE'):
+ block = getattr(request, method_type, { })
+
+ if True in [ k.startswith("oauth_") for k in block.keys() ]:
+ sanitized = block.copy()
+
+ for k in sanitized.keys():
+ if k.startswith("oauth_"):
+ sanitized.pop(k)
+
+ setattr(request, method_type, sanitized)
+
+ return request
+
+ # --
+
+ def email_exception(self, reporter):
+ subject = "Piston crash report"
+ html = reporter.get_traceback_html()
+
+ message = EmailMessage(settings.EMAIL_SUBJECT_PREFIX+subject,
+ html, settings.SERVER_EMAIL,
+ [ admin[1] for admin in settings.ADMINS ])
+
+ message.content_subtype = 'html'
+ message.send(fail_silently=True)
--- /dev/null
+# Django imports
+import django.dispatch
+
+# Piston imports
+from utils import send_consumer_mail
+
+def consumer_post_save(sender, instance, created, **kwargs):
+ send_consumer_mail(instance)
+
+def consumer_post_delete(sender, instance, **kwargs):
+ instance.status = 'canceled'
+ send_consumer_mail(instance)
+
+
--- /dev/null
+import oauth
+
+from models import Nonce, Token, Consumer
+from models import generate_random, VERIFIER_SIZE
+
+class DataStore(oauth.OAuthDataStore):
+ """Layer between Python OAuth and Django database."""
+ def __init__(self, oauth_request):
+ self.signature = oauth_request.parameters.get('oauth_signature', None)
+ self.timestamp = oauth_request.parameters.get('oauth_timestamp', None)
+ self.scope = oauth_request.parameters.get('scope', 'all')
+
+ def lookup_consumer(self, key):
+ try:
+ self.consumer = Consumer.objects.get(key=key)
+ return self.consumer
+ except Consumer.DoesNotExist:
+ return None
+
+ def lookup_token(self, token_type, token):
+ if token_type == 'request':
+ token_type = Token.REQUEST
+ elif token_type == 'access':
+ token_type = Token.ACCESS
+ try:
+ self.request_token = Token.objects.get(key=token,
+ token_type=token_type)
+ return self.request_token
+ except Token.DoesNotExist:
+ return None
+
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ if oauth_token is None:
+ return None
+ nonce, created = Nonce.objects.get_or_create(consumer_key=oauth_consumer.key,
+ token_key=oauth_token.key,
+ key=nonce)
+ if created:
+ return None
+ else:
+ return nonce.key
+
+ def fetch_request_token(self, oauth_consumer, oauth_callback):
+ if oauth_consumer.key == self.consumer.key:
+ self.request_token = Token.objects.create_token(consumer=self.consumer,
+ token_type=Token.REQUEST,
+ timestamp=self.timestamp)
+
+ if oauth_callback:
+ self.request_token.set_callback(oauth_callback)
+
+ return self.request_token
+ return None
+
+ def fetch_access_token(self, oauth_consumer, oauth_token, oauth_verifier):
+ if oauth_consumer.key == self.consumer.key \
+ and oauth_token.key == self.request_token.key \
+ and oauth_verifier == self.request_token.verifier \
+ and self.request_token.is_approved:
+ self.access_token = Token.objects.create_token(consumer=self.consumer,
+ token_type=Token.ACCESS,
+ timestamp=self.timestamp,
+ user=self.request_token.user)
+ return self.access_token
+ return None
+
+ def authorize_request_token(self, oauth_token, user):
+ if oauth_token.key == self.request_token.key:
+ # authorize the request token in the store
+ self.request_token.is_approved = True
+ self.request_token.user = user
+ self.request_token.verifier = generate_random(VERIFIER_SIZE)
+ self.request_token.save()
+ return self.request_token
+ return None
\ No newline at end of file
--- /dev/null
+{% load markup %}
+<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN"
+"http://www.w3.org/TR/html4/strict.dtd">
+<html>
+ <head>
+ <title>
+ Piston generated documentation
+ </title>
+ <style type="text/css">
+ body {
+ background: #fffff0;
+ font: 1em "Helvetica Neue", Verdana;
+ padding: 0 0 0 25px;
+ }
+ </style>
+ </head>
+ <body>
+ <h1>API Documentation</h1>
+
+ {% for doc in docs %}
+
+ <h3>{{ doc.name|cut:"Handler" }}:</h3>
+
+ <p>
+ {{ doc.get_doc|default:""|restructuredtext }}
+ </p>
+
+ <p>
+ URL: <b>{{ doc.get_resource_uri_template }}</b>
+ </p>
+
+ <p>
+ Accepted methods: {% for meth in doc.allowed_methods %}<b>{{ meth }}</b>{% if not forloop.last %}, {% endif %}{% endfor %}
+ </p>
+
+ <dl>
+ {% for method in doc.get_all_methods %}
+
+ <dt>
+ method <i>{{ method.name }}</i>({{ method.signature }}){% if method.stale %} <i>- inherited</i>{% else %}:{% endif %}
+
+ </dt>
+
+ {% if method.get_doc %}
+ <dd>
+ {{ method.get_doc|default:""|restructuredtext }}
+ <dd>
+ {% endif %}
+
+ {% endfor %}
+ </dl>
+
+ {% endfor %}
+ </body>
+</html>
--- /dev/null
+<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN"
+"http://www.w3.org/TR/html4/strict.dtd">
+<html>
+ <head>
+ <title>Authorize Token</title>
+ </head>
+ <body>
+ <h1>Authorize Token</h1>
+
+ <form action="{% url piston.authentication.oauth_user_auth %}" method="POST">
+ {{ form.as_table }}
+ </form>
+
+ </body>
+</html>
--- /dev/null
+# Django imports
+import django.test.client as client
+import django.test as test
+from django.utils.http import urlencode
+
+# Piston imports
+from piston import oauth
+from piston.models import Consumer, Token
+
+# 3rd/Python party imports
+import httplib2, urllib, cgi
+
+URLENCODED_FORM_CONTENT = 'application/x-www-form-urlencoded'
+
+class OAuthClient(client.Client):
+ def __init__(self, consumer, token):
+ self.token = oauth.OAuthToken(token.key, token.secret)
+ self.consumer = oauth.OAuthConsumer(consumer.key, consumer.secret)
+ self.signature = oauth.OAuthSignatureMethod_HMAC_SHA1()
+
+ super(OAuthClient, self).__init__()
+
+ def request(self, **request):
+ # Figure out parameters from request['QUERY_STRING'] and FakePayload
+ params = {}
+ if request['REQUEST_METHOD'] in ('POST', 'PUT'):
+ if request['CONTENT_TYPE'] == URLENCODED_FORM_CONTENT:
+ payload = request['wsgi.input'].read()
+ request['wsgi.input'] = client.FakePayload(payload)
+ params = cgi.parse_qs(payload)
+
+ url = "http://testserver" + request['PATH_INFO']
+
+ req = oauth.OAuthRequest.from_consumer_and_token(
+ self.consumer, token=self.token,
+ http_method=request['REQUEST_METHOD'], http_url=url,
+ parameters=params
+ )
+
+ req.sign_request(self.signature, self.consumer, self.token)
+ headers = req.to_header()
+ request['HTTP_AUTHORIZATION'] = headers['Authorization']
+
+ return super(OAuthClient, self).request(**request)
+
+ def post(self, path, data={}, content_type=None, follow=False, **extra):
+ if content_type is None:
+ content_type = URLENCODED_FORM_CONTENT
+
+ if isinstance(data, dict):
+ data = urlencode(data)
+
+ return super(OAuthClient, self).post(path, data, content_type, follow, **extra)
+
+class TestCase(test.TestCase):
+ pass
+
+class OAuthTestCase(TestCase):
+ @property
+ def oauth(self):
+ return OAuthClient(self.consumer, self.token)
+
--- /dev/null
+# Django imports
+from django.core import mail
+from django.contrib.auth.models import User
+from django.conf import settings
+
+# Piston imports
+from test import TestCase
+from models import Consumer
+
+class ConsumerTest(TestCase):
+ fixtures = ['models.json']
+
+ def setUp(self):
+ self.consumer = Consumer()
+ self.consumer.name = "Piston Test Consumer"
+ self.consumer.description = "A test consumer for Piston."
+ self.consumer.user = User.objects.get(pk=3)
+ self.consumer.generate_random_codes()
+
+ def test_create_pending(self):
+ """ Ensure creating a pending Consumer sends proper emails """
+ # If it's pending we should have two messages in the outbox; one
+ # to the consumer and one to the site admins.
+ if len(settings.ADMINS):
+ self.assertEquals(len(mail.outbox), 2)
+ else:
+ self.assertEquals(len(mail.outbox), 1)
+
+ expected = "Your API Consumer for example.com is awaiting approval."
+ self.assertEquals(mail.outbox[0].subject, expected)
+
+ def test_delete_consumer(self):
+ """ Ensure deleting a Consumer sends a cancel email """
+
+ # Clear out the outbox before we test for the cancel email.
+ mail.outbox = []
+
+ # Delete the consumer, which should fire off the cancel email.
+ self.consumer.delete()
+
+ self.assertEquals(len(mail.outbox), 1)
+ expected = "Your API Consumer for example.com has been canceled."
+ self.assertEquals(mail.outbox[0].subject, expected)
+
--- /dev/null
+import time
+from django.http import HttpResponseNotAllowed, HttpResponseForbidden, HttpResponse, HttpResponseBadRequest
+from django.core.urlresolvers import reverse
+from django.core.cache import cache
+from django import get_version as django_version
+from django.core.mail import send_mail, mail_admins
+from django.conf import settings
+from django.utils.translation import ugettext as _
+from django.template import loader, TemplateDoesNotExist
+from django.contrib.sites.models import Site
+from decorator import decorator
+
+from datetime import datetime, timedelta
+
+__version__ = '0.2.3rc1'
+
+def get_version():
+ return __version__
+
+def format_error(error):
+ return u"Piston/%s (Django %s) crash report:\n\n%s" % \
+ (get_version(), django_version(), error)
+
+class rc_factory(object):
+ """
+ Status codes.
+ """
+ CODES = dict(ALL_OK = ('OK', 200),
+ CREATED = ('Created', 201),
+ DELETED = ('', 204), # 204 says "Don't send a body!"
+ BAD_REQUEST = ('Bad Request', 400),
+ FORBIDDEN = ('Forbidden', 401),
+ NOT_FOUND = ('Not Found', 404),
+ DUPLICATE_ENTRY = ('Conflict/Duplicate', 409),
+ NOT_HERE = ('Gone', 410),
+ INTERNAL_ERROR = ('Internal Error', 500),
+ NOT_IMPLEMENTED = ('Not Implemented', 501),
+ THROTTLED = ('Throttled', 503))
+
+ def __getattr__(self, attr):
+ """
+ Returns a fresh `HttpResponse` when getting
+ an "attribute". This is backwards compatible
+ with 0.2, which is important.
+ """
+ try:
+ (r, c) = self.CODES.get(attr)
+ except TypeError:
+ raise AttributeError(attr)
+
+ return HttpResponse(r, content_type='text/plain', status=c)
+
+rc = rc_factory()
+
+class FormValidationError(Exception):
+ def __init__(self, form):
+ self.form = form
+
+class HttpStatusCode(Exception):
+ def __init__(self, response):
+ self.response = response
+
+def validate(v_form, operation='POST'):
+ @decorator
+ def wrap(f, self, request, *a, **kwa):
+ form = v_form(getattr(request, operation))
+
+ if form.is_valid():
+ return f(self, request, *a, **kwa)
+ else:
+ raise FormValidationError(form)
+ return wrap
+
+def throttle(max_requests, timeout=60*60, extra=''):
+ """
+ Simple throttling decorator, caches
+ the amount of requests made in cache.
+
+ If used on a view where users are required to
+ log in, the username is used, otherwise the
+ IP address of the originating request is used.
+
+ Parameters::
+ - `max_requests`: The maximum number of requests
+ - `timeout`: The timeout for the cache entry (default: 1 hour)
+ """
+ @decorator
+ def wrap(f, self, request, *args, **kwargs):
+ if request.user.is_authenticated():
+ ident = request.user.username
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ if hasattr(request, 'throttle_extra'):
+ """
+ Since we want to be able to throttle on a per-
+ application basis, it's important that we realize
+ that `throttle_extra` might be set on the request
+ object. If so, append the identifier name with it.
+ """
+ ident += ':%s' % str(request.throttle_extra)
+
+ if ident:
+ """
+ Preferrably we'd use incr/decr here, since they're
+ atomic in memcached, but it's in django-trunk so we
+ can't use it yet. If someone sees this after it's in
+ stable, you can change it here.
+ """
+ ident += ':%s' % extra
+
+ now = time.time()
+ count, expiration = cache.get(ident, (1, None))
+
+ if expiration is None:
+ expiration = now + timeout
+
+ if count >= max_requests and expiration > now:
+ t = rc.THROTTLED
+ wait = int(expiration - now)
+ t.content = 'Throttled, wait %d seconds.' % wait
+ t['Retry-After'] = wait
+ return t
+
+ cache.set(ident, (count+1, expiration), (expiration - now))
+
+ return f(self, request, *args, **kwargs)
+ return wrap
+
+def coerce_put_post(request):
+ """
+ Django doesn't particularly understand REST.
+ In case we send data over PUT, Django won't
+ actually look at the data and load it. We need
+ to twist its arm here.
+
+ The try/except abominiation here is due to a bug
+ in mod_python. This should fix it.
+ """
+ if request.method == "PUT":
+ try:
+ request.method = "POST"
+ request._load_post_and_files()
+ request.method = "PUT"
+ except AttributeError:
+ request.META['REQUEST_METHOD'] = 'POST'
+ request._load_post_and_files()
+ request.META['REQUEST_METHOD'] = 'PUT'
+
+ request.PUT = request.POST
+
+
+class MimerDataException(Exception):
+ """
+ Raised if the content_type and data don't match
+ """
+ pass
+
+class Mimer(object):
+ TYPES = dict()
+
+ def __init__(self, request):
+ self.request = request
+
+ def is_multipart(self):
+ content_type = self.content_type()
+
+ if content_type is not None:
+ return content_type.lstrip().startswith('multipart')
+
+ return False
+
+ def loader_for_type(self, ctype):
+ """
+ Gets a function ref to deserialize content
+ for a certain mimetype.
+ """
+ for loadee, mimes in Mimer.TYPES.iteritems():
+ for mime in mimes:
+ if ctype.startswith(mime):
+ return loadee
+
+ def content_type(self):
+ """
+ Returns the content type of the request in all cases where it is
+ different than a submitted form - application/x-www-form-urlencoded
+ """
+ type_formencoded = "application/x-www-form-urlencoded"
+
+ ctype = self.request.META.get('CONTENT_TYPE', type_formencoded)
+
+ if type_formencoded in ctype:
+ return None
+
+ return ctype
+
+ def translate(self):
+ """
+ Will look at the `Content-type` sent by the client, and maybe
+ deserialize the contents into the format they sent. This will
+ work for JSON, YAML, XML and Pickle. Since the data is not just
+ key-value (and maybe just a list), the data will be placed on
+ `request.data` instead, and the handler will have to read from
+ there.
+
+ It will also set `request.content_type` so the handler has an easy
+ way to tell what's going on. `request.content_type` will always be
+ None for form-encoded and/or multipart form data (what your browser sends.)
+ """
+ ctype = self.content_type()
+ self.request.content_type = ctype
+
+ if not self.is_multipart() and ctype:
+ loadee = self.loader_for_type(ctype)
+
+ if loadee:
+ try:
+ self.request.data = loadee(self.request.raw_post_data)
+
+ # Reset both POST and PUT from request, as its
+ # misleading having their presence around.
+ self.request.POST = self.request.PUT = dict()
+ except (TypeError, ValueError):
+ # This also catches if loadee is None.
+ raise MimerDataException
+ else:
+ self.request.data = None
+
+ return self.request
+
+ @classmethod
+ def register(cls, loadee, types):
+ cls.TYPES[loadee] = types
+
+ @classmethod
+ def unregister(cls, loadee):
+ return cls.TYPES.pop(loadee)
+
+def translate_mime(request):
+ request = Mimer(request).translate()
+
+def require_mime(*mimes):
+ """
+ Decorator requiring a certain mimetype. There's a nifty
+ helper called `require_extended` below which requires everything
+ we support except for post-data via form.
+ """
+ @decorator
+ def wrap(f, self, request, *args, **kwargs):
+ m = Mimer(request)
+ realmimes = set()
+
+ rewrite = { 'json': 'application/json',
+ 'yaml': 'application/x-yaml',
+ 'xml': 'text/xml',
+ 'pickle': 'application/python-pickle' }
+
+ for idx, mime in enumerate(mimes):
+ realmimes.add(rewrite.get(mime, mime))
+
+ if not m.content_type() in realmimes:
+ return rc.BAD_REQUEST
+
+ return f(self, request, *args, **kwargs)
+ return wrap
+
+require_extended = require_mime('json', 'yaml', 'xml', 'pickle')
+
+def send_consumer_mail(consumer):
+ """
+ Send a consumer an email depending on what their status is.
+ """
+ try:
+ subject = settings.PISTON_OAUTH_EMAIL_SUBJECTS[consumer.status]
+ except AttributeError:
+ subject = "Your API Consumer for %s " % Site.objects.get_current().name
+ if consumer.status == "accepted":
+ subject += "was accepted!"
+ elif consumer.status == "canceled":
+ subject += "has been canceled."
+ elif consumer.status == "rejected":
+ subject += "has been rejected."
+ else:
+ subject += "is awaiting approval."
+
+ template = "piston/mails/consumer_%s.txt" % consumer.status
+
+ try:
+ body = loader.render_to_string(template,
+ { 'consumer' : consumer, 'user' : consumer.user })
+ except TemplateDoesNotExist:
+ """
+ They haven't set up the templates, which means they might not want
+ these emails sent.
+ """
+ return
+
+ try:
+ sender = settings.PISTON_FROM_EMAIL
+ except AttributeError:
+ sender = settings.DEFAULT_FROM_EMAIL
+
+ if consumer.user:
+ send_mail(_(subject), body, sender, [consumer.user.email], fail_silently=True)
+
+ if consumer.status == 'pending' and len(settings.ADMINS):
+ mail_admins(_(subject), body, fail_silently=True)
+
+ if settings.DEBUG and consumer.user:
+ print "Mail being sent, to=%s" % consumer.user.email
+ print "Subject: %s" % _(subject)
+ print body
+