X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/blobdiff_plain/4d08fe00954d5a4a3a5d963ff04d001be57f9007..6e6126a7402676610d76703bdb7c37dabddde1a8:/cas_provider/views.py diff --git a/cas_provider/views.py b/cas_provider/views.py index 32ed6e9..2769e52 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -1,16 +1,21 @@ +from lxml import etree +from urllib import urlencode +import urllib2 +import urlparse from django.conf import settings -from django.contrib.auth import login as auth_login, \ - logout as auth_logout +from django.contrib.auth import login as auth_login, logout as auth_logout +from django.core.urlresolvers import get_callable from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import render_to_response from django.template import RequestContext +from cas_provider.attribute_formatters import NSMAP, CAS +from cas_provider.models import ProxyGrantingTicket, ProxyTicket from forms import LoginForm from models import ServiceTicket, LoginTicket __all__ = ['login', 'validate', 'logout', 'service_validate'] - INVALID_TICKET = 'INVALID_TICKET' INVALID_SERVICE = 'INVALID_SERVICE' INVALID_REQUEST = 'INVALID_REQUEST' @@ -21,12 +26,13 @@ ERROR_MESSAGES = ( (INVALID_SERVICE, u'Service is invalid'), (INVALID_REQUEST, u'Not all required parameters were sent.'), (INTERNAL_ERROR, u'An internal error occurred during ticket validation'), -) + ) -def login(request, template_name='cas/login.html', \ - success_redirect=settings.LOGIN_REDIRECT_URL, - warn_template_name='cas/warn.html'): +def login(request, template_name='cas/login.html',\ + success_redirect=settings.LOGIN_REDIRECT_URL, + warn_template_name='cas/warn.html', + form_class=LoginForm): service = request.GET.get('service', None) if request.user.is_authenticated(): if service is not None: @@ -40,7 +46,7 @@ def login(request, template_name='cas/login.html', \ else: return HttpResponseRedirect(success_redirect) if request.method == 'POST': - form = LoginForm(request.POST) + form = form_class(data=request.POST, request=request) if form.is_valid(): user = form.get_user() auth_login(request, user) @@ -50,14 +56,16 @@ def login(request, template_name='cas/login.html', \ success_redirect = ticket.get_redirect_url() return HttpResponseRedirect(success_redirect) else: - form = LoginForm(initial={ + form = form_class(request=request, initial={ 'service': service, 'lt': LoginTicket.objects.create() }) + if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'): + request.session.set_test_cookie() return render_to_response(template_name, { 'form': form, - 'errors': form.get_errors() - }, context_instance=RequestContext(request)) + 'errors': form.get_errors() if hasattr(form, 'get_errors') else None, + }, context_instance=RequestContext(request)) def validate(request): @@ -70,56 +78,186 @@ def validate(request): # TODO: check user SSO session try: ticket = ServiceTicket.objects.get(ticket=ticket_string) + assert ticket.service == service username = ticket.user.username - ticket.delete() return HttpResponse("yes\n%s\n" % username) except: pass return HttpResponse("no\n\n") -def logout(request, template_name='cas/logout.html'): +def logout(request, template_name='cas/logout.html', + auto_redirect=settings.CAS_AUTO_REDIRECT_AFTER_LOGOUT): url = request.GET.get('url', None) - auth_logout(request) - return render_to_response(template_name, {'url': url}, \ - context_instance=RequestContext(request)) + if request.user.is_authenticated(): + for ticket in ServiceTicket.objects.filter(user=request.user): + ticket.delete() + auth_logout(request) + if url and auto_redirect: + return HttpResponseRedirect(url) + return render_to_response(template_name, {'url': url}, + context_instance=RequestContext(request)) -def service_validate(request): - """Validate ticket via CAS v.2 protocol""" - service = request.GET.get('service', None) - ticket_string = request.GET.get('ticket', None) +def proxy(request): + targetService = request.GET['targetService'] + pgtiou = request.GET['pgt'] + + try: + proxyGrantingTicket = ProxyGrantingTicket.objects.get(pgtiou=pgtiou) + except ProxyGrantingTicket.DoesNotExist: + return _cas2_error_response(INVALID_TICKET) + + if not proxyGrantingTicket.targetService == targetService: + return _cas2_error_response(INVALID_SERVICE, + "The PGT was issued for %(original)s but the PT was requested for %(but)s" % dict( + original=proxyGrantingTicket.targetService, but=targetService)) + + pt = ProxyTicket.objects.create(proxyGrantingTicket=proxyGrantingTicket, + user=proxyGrantingTicket.serviceTicket.user, + service=targetService) + return _cas2_proxy_success(pt.ticket) + + +def ticket_validate(service, ticket_string, pgtUrl): if service is None or ticket_string is None: return _cas2_error_response(INVALID_REQUEST) try: - ticket = ServiceTicket.objects.get(ticket=ticket_string) + if ticket_string.startswith('ST'): + ticket = ServiceTicket.objects.get(ticket=ticket_string) + elif ticket_string.startswith('PT'): + ticket = ProxyTicket.objects.get(ticket=ticket_string) + else: + return _cas2_error_response(INVALID_TICKET, + '%(ticket)s is neither Service (ST-...) nor Proxy Ticket (PT-...)' % { + 'ticket': ticket_string}) except ServiceTicket.DoesNotExist: return _cas2_error_response(INVALID_TICKET) - if settings.CAS_CHECK_SERVICE and ticket.service != service: - ticket.delete() + if ticket.service != service: return _cas2_error_response(INVALID_SERVICE) - username = ticket.user.username - ticket.delete() - return _cas2_sucess_response(username) + pgtIouId = None + proxies = () + if pgtUrl is not None: + pgt = generate_proxy_granting_ticket(pgtUrl, ticket) + if pgt: + pgtIouId = pgt.pgtiou + + if hasattr(ticket, 'proxyticket'): + pgt = ticket.proxyticket.proxyGrantingTicket + # I am issued by this proxy granting ticket + if hasattr(pgt.serviceTicket, 'proxyticket'): + while pgt: + if hasattr(pgt.serviceTicket, 'proxyticket'): + proxies += (pgt.serviceTicket.service,) + pgt = pgt.serviceTicket.proxyticket.proxyGrantingTicket + else: + pgt = None + user = ticket.user + return _cas2_sucess_response(user, pgtIouId, proxies) -def _cas2_error_response(code): - return HttpResponse(u'''' + +def service_validate(request): + """Validate ticket via CAS v.2 protocol""" + service = request.GET.get('service', None) + ticket_string = request.GET.get('ticket', None) + pgtUrl = request.GET.get('pgtUrl', None) + if ticket_string.startswith('PT-'): + return _cas2_error_response(INVALID_TICKET, "serviceValidate cannot verify proxy tickets") + else: + return ticket_validate(service, ticket_string, pgtUrl) + + +def proxy_validate(request): + """Validate ticket via CAS v.2 protocol""" + service = request.GET.get('service', None) + ticket_string = request.GET.get('ticket', None) + pgtUrl = request.GET.get('pgtUrl', None) + return ticket_validate(service, ticket_string, pgtUrl) + + +def generate_proxy_granting_ticket(pgt_url, ticket): + proxy_callback_good_status = (200, 202, 301, 302, 304) + uri = list(urlparse.urlsplit(pgt_url)) + + pgt = ProxyGrantingTicket() + pgt.serviceTicket = ticket + pgt.targetService = pgt_url + + if hasattr(ticket, 'proxyGrantingTicket'): + # here we got a proxy ticket! tata! + pgt.pgt = ticket.proxyGrantingTicket + + params = {'pgtId': pgt.ticket, 'pgtIou': pgt.pgtiou} + + query = dict(urlparse.parse_qsl(uri[4])) + query.update(params) + + uri[4] = urlencode(query) + + try: + response = urllib2.urlopen(urlparse.urlunsplit(uri)) + except urllib2.HTTPError, e: + if not e.code in proxy_callback_good_status: + return + except urllib2.URLError, e: + return + + pgt.save() + return pgt + + +def _cas2_proxy_success(pt): + return HttpResponse(proxy_success(pt)) + + +def _cas2_sucess_response(user, pgt=None, proxies=None): + return HttpResponse(auth_success_response(user, pgt, proxies), mimetype='text/xml') + + +def _cas2_error_response(code, message=None): + return HttpResponse(u''' %(message)s ''' % { - 'code': code, - 'message': dict(ERROR_MESSAGES).get(code) + 'code': code, + 'message': message if message else dict(ERROR_MESSAGES).get(code) }, mimetype='text/xml') -def _cas2_sucess_response(username): - return HttpResponse(u''' - - %(username)s - - ''' % {'username': username}, mimetype='text/xml') +def proxy_success(pt): + response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP) + proxySuccess = etree.SubElement(response, CAS + 'proxySuccess') + proxyTicket = etree.SubElement(proxySuccess, CAS + 'proxyTicket') + proxyTicket.text = pt + return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8') + + +def auth_success_response(user, pgt, proxies): + response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP) + auth_success = etree.SubElement(response, CAS + 'authenticationSuccess') + username = etree.SubElement(auth_success, CAS + 'user') + username.text = user.username + + if settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK: + callback = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK) + attrs = callback(user) + if len(attrs) > 0: + formater = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_FORMATER) + formater(auth_success, attrs) + + if pgt: + pgtElement = etree.SubElement(auth_success, CAS + 'proxyGrantingTicket') + pgtElement.text = pgt + + if proxies: + proxiesElement = etree.SubElement(auth_success, CAS + "proxies") + for proxy in proxies: + proxyElement = etree.SubElement(proxiesElement, CAS + "proxy") + proxyElement.text = proxy + + return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')