X-Git-Url: https://git.mdrn.pl/redakcja.git/blobdiff_plain/b0d77fd4dd2b177e77e2bb038a5864567adfd9df..ae4e892b4aae2778761e0a75378055237c9e1308:/apps/django_cas/backends.py diff --git a/apps/django_cas/backends.py b/apps/django_cas/backends.py index f14619d0..4cf9f625 100644 --- a/apps/django_cas/backends.py +++ b/apps/django_cas/backends.py @@ -7,10 +7,11 @@ from django_cas.models import User __all__ = ['CASBackend'] + def _verify_cas1(ticket, service): """Verifies CAS 1.0 authentication ticket. - Returns username on success and None on failure. + Returns (username, None) on success and (None, None) on failure. """ params = {'ticket': ticket, 'service': service} @@ -20,9 +21,9 @@ def _verify_cas1(ticket, service): try: verified = page.readline().strip() if verified == 'yes': - return page.readline().strip() + return page.readline().strip(), None else: - return None + return None, None finally: page.close() @@ -30,7 +31,7 @@ def _verify_cas1(ticket, service): def _verify_cas2(ticket, service): """Verifies CAS 2.0+ XML-based authentication ticket. - Returns username on success and None on failure. + Returns (username, attr_dict) on success and (None, None) on failure. """ try: @@ -46,13 +47,16 @@ def _verify_cas2(ticket, service): response = page.read() tree = ElementTree.fromstring(response) if tree[0].tag.endswith('authenticationSuccess'): - return tree[0][0].text + attrs = {} + for tag in tree[0][1:]: + attrs[tag.tag] = tag.text + return tree[0][0].text, attrs else: - return None + return None, None except: import traceback traceback.print_exc() - print "****" + print "****", url print response print "****" finally: @@ -73,14 +77,34 @@ class CASBackend(object): def authenticate(self, ticket, service): """Verifies CAS ticket and gets or creates User object""" - username = _verify(ticket, service) + username, attrs = _verify(ticket, service) if not username: return None + + user_attrs = {} + if hasattr(settings, 'CAS_USER_ATTRS_MAP'): + attr_map = settings.CAS_USER_ATTRS_MAP + for k, v in attrs.items(): + if k in attr_map: + user_attrs[attr_map[k]] = v # unicode(v, 'utf-8') + try: - user = User.objects.get(username__iexact = username) + user = User.objects.get(username__iexact=username) + # update user info + changed = False + for k, v in user_attrs.items(): + if getattr(user, k) != v: + setattr(user, k, v) + changed = True + if changed: + user.save() except User.DoesNotExist: # user will have an "unusable" password user = User.objects.create_user(username, '') + for k, v in user_attrs.items(): + setattr(user, k, v) + user.first_name = attrs.get('firstname', '') + user.last_name = attrs.get('lastname', '') user.save() return user @@ -88,6 +112,6 @@ class CASBackend(object): """Retrieve the user's entry in the User model if it exists""" try: - return User.objects.get(pk = user_id) + return User.objects.get(pk=user_id) except User.DoesNotExist: return None