X-Git-Url: https://git.mdrn.pl/redakcja.git/blobdiff_plain/5780495e13ec2d55bc2dee96dec372a9ea395462..3e4d96cd7fdd08910887be91ead79e446a96fe53:/apps/django_cas/backends.py diff --git a/apps/django_cas/backends.py b/apps/django_cas/backends.py index d55c9db3..4cf9f625 100644 --- a/apps/django_cas/backends.py +++ b/apps/django_cas/backends.py @@ -11,7 +11,7 @@ __all__ = ['CASBackend'] def _verify_cas1(ticket, service): """Verifies CAS 1.0 authentication ticket. - Returns username on success and None on failure. + Returns (username, None) on success and (None, None) on failure. """ params = {'ticket': ticket, 'service': service} @@ -21,9 +21,9 @@ def _verify_cas1(ticket, service): try: verified = page.readline().strip() if verified == 'yes': - return page.readline().strip() + return page.readline().strip(), None else: - return None + return None, None finally: page.close() @@ -31,7 +31,7 @@ def _verify_cas1(ticket, service): def _verify_cas2(ticket, service): """Verifies CAS 2.0+ XML-based authentication ticket. - Returns username on success and None on failure. + Returns (username, attr_dict) on success and (None, None) on failure. """ try: @@ -47,9 +47,12 @@ def _verify_cas2(ticket, service): response = page.read() tree = ElementTree.fromstring(response) if tree[0].tag.endswith('authenticationSuccess'): - return tree[0][0].text + attrs = {} + for tag in tree[0][1:]: + attrs[tag.tag] = tag.text + return tree[0][0].text, attrs else: - return None + return None, None except: import traceback traceback.print_exc() @@ -74,14 +77,34 @@ class CASBackend(object): def authenticate(self, ticket, service): """Verifies CAS ticket and gets or creates User object""" - username = _verify(ticket, service) + username, attrs = _verify(ticket, service) if not username: return None + + user_attrs = {} + if hasattr(settings, 'CAS_USER_ATTRS_MAP'): + attr_map = settings.CAS_USER_ATTRS_MAP + for k, v in attrs.items(): + if k in attr_map: + user_attrs[attr_map[k]] = v # unicode(v, 'utf-8') + try: user = User.objects.get(username__iexact=username) + # update user info + changed = False + for k, v in user_attrs.items(): + if getattr(user, k) != v: + setattr(user, k, v) + changed = True + if changed: + user.save() except User.DoesNotExist: # user will have an "unusable" password user = User.objects.create_user(username, '') + for k, v in user_attrs.items(): + setattr(user, k, v) + user.first_name = attrs.get('firstname', '') + user.last_name = attrs.get('lastname', '') user.save() return user