import script
[redakcja.git] / apps / django_cas / backends.py
1 """CAS authentication backend"""
2
3 from urllib import urlencode, urlopen
4 from urlparse import urljoin
5 from django.conf import settings
6 from django_cas.models import User
7
8 __all__ = ['CASBackend']
9
10
11 def _verify_cas1(ticket, service):
12     """Verifies CAS 1.0 authentication ticket.
13
14     Returns (username, None) on success and (None, None) on failure.
15     """
16
17     params = {'ticket': ticket, 'service': service}
18     url = (urljoin(settings.CAS_SERVER_URL, 'validate') + '?' +
19            urlencode(params))
20     page = urlopen(url)
21     try:
22         verified = page.readline().strip()
23         if verified == 'yes':
24             return page.readline().strip(), None
25         else:
26             return None, None
27     finally:
28         page.close()
29
30
31 def _verify_cas2(ticket, service):
32     """Verifies CAS 2.0+ XML-based authentication ticket.
33
34     Returns (username, attr_dict) on success and (None, None) on failure.
35     """
36
37     try:
38         from lxml import etree as ElementTree
39     except ImportError:
40         from elementtree import ElementTree
41
42     params = {'ticket': ticket, 'service': service}
43     url = (urljoin(settings.CAS_SERVER_URL, 'serviceValidate') + '?' +
44            urlencode(params))
45     page = urlopen(url)
46     try:
47         response = page.read()
48         tree = ElementTree.fromstring(response)
49         if tree[0].tag.endswith('authenticationSuccess'):
50             attrs = {}
51             for tag in tree[0][1:]:
52                 attrs[tag.tag] = tag.text
53             return tree[0][0].text, attrs
54         else:
55             return None, None
56     except:
57         import traceback
58         traceback.print_exc()
59         print "****", url
60         print response
61         print "****"
62     finally:
63         page.close()
64
65
66 _PROTOCOLS = {'1': _verify_cas1, '2': _verify_cas2}
67
68 if settings.CAS_VERSION not in _PROTOCOLS:
69     raise ValueError('Unsupported CAS_VERSION %r' % settings.CAS_VERSION)
70
71 _verify = _PROTOCOLS[settings.CAS_VERSION]
72
73
74 class CASBackend(object):
75     """CAS authentication backend"""
76
77     def authenticate(self, ticket, service):
78         """Verifies CAS ticket and gets or creates User object"""
79
80         username, attrs = _verify(ticket, service)
81         if not username:
82             return None
83
84         user_attrs = {}
85         if hasattr(settings, 'CAS_USER_ATTRS_MAP'):
86             attr_map = settings.CAS_USER_ATTRS_MAP
87             for k, v in attrs.items():
88                 if k in attr_map:
89                     user_attrs[attr_map[k]] = v # unicode(v, 'utf-8')
90
91         try:
92             user = User.objects.get(username__iexact=username)
93             # update user info
94             changed = False
95             for k, v in user_attrs.items():
96                 if getattr(user, k) != v:
97                     setattr(user, k, v)
98                     changed = True
99             if changed:
100                 user.save()
101         except User.DoesNotExist:
102             # user will have an "unusable" password
103             user = User.objects.create_user(username, '')
104             for k, v in user_attrs.items():
105                 setattr(user, k, v)
106             user.first_name = attrs.get('firstname', '')
107             user.last_name = attrs.get('lastname', '')
108             user.save()
109         return user
110
111     def get_user(self, user_id):
112         """Retrieve the user's entry in the User model if it exists"""
113
114         try:
115             return User.objects.get(pk=user_id)
116         except User.DoesNotExist:
117             return None