X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/blobdiff_plain/0ea3daef05a6126e6429d37d9706aef785104802..34efbf37f568d6db6523da7dde3a134c467ed89b:/cas_provider/views.py?ds=sidebyside diff --git a/cas_provider/views.py b/cas_provider/views.py index c9a44db..717aa6f 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -1,16 +1,21 @@ +from lxml import etree +from urllib import urlencode +import urllib2 +import urlparse from django.conf import settings from django.contrib.auth import login as auth_login, logout as auth_logout from django.core.urlresolvers import get_callable from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import render_to_response from django.template import RequestContext +from cas_provider.attribute_formatters import NSMAP, CAS +from cas_provider.models import ProxyGrantingTicket, ProxyTicket from forms import LoginForm from models import ServiceTicket, LoginTicket __all__ = ['login', 'validate', 'logout', 'service_validate'] - INVALID_TICKET = 'INVALID_TICKET' INVALID_SERVICE = 'INVALID_SERVICE' INVALID_REQUEST = 'INVALID_REQUEST' @@ -21,12 +26,13 @@ ERROR_MESSAGES = ( (INVALID_SERVICE, u'Service is invalid'), (INVALID_REQUEST, u'Not all required parameters were sent.'), (INTERNAL_ERROR, u'An internal error occurred during ticket validation'), -) + ) -def login(request, template_name='cas/login.html', \ - success_redirect=settings.LOGIN_REDIRECT_URL, - warn_template_name='cas/warn.html'): +def login(request, template_name='cas/login.html',\ + success_redirect=settings.LOGIN_REDIRECT_URL, + warn_template_name='cas/warn.html', + form_class=LoginForm): service = request.GET.get('service', None) if request.user.is_authenticated(): if service is not None: @@ -40,7 +46,7 @@ def login(request, template_name='cas/login.html', \ else: return HttpResponseRedirect(success_redirect) if request.method == 'POST': - form = LoginForm(request.POST) + form = form_class(data=request.POST, request=request) if form.is_valid(): user = form.get_user() auth_login(request, user) @@ -50,14 +56,16 @@ def login(request, template_name='cas/login.html', \ success_redirect = ticket.get_redirect_url() return HttpResponseRedirect(success_redirect) else: - form = LoginForm(initial={ + form = form_class(request=request, initial={ 'service': service, 'lt': LoginTicket.objects.create() }) + if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'): + request.session.set_test_cookie() return render_to_response(template_name, { 'form': form, - 'errors': form.get_errors() - }, context_instance=RequestContext(request)) + 'errors': form.get_errors() if hasattr(form, 'get_errors') else None, + }, context_instance=RequestContext(request)) def validate(request): @@ -72,28 +80,40 @@ def validate(request): ticket = ServiceTicket.objects.get(ticket=ticket_string) assert ticket.service == service username = ticket.user.username - ticket.delete() return HttpResponse("yes\n%s\n" % username) except: pass return HttpResponse("no\n\n") -def logout(request, template_name='cas/logout.html', \ - auto_redirect=settings.CAS_AUTO_REDIRECT_AFTER_LOGOUT): +def logout(request, template_name='cas/logout.html', + auto_redirect=settings.CAS_AUTO_REDIRECT_AFTER_LOGOUT): url = request.GET.get('url', None) if request.user.is_authenticated(): + for ticket in ServiceTicket.objects.filter(user = request.user): + ticket.delete() auth_logout(request) if url and auto_redirect: return HttpResponseRedirect(url) - return render_to_response(template_name, {'url': url}, \ + return render_to_response(template_name, {'url': url},\ context_instance=RequestContext(request)) +def proxy(request): + targetService = request.GET['targetService'] + pgtiou = request.GET['pgt'] -def service_validate(request): - """Validate ticket via CAS v.2 protocol""" - service = request.GET.get('service', None) - ticket_string = request.GET.get('ticket', None) + try: + proxyGrantingTicket = ProxyGrantingTicket.objects.get(pgtiou=pgtiou) + except ProxyGrantingTicket.DoesNotExist: + return _cas2_error_response(INVALID_TICKET) + + pt = ProxyTicket.objects.create(proxyGrantingTicket = proxyGrantingTicket, + user=proxyGrantingTicket.serviceTicket.user, + service=targetService) + return _cas2_proxy_success(pt.ticket) + + +def ticket_validate(service, ticket_string, pgtUrl): if service is None or ticket_string is None: return _cas2_error_response(INVALID_REQUEST) @@ -103,31 +123,99 @@ def service_validate(request): return _cas2_error_response(INVALID_TICKET) if ticket.service != service: - ticket.delete() return _cas2_error_response(INVALID_SERVICE) + + pgtIouId = None + proxies = () + if pgtUrl is not None: + pgt = generate_proxy_granting_ticket(pgtUrl, ticket) + if pgt: + pgtIouId = pgt.pgtiou + + while pgt: + proxies += (pgt.serviceTicket.service,) + pgt = pgt.serviceTicket.proxyGrantingTicket if hasattr(pgt.serviceTicket, 'proxyGrantingTicket') else None + + user = ticket.user - ticket.delete() - return _cas2_sucess_response(user) + return _cas2_sucess_response(user, pgtIouId, proxies) + + +def service_validate(request): + """Validate ticket via CAS v.2 protocol""" + service = request.GET.get('service', None) + ticket_string = request.GET.get('ticket', None) + pgtUrl = request.GET.get('pgtUrl', None) + if ticket_string.startswith('PT-'): + return _cas2_error_response(INVALID_TICKET, "serviceValidate cannot verify proxy tickets") + else: + return ticket_validate(service, ticket_string, pgtUrl) -def _cas2_error_response(code): +def proxy_validate(request): + """Validate ticket via CAS v.2 protocol""" + service = request.GET.get('service', None) + ticket_string = request.GET.get('ticket', None) + pgtUrl = request.GET.get('pgtUrl', None) + return ticket_validate(service, ticket_string, pgtUrl) + +def generate_proxy_granting_ticket(pgt_url, ticket): + proxy_callback_good_status = (200, 202, 301, 302, 304) + uri = list(urlparse.urlsplit(pgt_url)) + + pgt = ProxyGrantingTicket() + pgt.serviceTicket = ticket + + if hasattr(ticket, 'proxyGrantingTicket'): + # here we got a proxy ticket! tata! + pgt.pgt = ticket.proxyGrantingTicket + + params = {'pgtId': pgt.ticket, 'pgtIou': pgt.pgtiou} + + query = dict(urlparse.parse_qsl(uri[4])) + query.update(params) + + uri[4] = urlencode(query) + + + try: + response = urllib2.urlopen(urlparse.urlunsplit(uri)) + except urllib2.HTTPError, e: + if not e.code in proxy_callback_good_status: + return + except urllib2.URLError, e: + return + + pgt.save() + return pgt + + +def _cas2_proxy_success(pt): + return HttpResponse(proxy_success(pt)) + +def _cas2_sucess_response(user, pgt = None, proxies = None): + return HttpResponse(auth_success_response(user, pgt, proxies), mimetype='text/xml') + +def _cas2_error_response(code, message = None): return HttpResponse(u'''' %(message)s ''' % { 'code': code, - 'message': dict(ERROR_MESSAGES).get(code) + 'message': message if message else dict(ERROR_MESSAGES).get(code) }, mimetype='text/xml') +def proxy_success(pt): + response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP) + proxySuccess = etree.SubElement(response, CAS + 'proxySuccess') + proxyTicket = etree.SubElement(proxySuccess, CAS + 'proxyTicket') + proxyTicket.text = pt + return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8') -def _cas2_sucess_response(user): - return HttpResponse(auth_success_response(user), mimetype='text/xml') - +def auth_success_response(user, pgt, proxies): -def auth_success_response(user): - from attribute_formatters import CAS, NSMAP, etree response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP) auth_success = etree.SubElement(response, CAS + 'authenticationSuccess') @@ -140,4 +228,18 @@ def auth_success_response(user): if len(attrs) > 0: formater = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_FORMATER) formater(auth_success, attrs) + + + if pgt: + pgtElement = etree.SubElement(auth_success, CAS + 'proxyGrantingTicket') + pgtElement.text = pgt + + if proxies: + proxiesElement = etree.SubElement(auth_success , CAS + "proxies") + for proxy in proxies: + proxyElement = etree.SubElement(proxiesElement, CAS + "proxy") + proxyElement.text = proxy + + + return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')