Added support for custom attributes from CAS version 3.0 (see setting CAS_CUSTOM_ATTR...
authorzuber <marek@stepniowski.com>
Mon, 5 Oct 2009 14:39:52 +0000 (16:39 +0200)
committerzuber <marek@stepniowski.com>
Mon, 5 Oct 2009 14:39:52 +0000 (16:39 +0200)
cas_provider/__init__.py
cas_provider/models.py
cas_provider/views.py
cas_provider/xml.py [new file with mode: 0644]

index 1a719b4..b70697c 100644 (file)
@@ -4,6 +4,7 @@ __all__ = []
 
 _DEFAULTS = {
     'CAS_TICKET_EXPIRATION': 5, # In minutes
+    'CAS_CUSTOM_ATTRIBUTES_CALLBACK': None,
 }
 
 for key, value in _DEFAULTS.iteritems():
index 94023b7..eae9cf6 100644 (file)
@@ -1,5 +1,9 @@
 from django.db import models
 from django.contrib.auth.models import User
+from django.conf import settings
+from django.core.urlresolvers import get_callable
+
+from cas_provider.xml import etree, register_namespace, ElementRoot
 
 class ServiceTicket(models.Model):
     user = models.ForeignKey(User)
@@ -15,4 +19,23 @@ class LoginTicket(models.Model):
     created = models.DateTimeField(auto_now=True)
     
     def __unicode__(self):
-        return "%s - %s" % (self.ticket, self.created)
\ No newline at end of file
+        return "%s - %s" % (self.ticket, self.created)
+
+CAS_URI = 'http://www.yale.edu/tp/cas'
+register_namespace('cas', CAS_URI)
+CAS = '{%s}' % CAS_URI
+
+def auth_success_response(user):
+    attrs = {}
+    if settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK:
+        callback = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK)
+        attrs = callback(user)
+    
+    response = ElementRoot(CAS + 'serviceResponse')
+    auth_success = etree.SubElement(response, CAS + 'authenticationSuccess')
+    username = etree.SubElement(auth_success, CAS + 'user')
+    username.text = user.username
+    for name, value in attrs.items():
+        element = etree.SubElement(auth_success, name)
+        element.text = value
+    return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')
index cdd5d7b..3ee5d7c 100644 (file)
@@ -6,7 +6,7 @@ from django.contrib.auth import authenticate
 from django.contrib.auth import login as auth_login, logout as auth_logout
 
 from cas_provider.forms import LoginForm
-from cas_provider.models import ServiceTicket, LoginTicket
+from cas_provider.models import ServiceTicket, LoginTicket, auth_success_response
 from cas_provider.utils import create_service_ticket
 
 __all__ = ['login', 'validate', 'service_validate', 'logout']
@@ -76,15 +76,10 @@ def service_validate(request):
     
     try:
         ticket = ServiceTicket.objects.get(ticket=ticket_string)
-        username = ticket.user.username
-        ticket.delete()
-        return HttpResponse('''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
-            <cas:authenticationSuccess>
-                <cas:user>%(username)s</cas:user>
-            </cas:authenticationSuccess>
-        </cas:serviceResponse>''' % {'username': username}, mimetype='text/xml')
+        # ticket.delete()
+        return HttpResponse(auth_success_response(ticket.user), mimetype='text/xml')
     except ServiceTicket.DoesNotExist:
-        return HttpResponse(''''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
+        return HttpResponse('''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
             <cas:authenticationFailure code="INVALID_TICKET">
                 The provided ticket is invalid.
             </cas:authenticationFailure>
diff --git a/cas_provider/xml.py b/cas_provider/xml.py
new file mode 100644 (file)
index 0000000..012fb41
--- /dev/null
@@ -0,0 +1,39 @@
+# Import etree from anywhere
+try:
+    # lxml http://codespeak.net/lxml/
+    from lxml import etree
+    
+    # Define register_namespace function and ElementRoot for proper serialization
+    NSMAP = {}
+    def register_namespace(prefix, uri):
+        NSMAP[prefix] = uri
+    
+    def ElementRoot(*args, **kwargs):
+        return etree.Element(*args, nsmap=NSMAP, **kwargs)
+
+except ImportError:
+    try:
+        # Python 2.5
+        import xml.etree.cElementTree as etree
+    except ImportError:
+        try:
+            # Python 2.5
+            import xml.etree.ElementTree as etree
+        except ImportError:
+            try:
+                # normal cElementTree install
+                import cElementTree as etree
+            except ImportError:
+                # normal ElementTree install
+                import elementtree.ElementTree as etree
+
+    try:
+        register_namespace = ET.register_namespace
+    except AttributeError:
+        def register_namespace(prefix, uri):
+            ET._namespace_map[uri] = prefix
+
+    def ElementRoot(*args, **kwargs):
+        return etree.Element(*args, **kwargs)
+
+__all__ = ('etree', 'register_namespace', 'ElementRoot')