Require LXML (maybe I'll fix this later).
[cas.git] / provider / cas_provider / views.py
1 from django.http import HttpResponse, HttpResponseForbidden, HttpResponseRedirect
2 from django.shortcuts import render_to_response
3 from django.template import RequestContext
4 from django.contrib.auth.models import User
5 from django.contrib.auth import authenticate
6 from django.contrib.auth import login as auth_login, logout as auth_logout
7
8 from cas_provider.forms import LoginForm
9 from cas_provider.models import ServiceTicket, LoginTicket, auth_success_response
10 from cas_provider.utils import create_service_ticket
11
12 import urlparse, urllib
13
14 try:
15     from urlparse import parse_qs as url_parse_qs
16 except ImportError:
17     from cgi import parse_qs as url_parse_qs
18      
19
20 import logging
21 logger = logging.getLogger("fnp.cas.provider")
22
23 __all__ = ['login', 'validate', 'service_validate', 'logout']
24
25 def _add_query_param(url, param, value):    
26     parsed = urlparse.urlparse(url)
27     query = url_parse_qs(parsed.query)
28     query[param] = [unicode(value, 'utf-8')]
29     query = [ ((k, v) if len(v) != 1 else (k, v[0])) for k, v in query.iteritems() ]
30     parsed = urlparse.ParseResult(parsed.scheme, parsed.netloc,
31                                   parsed.path, parsed.params,
32                                   urllib.urlencode(query), parsed.fragment)
33     return parsed.geturl()
34
35
36 def login(request, template_name = 'cas/login.html', success_redirect = '/accounts/'):
37     service = request.GET.get('service', None)
38     
39     if request.user.is_authenticated():
40         logger.info("User %s passed auth, service is %s", request.user, service)
41         
42         if service is not None:
43             ticket = create_service_ticket(request.user, service)
44             target = _add_query_param(service, 'ticket', ticket.ticket)
45             logger.info("Redirecting to %s", target)
46             return HttpResponseRedirect(target)
47         else:
48             logger.info("Redirecting to default: %s", success_redirect)
49             return HttpResponseRedirect(success_redirect)
50     
51     errors = []
52     if request.method == 'POST':
53         username = request.POST.get('username', None)
54         password = request.POST.get('password', None)
55         service = request.POST.get('service', None)
56         lt = request.POST.get('lt', None)
57         
58         logger.debug("User %s logging in", username)
59         logger.info("Login submit: serivce = %s, Lticket=%s",service, lt)                      
60
61         try:
62             login_ticket = LoginTicket.objects.get(ticket = lt)
63         except:
64             errors.append('Login ticket expired. Please try again.')
65         else:
66             login_ticket.delete()
67             logger.debug("Auth")
68             user = authenticate(username = username, password = password)
69             if user is not None:
70                 if user.is_active:
71                     logger.debug("AuthLogin")
72                     auth_login(request, user)
73                     if service is not None:                        
74                         ticket = create_service_ticket(user, service)
75                         logger.info("Service=%s, ticket=%s", service, ticket)
76                         target = _add_query_param(service, 'ticket', ticket.ticket)
77                         logger.info("Redirecting to %s", target)
78                         return HttpResponseRedirect(target)
79                     else:
80                         logger.info("Redirecting to default: %s", success_redirect)
81                         return HttpResponseRedirect(success_redirect)
82                 else:
83                     errors.append('This account is disabled.')
84             else:
85                     errors.append('Incorrect username and/or password.')
86     
87     logger.debug("LOGIN GET, service = %s", service)
88     form = LoginForm(service)
89     return render_to_response(template_name, {'form': form, 'errors': errors}, context_instance = RequestContext(request))
90
91 def validate(request):
92     service = request.GET.get('service', None)
93     ticket_string = request.GET.get('ticket', None)
94     if service is not None and ticket_string is not None:
95         try:
96             ticket = ServiceTicket.objects.get(ticket = ticket_string)
97             username = ticket.user.username
98             ticket.delete()
99             return HttpResponse("yes\n%s\n" % username)
100         except:
101             pass
102     return HttpResponse("no\n\n")
103
104 def service_validate(request):
105     service = request.GET.get('service', None)
106     ticket_string = request.GET.get('ticket', None)
107     if service is None or ticket_string is None:
108         return HttpResponse('''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
109             <cas:authenticationFailure code="INVALID_REQUEST">
110                 Not all required parameters were sent.
111             </cas:authenticationFailure>
112         </cas:serviceResponse>''', mimetype = 'text/xml')
113
114     try:
115         ticket = ServiceTicket.objects.get(ticket = ticket_string)
116         ticket.delete()
117         return HttpResponse(auth_success_response(ticket.user), mimetype = 'text/xml')
118     except ServiceTicket.DoesNotExist:
119         return HttpResponse('''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
120             <cas:authenticationFailure code="INVALID_TICKET">
121                 The provided ticket is invalid.
122             </cas:authenticationFailure>
123         </cas:serviceResponse>''', mimetype = 'text/xml')
124
125 def logout(request, template_name = 'cas/logout.html'):
126     url = request.GET.get('url', None)
127     auth_logout(request)
128     return render_to_response(template_name, {'url': url}, context_instance = RequestContext(request))