fix
[redakcja.git] / apps / django_cas / backends.py
index d55c9db..4cf9f62 100644 (file)
@@ -11,7 +11,7 @@ __all__ = ['CASBackend']
 def _verify_cas1(ticket, service):
     """Verifies CAS 1.0 authentication ticket.
 
-    Returns username on success and None on failure.
+    Returns (username, None) on success and (None, None) on failure.
     """
 
     params = {'ticket': ticket, 'service': service}
@@ -21,9 +21,9 @@ def _verify_cas1(ticket, service):
     try:
         verified = page.readline().strip()
         if verified == 'yes':
-            return page.readline().strip()
+            return page.readline().strip(), None
         else:
-            return None
+            return None, None
     finally:
         page.close()
 
@@ -31,7 +31,7 @@ def _verify_cas1(ticket, service):
 def _verify_cas2(ticket, service):
     """Verifies CAS 2.0+ XML-based authentication ticket.
 
-    Returns username on success and None on failure.
+    Returns (username, attr_dict) on success and (None, None) on failure.
     """
 
     try:
@@ -47,9 +47,12 @@ def _verify_cas2(ticket, service):
         response = page.read()
         tree = ElementTree.fromstring(response)
         if tree[0].tag.endswith('authenticationSuccess'):
-            return tree[0][0].text
+            attrs = {}
+            for tag in tree[0][1:]:
+                attrs[tag.tag] = tag.text
+            return tree[0][0].text, attrs
         else:
-            return None
+            return None, None
     except:
         import traceback
         traceback.print_exc()
@@ -74,14 +77,34 @@ class CASBackend(object):
     def authenticate(self, ticket, service):
         """Verifies CAS ticket and gets or creates User object"""
 
-        username = _verify(ticket, service)
+        username, attrs = _verify(ticket, service)
         if not username:
             return None
+
+        user_attrs = {}
+        if hasattr(settings, 'CAS_USER_ATTRS_MAP'):
+            attr_map = settings.CAS_USER_ATTRS_MAP
+            for k, v in attrs.items():
+                if k in attr_map:
+                    user_attrs[attr_map[k]] = v # unicode(v, 'utf-8')
+
         try:
             user = User.objects.get(username__iexact=username)
+            # update user info
+            changed = False
+            for k, v in user_attrs.items():
+                if getattr(user, k) != v:
+                    setattr(user, k, v)
+                    changed = True
+            if changed:
+                user.save()
         except User.DoesNotExist:
             # user will have an "unusable" password
             user = User.objects.create_user(username, '')
+            for k, v in user_attrs.items():
+                setattr(user, k, v)
+            user.first_name = attrs.get('firstname', '')
+            user.last_name = attrs.get('lastname', '')
             user.save()
         return user