2769e528f9a3fb7021c712e223fa122ab267ace7
[django-cas-provider.git] / cas_provider / views.py
1 from lxml import etree
2 from urllib import urlencode
3 import urllib2
4 import urlparse
5 from django.conf import settings
6 from django.contrib.auth import login as auth_login, logout as auth_logout
7 from django.core.urlresolvers import get_callable
8 from django.http import HttpResponse, HttpResponseRedirect
9 from django.shortcuts import render_to_response
10 from django.template import RequestContext
11 from cas_provider.attribute_formatters import NSMAP, CAS
12 from cas_provider.models import ProxyGrantingTicket, ProxyTicket
13 from forms import LoginForm
14 from models import ServiceTicket, LoginTicket
15
16
17 __all__ = ['login', 'validate', 'logout', 'service_validate']
18
19 INVALID_TICKET = 'INVALID_TICKET'
20 INVALID_SERVICE = 'INVALID_SERVICE'
21 INVALID_REQUEST = 'INVALID_REQUEST'
22 INTERNAL_ERROR = 'INTERNAL_ERROR'
23
24 ERROR_MESSAGES = (
25     (INVALID_TICKET, u'The provided ticket is invalid.'),
26     (INVALID_SERVICE, u'Service is invalid'),
27     (INVALID_REQUEST, u'Not all required parameters were sent.'),
28     (INTERNAL_ERROR, u'An internal error occurred during ticket validation'),
29     )
30
31
32 def login(request, template_name='cas/login.html',\
33           success_redirect=settings.LOGIN_REDIRECT_URL,
34           warn_template_name='cas/warn.html',
35           form_class=LoginForm):
36     service = request.GET.get('service', None)
37     if request.user.is_authenticated():
38         if service is not None:
39             if request.GET.get('warn', False):
40                 return render_to_response(warn_template_name, {
41                     'service': service,
42                     'warn': False
43                 }, context_instance=RequestContext(request))
44             ticket = ServiceTicket.objects.create(service=service, user=request.user)
45             return HttpResponseRedirect(ticket.get_redirect_url())
46         else:
47             return HttpResponseRedirect(success_redirect)
48     if request.method == 'POST':
49         form = form_class(data=request.POST, request=request)
50         if form.is_valid():
51             user = form.get_user()
52             auth_login(request, user)
53             service = form.cleaned_data.get('service')
54             if service is not None:
55                 ticket = ServiceTicket.objects.create(service=service, user=user)
56                 success_redirect = ticket.get_redirect_url()
57             return HttpResponseRedirect(success_redirect)
58     else:
59         form = form_class(request=request, initial={
60             'service': service,
61             'lt': LoginTicket.objects.create()
62         })
63     if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'):
64         request.session.set_test_cookie()
65     return render_to_response(template_name, {
66         'form': form,
67         'errors': form.get_errors() if hasattr(form, 'get_errors') else None,
68         }, context_instance=RequestContext(request))
69
70
71 def validate(request):
72     """Validate ticket via CAS v.1 protocol"""
73     service = request.GET.get('service', None)
74     ticket_string = request.GET.get('ticket', None)
75     if service is not None and ticket_string is not None:
76         #renew = request.GET.get('renew', True)
77         #if not renew:
78         # TODO: check user SSO session
79         try:
80             ticket = ServiceTicket.objects.get(ticket=ticket_string)
81             assert ticket.service == service
82             username = ticket.user.username
83             return HttpResponse("yes\n%s\n" % username)
84         except:
85             pass
86     return HttpResponse("no\n\n")
87
88
89 def logout(request, template_name='cas/logout.html',
90            auto_redirect=settings.CAS_AUTO_REDIRECT_AFTER_LOGOUT):
91     url = request.GET.get('url', None)
92     if request.user.is_authenticated():
93         for ticket in ServiceTicket.objects.filter(user=request.user):
94             ticket.delete()
95         auth_logout(request)
96         if url and auto_redirect:
97             return HttpResponseRedirect(url)
98     return render_to_response(template_name, {'url': url},
99         context_instance=RequestContext(request))
100
101
102 def proxy(request):
103     targetService = request.GET['targetService']
104     pgtiou = request.GET['pgt']
105
106     try:
107         proxyGrantingTicket = ProxyGrantingTicket.objects.get(pgtiou=pgtiou)
108     except ProxyGrantingTicket.DoesNotExist:
109         return _cas2_error_response(INVALID_TICKET)
110
111     if not proxyGrantingTicket.targetService == targetService:
112         return _cas2_error_response(INVALID_SERVICE,
113             "The PGT was issued for %(original)s but the PT was requested for %(but)s" % dict(
114                 original=proxyGrantingTicket.targetService, but=targetService))
115
116     pt = ProxyTicket.objects.create(proxyGrantingTicket=proxyGrantingTicket,
117         user=proxyGrantingTicket.serviceTicket.user,
118         service=targetService)
119     return _cas2_proxy_success(pt.ticket)
120
121
122 def ticket_validate(service, ticket_string, pgtUrl):
123     if service is None or ticket_string is None:
124         return _cas2_error_response(INVALID_REQUEST)
125
126     try:
127         if ticket_string.startswith('ST'):
128             ticket = ServiceTicket.objects.get(ticket=ticket_string)
129         elif ticket_string.startswith('PT'):
130             ticket = ProxyTicket.objects.get(ticket=ticket_string)
131         else:
132             return _cas2_error_response(INVALID_TICKET,
133                 '%(ticket)s is neither Service (ST-...) nor Proxy Ticket (PT-...)' % {
134                     'ticket': ticket_string})
135     except ServiceTicket.DoesNotExist:
136         return _cas2_error_response(INVALID_TICKET)
137
138     if ticket.service != service:
139         return _cas2_error_response(INVALID_SERVICE)
140
141     pgtIouId = None
142     proxies = ()
143     if pgtUrl is not None:
144         pgt = generate_proxy_granting_ticket(pgtUrl, ticket)
145         if pgt:
146             pgtIouId = pgt.pgtiou
147
148     if hasattr(ticket, 'proxyticket'):
149         pgt = ticket.proxyticket.proxyGrantingTicket
150         # I am issued by this proxy granting ticket
151         if hasattr(pgt.serviceTicket, 'proxyticket'):
152             while pgt:
153                 if hasattr(pgt.serviceTicket, 'proxyticket'):
154                     proxies += (pgt.serviceTicket.service,)
155                     pgt = pgt.serviceTicket.proxyticket.proxyGrantingTicket
156                 else:
157                     pgt = None
158
159     user = ticket.user
160     return _cas2_sucess_response(user, pgtIouId, proxies)
161
162
163 def service_validate(request):
164     """Validate ticket via CAS v.2 protocol"""
165     service = request.GET.get('service', None)
166     ticket_string = request.GET.get('ticket', None)
167     pgtUrl = request.GET.get('pgtUrl', None)
168     if ticket_string.startswith('PT-'):
169         return _cas2_error_response(INVALID_TICKET, "serviceValidate cannot verify proxy tickets")
170     else:
171         return ticket_validate(service, ticket_string, pgtUrl)
172
173
174 def proxy_validate(request):
175     """Validate ticket via CAS v.2 protocol"""
176     service = request.GET.get('service', None)
177     ticket_string = request.GET.get('ticket', None)
178     pgtUrl = request.GET.get('pgtUrl', None)
179     return ticket_validate(service, ticket_string, pgtUrl)
180
181
182 def generate_proxy_granting_ticket(pgt_url, ticket):
183     proxy_callback_good_status = (200, 202, 301, 302, 304)
184     uri = list(urlparse.urlsplit(pgt_url))
185
186     pgt = ProxyGrantingTicket()
187     pgt.serviceTicket = ticket
188     pgt.targetService = pgt_url
189
190     if hasattr(ticket, 'proxyGrantingTicket'):
191         # here we got a proxy ticket! tata!
192         pgt.pgt = ticket.proxyGrantingTicket
193
194     params = {'pgtId': pgt.ticket, 'pgtIou': pgt.pgtiou}
195
196     query = dict(urlparse.parse_qsl(uri[4]))
197     query.update(params)
198
199     uri[4] = urlencode(query)
200
201     try:
202         response = urllib2.urlopen(urlparse.urlunsplit(uri))
203     except urllib2.HTTPError, e:
204         if not e.code in proxy_callback_good_status:
205             return
206     except urllib2.URLError, e:
207         return
208
209     pgt.save()
210     return pgt
211
212
213 def _cas2_proxy_success(pt):
214     return HttpResponse(proxy_success(pt))
215
216
217 def _cas2_sucess_response(user, pgt=None, proxies=None):
218     return HttpResponse(auth_success_response(user, pgt, proxies), mimetype='text/xml')
219
220
221 def _cas2_error_response(code, message=None):
222     return HttpResponse(u'''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
223             <cas:authenticationFailure code="%(code)s">
224                 %(message)s
225             </cas:authenticationFailure>
226         </cas:serviceResponse>''' % {
227         'code': code,
228         'message': message if message else dict(ERROR_MESSAGES).get(code)
229     }, mimetype='text/xml')
230
231
232 def proxy_success(pt):
233     response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP)
234     proxySuccess = etree.SubElement(response, CAS + 'proxySuccess')
235     proxyTicket = etree.SubElement(proxySuccess, CAS + 'proxyTicket')
236     proxyTicket.text = pt
237     return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')
238
239
240 def auth_success_response(user, pgt, proxies):
241     response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP)
242     auth_success = etree.SubElement(response, CAS + 'authenticationSuccess')
243     username = etree.SubElement(auth_success, CAS + 'user')
244     username.text = user.username
245
246     if settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK:
247         callback = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK)
248         attrs = callback(user)
249         if len(attrs) > 0:
250             formater = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_FORMATER)
251             formater(auth_success, attrs)
252
253     if pgt:
254         pgtElement = etree.SubElement(auth_success, CAS + 'proxyGrantingTicket')
255         pgtElement.text = pgt
256
257     if proxies:
258         proxiesElement = etree.SubElement(auth_success, CAS + "proxies")
259         for proxy in proxies:
260             proxyElement = etree.SubElement(proxiesElement, CAS + "proxy")
261             proxyElement.text = proxy
262
263     return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')