From: zuber Date: Mon, 5 Oct 2009 14:39:52 +0000 (+0200) Subject: Added support for custom attributes from CAS version 3.0 (see setting CAS_CUSTOM_ATTR... X-Git-Tag: 22.4~32^2~18^2~1^2~9 X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/commitdiff_plain/bf06f0f8cde976e4e579b059fdb0e8a8bc279f01?ds=sidebyside;hp=5392ec6f8245dd4ecad27bba8a610dc198e537b6 Added support for custom attributes from CAS version 3.0 (see setting CAS_CUSTOM_ATTRIBUTES_CALLBACK). --- diff --git a/cas_provider/__init__.py b/cas_provider/__init__.py index 1a719b4..b70697c 100644 --- a/cas_provider/__init__.py +++ b/cas_provider/__init__.py @@ -4,6 +4,7 @@ __all__ = [] _DEFAULTS = { 'CAS_TICKET_EXPIRATION': 5, # In minutes + 'CAS_CUSTOM_ATTRIBUTES_CALLBACK': None, } for key, value in _DEFAULTS.iteritems(): diff --git a/cas_provider/models.py b/cas_provider/models.py index 94023b7..eae9cf6 100644 --- a/cas_provider/models.py +++ b/cas_provider/models.py @@ -1,5 +1,9 @@ from django.db import models from django.contrib.auth.models import User +from django.conf import settings +from django.core.urlresolvers import get_callable + +from cas_provider.xml import etree, register_namespace, ElementRoot class ServiceTicket(models.Model): user = models.ForeignKey(User) @@ -15,4 +19,23 @@ class LoginTicket(models.Model): created = models.DateTimeField(auto_now=True) def __unicode__(self): - return "%s - %s" % (self.ticket, self.created) \ No newline at end of file + return "%s - %s" % (self.ticket, self.created) + +CAS_URI = 'http://www.yale.edu/tp/cas' +register_namespace('cas', CAS_URI) +CAS = '{%s}' % CAS_URI + +def auth_success_response(user): + attrs = {} + if settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK: + callback = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK) + attrs = callback(user) + + response = ElementRoot(CAS + 'serviceResponse') + auth_success = etree.SubElement(response, CAS + 'authenticationSuccess') + username = etree.SubElement(auth_success, CAS + 'user') + username.text = user.username + for name, value in attrs.items(): + element = etree.SubElement(auth_success, name) + element.text = value + return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8') diff --git a/cas_provider/views.py b/cas_provider/views.py index cdd5d7b..3ee5d7c 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -6,7 +6,7 @@ from django.contrib.auth import authenticate from django.contrib.auth import login as auth_login, logout as auth_logout from cas_provider.forms import LoginForm -from cas_provider.models import ServiceTicket, LoginTicket +from cas_provider.models import ServiceTicket, LoginTicket, auth_success_response from cas_provider.utils import create_service_ticket __all__ = ['login', 'validate', 'service_validate', 'logout'] @@ -76,15 +76,10 @@ def service_validate(request): try: ticket = ServiceTicket.objects.get(ticket=ticket_string) - username = ticket.user.username - ticket.delete() - return HttpResponse(''' - - %(username)s - - ''' % {'username': username}, mimetype='text/xml') + # ticket.delete() + return HttpResponse(auth_success_response(ticket.user), mimetype='text/xml') except ServiceTicket.DoesNotExist: - return HttpResponse('''' + return HttpResponse(''' The provided ticket is invalid. diff --git a/cas_provider/xml.py b/cas_provider/xml.py new file mode 100644 index 0000000..012fb41 --- /dev/null +++ b/cas_provider/xml.py @@ -0,0 +1,39 @@ +# Import etree from anywhere +try: + # lxml http://codespeak.net/lxml/ + from lxml import etree + + # Define register_namespace function and ElementRoot for proper serialization + NSMAP = {} + def register_namespace(prefix, uri): + NSMAP[prefix] = uri + + def ElementRoot(*args, **kwargs): + return etree.Element(*args, nsmap=NSMAP, **kwargs) + +except ImportError: + try: + # Python 2.5 + import xml.etree.cElementTree as etree + except ImportError: + try: + # Python 2.5 + import xml.etree.ElementTree as etree + except ImportError: + try: + # normal cElementTree install + import cElementTree as etree + except ImportError: + # normal ElementTree install + import elementtree.ElementTree as etree + + try: + register_namespace = ET.register_namespace + except AttributeError: + def register_namespace(prefix, uri): + ET._namespace_map[uri] = prefix + + def ElementRoot(*args, **kwargs): + return etree.Element(*args, **kwargs) + +__all__ = ('etree', 'register_namespace', 'ElementRoot')