import script
[redakcja.git] / apps / django_cas / backends.py
index f14619d..4cf9f62 100644 (file)
@@ -7,10 +7,11 @@ from django_cas.models import User
 
 __all__ = ['CASBackend']
 
 
 __all__ = ['CASBackend']
 
+
 def _verify_cas1(ticket, service):
     """Verifies CAS 1.0 authentication ticket.
 
 def _verify_cas1(ticket, service):
     """Verifies CAS 1.0 authentication ticket.
 
-    Returns username on success and None on failure.
+    Returns (username, None) on success and (None, None) on failure.
     """
 
     params = {'ticket': ticket, 'service': service}
     """
 
     params = {'ticket': ticket, 'service': service}
@@ -20,9 +21,9 @@ def _verify_cas1(ticket, service):
     try:
         verified = page.readline().strip()
         if verified == 'yes':
     try:
         verified = page.readline().strip()
         if verified == 'yes':
-            return page.readline().strip()
+            return page.readline().strip(), None
         else:
         else:
-            return None
+            return None, None
     finally:
         page.close()
 
     finally:
         page.close()
 
@@ -30,7 +31,7 @@ def _verify_cas1(ticket, service):
 def _verify_cas2(ticket, service):
     """Verifies CAS 2.0+ XML-based authentication ticket.
 
 def _verify_cas2(ticket, service):
     """Verifies CAS 2.0+ XML-based authentication ticket.
 
-    Returns username on success and None on failure.
+    Returns (username, attr_dict) on success and (None, None) on failure.
     """
 
     try:
     """
 
     try:
@@ -46,13 +47,16 @@ def _verify_cas2(ticket, service):
         response = page.read()
         tree = ElementTree.fromstring(response)
         if tree[0].tag.endswith('authenticationSuccess'):
         response = page.read()
         tree = ElementTree.fromstring(response)
         if tree[0].tag.endswith('authenticationSuccess'):
-            return tree[0][0].text
+            attrs = {}
+            for tag in tree[0][1:]:
+                attrs[tag.tag] = tag.text
+            return tree[0][0].text, attrs
         else:
         else:
-            return None
+            return None, None
     except:
         import traceback
         traceback.print_exc()
     except:
         import traceback
         traceback.print_exc()
-        print "****"
+        print "****", url
         print response
         print "****"
     finally:
         print response
         print "****"
     finally:
@@ -73,14 +77,34 @@ class CASBackend(object):
     def authenticate(self, ticket, service):
         """Verifies CAS ticket and gets or creates User object"""
 
     def authenticate(self, ticket, service):
         """Verifies CAS ticket and gets or creates User object"""
 
-        username = _verify(ticket, service)
+        username, attrs = _verify(ticket, service)
         if not username:
             return None
         if not username:
             return None
+
+        user_attrs = {}
+        if hasattr(settings, 'CAS_USER_ATTRS_MAP'):
+            attr_map = settings.CAS_USER_ATTRS_MAP
+            for k, v in attrs.items():
+                if k in attr_map:
+                    user_attrs[attr_map[k]] = v # unicode(v, 'utf-8')
+
         try:
         try:
-            user = User.objects.get(username__iexact = username)
+            user = User.objects.get(username__iexact=username)
+            # update user info
+            changed = False
+            for k, v in user_attrs.items():
+                if getattr(user, k) != v:
+                    setattr(user, k, v)
+                    changed = True
+            if changed:
+                user.save()
         except User.DoesNotExist:
             # user will have an "unusable" password
             user = User.objects.create_user(username, '')
         except User.DoesNotExist:
             # user will have an "unusable" password
             user = User.objects.create_user(username, '')
+            for k, v in user_attrs.items():
+                setattr(user, k, v)
+            user.first_name = attrs.get('firstname', '')
+            user.last_name = attrs.get('lastname', '')
             user.save()
         return user
 
             user.save()
         return user
 
@@ -88,6 +112,6 @@ class CASBackend(object):
         """Retrieve the user's entry in the User model if it exists"""
 
         try:
         """Retrieve the user's entry in the User model if it exists"""
 
         try:
-            return User.objects.get(pk = user_id)
+            return User.objects.get(pk=user_id)
         except User.DoesNotExist:
             return None
         except User.DoesNotExist:
             return None