Merge remote-tracking branch 'cas2/master'
authordeyk <deyk@crossway.org>
Fri, 13 Apr 2012 22:25:48 +0000 (15:25 -0700)
committerdeyk <deyk@crossway.org>
Fri, 13 Apr 2012 22:26:25 +0000 (15:26 -0700)
Conflicts:
.gitignore
cas_provider/forms.py
cas_provider/models.py
cas_provider/urls.py
cas_provider/views.py

Merged from https://github.com/castlabs/django-cas-provider

PT #27996721

1  2 
cas_provider/forms.py
cas_provider/urls.py
cas_provider/views.py

@@@ -1,26 -1,37 +1,33 @@@
  from django import forms
+ from django.conf import settings
+ from django.contrib.auth import authenticate
+ from django.contrib.auth.forms import AuthenticationForm
+ from django.forms import ValidationError
+ from django.utils.translation import ugettext_lazy as _
+ from models import LoginTicket
+ import datetime
  
  
- class LoginForm(forms.Form):
 -__all__ = ['LoginForm', ]
 -
 -
+ class LoginForm(AuthenticationForm):
 -    lt = forms.CharField(widget=forms.HiddenInput)
 +    email = forms.CharField(widget=forms.TextInput(attrs={'autofocus': 'autofocus',
 +                                                          'max_length': '255'}))
 +    password = forms.CharField(widget=forms.PasswordInput)
      service = forms.CharField(widget=forms.HiddenInput, required=False)
 +    remember_me = forms.BooleanField(required=False, label="Keep me signed in",
 +                                     widget=forms.CheckboxInput(attrs={'class': 'remember_me'}))
  
 -    def clean_lt(self):
 -        ticket = self.cleaned_data['lt']
 -        timeframe = datetime.datetime.now() - \
 -                    datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION)
 -        try:
 -            return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe)
 -        except LoginTicket.DoesNotExist:
 -            raise ValidationError(_('Login ticket expired. Please try again.'))
 -        return ticket
 +    def __init__(self, *args, **kwargs):
 +        # renew = kwargs.pop('renew', None)
 +        # gateway = kwargs.pop('gateway', None)
 +        request = kwargs.pop('request', None)
 +        super(LoginForm, self).__init__(*args, **kwargs)
 +        self.request = request
  
 -    def clean(self):
 -        AuthenticationForm.clean(self)
 -        self.cleaned_data.get('lt').delete()
 -        return self.cleaned_data
 +    def clean_remember_me(self):
 +        remember = self.cleaned_data['remember_me']
 +        if not remember and self.request is not None:
 +            self.request.session.set_expiry(0)
 +            
  
 -    def get_errors(self):
 -        errors = []
 -        for k, error in self.errors.items():
 -            errors += [e for e in error]
 -        return errors
 +class MergeLoginForm(LoginForm):
 +    email = forms.CharField(max_length=255, widget=forms.HiddenInput)
@@@ -1,11 -1,11 +1,13 @@@
- from django.conf.urls.defaults import *
+ from django.conf.urls.defaults import patterns, url
  
- from views import *
  
- urlpatterns = patterns('',
-                        url(r'^login/', login),
-                        url(r'^socialauth-login/$', socialauth_login),
-                        url(r'^validate/', validate),
-                        url(r'^logout/', logout),
-                        url(r'^login/merge/', login, {'merge': True, 'template_name': 'cas/merge.html'})
-                        )
+ urlpatterns = patterns('cas_provider.views',
++    url(r'^login/merge/', 'login', {'merge': True, 'template_name': 'cas/merge.html'})
+     url(r'^login/?$', 'login', name='cas_login'),
++    url(r'^socialauth-login/$', 'socialauth_login', name='cas_socialauth_login'),
+     url(r'^validate/?$', 'validate', name='cas_validate'),
+     url(r'^proxy/?$', 'proxy', name='proxy'),
+     url(r'^serviceValidate/?$', 'service_validate', name='cas_service_validate'),
+     url(r'^proxyValidate/?$', 'proxy_validate', name='cas_proxy_validate'),
+     url(r'^logout/?$', 'logout', name='cas_logout'),
+ )
  import logging
 -from lxml import etree
 +logger = logging.getLogger('cas_provider.views')
 +import urllib
 +
++import logging
+ from urllib import urlencode
+ import urllib2
+ import urlparse
++
 +from django.http import HttpResponse, HttpResponseRedirect
+ from django.conf import settings
+ from django.contrib.auth import login as auth_login, logout as auth_logout
+ from django.core.urlresolvers import get_callable
 -from django.http import HttpResponse, HttpResponseRedirect
  from django.shortcuts import render_to_response
  from django.template import RequestContext
- from django.contrib.auth import login as auth_login, logout as auth_logout
 +from django.contrib.auth import authenticate
- from forms import LoginForm, MergeLoginForm
- from models import ServiceTicket
- from utils import create_service_ticket
- from exceptions import SameEmailMismatchedPasswords
 +from django.core.urlresolvers import reverse
 +
 -from forms import LoginForm
 -from models import ServiceTicket, LoginTicket
++from lxml import etree
+ from cas_provider.attribute_formatters import NSMAP, CAS
+ from cas_provider.models import ProxyGrantingTicket, ProxyTicket
++from cas_provider.models import ServiceTicket
++from cas_provider.exceptions import SameEmailMismatchedPasswords
++from cas_provider.forms import LoginForm, MergeLoginForm
 +
 +from . import signals
  
- __all__ = ['login', 'validate', 'logout']
+ __all__ = ['login', 'validate', 'logout', 'service_validate']
  
+ INVALID_TICKET = 'INVALID_TICKET'
+ INVALID_SERVICE = 'INVALID_SERVICE'
+ INVALID_REQUEST = 'INVALID_REQUEST'
+ INTERNAL_ERROR = 'INTERNAL_ERROR'
+ ERROR_MESSAGES = (
+     (INVALID_TICKET, u'The provided ticket is invalid.'),
+     (INVALID_SERVICE, u'Service is invalid'),
+     (INVALID_REQUEST, u'Not all required parameters were sent.'),
+     (INTERNAL_ERROR, u'An internal error occurred during ticket validation'),
+     )
  
- def _build_service_url(service, ticket):
-     if service.find('?') == -1:
-         return service + '?ticket=' + ticket
-     else:
-         return service + '&ticket=' + ticket
  
+ logger = logging.getLogger(__name__)
  
- def login(request, template_name='cas/login.html', success_redirect='/account/', **kwargs):
 -def login(request, template_name='cas/login.html',\
++
++def login(request, template_name='cas/login.html',
+           success_redirect=settings.LOGIN_REDIRECT_URL,
 -          warn_template_name='cas/warn.html',
 -          form_class=LoginForm):
++          warn_template_name='cas/warn.html', **kwargs):
 +    merge = kwargs.get('merge', False)
 +    logging.debug('CAS Provider Login view. Method is %s, merge is %s, template is %s.',
 +                  request.method, merge, template_name)
 +
      service = request.GET.get('service', None)
 -    if request.user.is_authenticated():
 -        if service is not None:
 -            if request.GET.get('warn', False):
 -                return render_to_response(warn_template_name, {
 -                    'service': service,
 -                    'warn': False
 -                }, context_instance=RequestContext(request))
 -            ticket = ServiceTicket.objects.create(service=service, user=request.user)
 -            return HttpResponseRedirect(ticket.get_redirect_url())
 -        else:
 -            return HttpResponseRedirect(success_redirect)
 +    if service is not None:
 +        # Save the service on the session, for later use if we end up
 +        # in one of the more complicated workflows.
 +        request.session['service'] = service
 +
 +    user = request.user
 +
 +    errors = []
 +
      if request.method == 'POST':
 -        form = form_class(data=request.POST, request=request)
 +        if merge:
 +            form = MergeLoginForm(request.POST, request=request)
 +        else:
 +            form = LoginForm(request.POST, request=request)
 +
          if form.is_valid():
 -            user = form.get_user()
 -            auth_login(request, user)
 -            service = form.cleaned_data.get('service')
 -            if service is not None:
 +            service = form.cleaned_data.get('service', None)
 +            try:
 +                auth_args = dict(username=form.cleaned_data['email'],
 +                                 password=form.cleaned_data['password'])
 +                if merge:
 +                    # We only want to send the merge argument if it's
 +                    # True. If it it's False, we want it to propagate
 +                    # through the auth backends properly.
 +                    auth_args['merge'] = merge
 +                user = authenticate(**auth_args)
 +            except SameEmailMismatchedPasswords:
 +                # Need to merge the accounts?
 +                if merge:
 +                    # We shouldn't get here...
 +                    raise
 +                else:
 +                    base_url = reverse('cas_provider_merge')
 +                    args = dict(
 +                        success_redirect=success_redirect,
 +                        email=form.cleaned_data['email'],
 +                        )
 +                    if service is not None:
 +                        args['service'] = service
 +                    args = urllib.urlencode(args)
 +
 +                    url = '%s?%s' % (base_url, args)
 +                    logging.debug('Redirecting to %s', url)
 +                    return HttpResponseRedirect(url)
 +            
 +            if user is None:
 +                errors.append('Incorrect username and/or password.')
 +            else:
 +                if user.is_active:
 +                    auth_login(request, user)
 +
 +    else:  # Not a POST...
 +        if merge:
 +            form = MergeLoginForm(initial={'service': service, 'email': request.GET.get('email')})
 +        else:
 +            form = LoginForm(initial={'service': service})
 +
 +    if user is not None and user.is_authenticated():
 +        # We have an authenticated user.
 +        if not user.is_active:
 +            errors.append('This account is disabled.')
 +        else:
 +            # Send the on_cas_login signal. If we get an HttpResponse, return that.
 +            for receiver, response in signals.on_cas_login.send(sender=login, request=request, **kwargs):
 +                if isinstance(response, HttpResponse):
 +                    return response
 +            
 +            if service is None:
 +                # Try and pull the service off the session
 +                service = request.session.pop('service', service)
 +            
 +            if service is None:
 +                # Normal internal success redirection.
 +                logging.debug('Redirecting to %s', success_redirect)
 +                return HttpResponseRedirect(success_redirect)
 +            else:
++                if request.GET.get('warn', False):
++                    return render_to_response(warn_template_name, {
++                        'service': service,
++                        'warn': False
++                    }, context_instance=RequestContext(request))
++                
 +                # Create a service ticket and redirect to the service.
-                 ticket = create_service_ticket(request.user, service)
+                 ticket = ServiceTicket.objects.create(service=service, user=user)
 -                success_redirect = ticket.get_redirect_url()
 -            return HttpResponseRedirect(success_redirect)
 -    else:
 -        form = form_class(request=request, initial={
 -            'service': service,
 -            'lt': LoginTicket.objects.create()
 -        })
 -    if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'):
 -        request.session.set_test_cookie()
 -    return render_to_response(template_name, {
 -        'form': form,
 -        'errors': form.get_errors() if hasattr(form, 'get_errors') else None,
 -        }, context_instance=RequestContext(request))
 +                if 'service' in request.session:
 +                    # Don't need this any more.
 +                    del request.session['service']
 +
-                 url = _build_service_url(service, ticket.ticket)
++                url = ticket.get_redirect_url()
 +                logging.debug('Redirecting to %s', url)
 +                return HttpResponseRedirect(url)
 +
 +    logging.debug('Rendering response on %s, merge is %s', template_name, merge)
 +    return render_to_response(template_name, {'form': form, 'errors': errors}, context_instance=RequestContext(request))
  
  
  def validate(request):
 -    """Validate ticket via CAS v.1 protocol"""
++    """Validate ticket via CAS v.1 protocol
++    """
      service = request.GET.get('service', None)
      ticket_string = request.GET.get('ticket', None)
 +    logger.info('Validating ticket %s for %s', ticket_string, service)
      if service is not None and ticket_string is not None:
+         #renew = request.GET.get('renew', True)
+         #if not renew:
+         # TODO: check user SSO session
          try:
              ticket = ServiceTicket.objects.get(ticket=ticket_string)
+             assert ticket.service == service
 +        except ServiceTicket.DoesNotExist:
 +            logger.exception("Tried to validate with an invalid ticket %s for %s", ticket_string, service)
 +        except Exception as e:
 +            logger.exception('Got an exception: %s', e)
 +        else:
              username = ticket.user.username
 -            return HttpResponse("yes\n%s\n" % username)
 -        except:
 -            pass
 -    return HttpResponse("no\n\n")
 +            ticket.delete()
 +
 +            results = signals.on_cas_collect_histories.send(sender=validate, for_email=ticket.user.email)
 +            histories = '\n'.join('\n'.join(rs) for rc, rs in results)
 +            logger.info('Validated %s %s', username, "(also %s)" % histories if histories else '')
 +            return HttpResponse("yes\n%s\n%s" % (username, histories))
  
 +    logger.info('Validation failed.')
 +    return HttpResponse("no\n\n")
 +    
  
- def logout(request, template_name='cas/logout.html'):
+ def logout(request, template_name='cas/logout.html',
+            auto_redirect=settings.CAS_AUTO_REDIRECT_AFTER_LOGOUT):
      url = request.GET.get('url', None)
-     auth_logout(request)
-     return render_to_response(template_name, {'url': url}, context_instance=RequestContext(request))
+     if request.user.is_authenticated():
+         for ticket in ServiceTicket.objects.filter(user=request.user):
+             ticket.delete()
+         auth_logout(request)
+         if url and auto_redirect:
+             return HttpResponseRedirect(url)
+     return render_to_response(template_name, {'url': url},
+         context_instance=RequestContext(request))
+ def proxy(request):
+     targetService = request.GET['targetService']
+     pgt_id = request.GET['pgt']
+     try:
+         proxyGrantingTicket = ProxyGrantingTicket.objects.get(ticket=pgt_id)
+     except ProxyGrantingTicket.DoesNotExist:
+         return _cas2_error_response(INVALID_TICKET)
+     pt = ProxyTicket.objects.create(proxyGrantingTicket=proxyGrantingTicket,
+         user=proxyGrantingTicket.serviceTicket.user,
+         service=targetService)
+     return _cas2_proxy_success(pt.ticket)
+ def ticket_validate(service, ticket_string, pgtUrl):
+     if service is None or ticket_string is None:
+         return _cas2_error_response(INVALID_REQUEST)
+     try:
+         if ticket_string.startswith('ST'):
+             ticket = ServiceTicket.objects.get(ticket=ticket_string)
+         elif ticket_string.startswith('PT'):
+             ticket = ProxyTicket.objects.get(ticket=ticket_string)
+         else:
+             return _cas2_error_response(INVALID_TICKET,
+                 '%(ticket)s is neither Service (ST-...) nor Proxy Ticket (PT-...)' % {
+                     'ticket': ticket_string})
+     except ServiceTicket.DoesNotExist:
+         return _cas2_error_response(INVALID_TICKET)
 -    ticketUrl =  urlparse.urlparse(ticket.service)
 -    serviceUrl =  urlparse.urlparse(service)
++    ticketUrl = urlparse.urlparse(ticket.service)
++    serviceUrl = urlparse.urlparse(service)
+     if not(ticketUrl.hostname == serviceUrl.hostname and ticketUrl.path == serviceUrl.path and ticketUrl.port == serviceUrl.port):
+         return _cas2_error_response(INVALID_SERVICE)
+     pgtIouId = None
+     proxies = ()
+     if pgtUrl is not None:
+         pgt = generate_proxy_granting_ticket(pgtUrl, ticket)
+         if pgt:
+             pgtIouId = pgt.pgtiou
+     if hasattr(ticket, 'proxyticket'):
+         pgt = ticket.proxyticket.proxyGrantingTicket
+         # I am issued by this proxy granting ticket
+         if hasattr(pgt.serviceTicket, 'proxyticket'):
+             while pgt:
+                 if hasattr(pgt.serviceTicket, 'proxyticket'):
+                     proxies += (pgt.serviceTicket.service,)
+                     pgt = pgt.serviceTicket.proxyticket.proxyGrantingTicket
+                 else:
+                     pgt = None
+     user = ticket.user
+     return _cas2_sucess_response(user, pgtIouId, proxies)
+ def service_validate(request):
+     """Validate ticket via CAS v.2 protocol"""
+     service = request.GET.get('service', None)
+     ticket_string = request.GET.get('ticket', None)
+     pgtUrl = request.GET.get('pgtUrl', None)
+     if ticket_string.startswith('PT-'):
+         return _cas2_error_response(INVALID_TICKET, "serviceValidate cannot verify proxy tickets")
+     else:
+         return ticket_validate(service, ticket_string, pgtUrl)
+ def proxy_validate(request):
+     """Validate ticket via CAS v.2 protocol"""
+     service = request.GET.get('service', None)
+     ticket_string = request.GET.get('ticket', None)
+     pgtUrl = request.GET.get('pgtUrl', None)
+     return ticket_validate(service, ticket_string, pgtUrl)
+ def generate_proxy_granting_ticket(pgt_url, ticket):
+     proxy_callback_good_status = (200, 202, 301, 302, 304)
+     uri = list(urlparse.urlsplit(pgt_url))
+     pgt = ProxyGrantingTicket()
+     pgt.serviceTicket = ticket
+     pgt.targetService = pgt_url
+     if hasattr(ticket, 'proxyGrantingTicket'):
+         # here we got a proxy ticket! tata!
+         pgt.pgt = ticket.proxyGrantingTicket
+     params = {'pgtId': pgt.ticket, 'pgtIou': pgt.pgtiou}
+     query = dict(urlparse.parse_qsl(uri[4]))
+     query.update(params)
+     uri[3] = urlencode(query)
+     try:
+         response = urllib2.urlopen(urlparse.urlunsplit(uri))
 -    except urllib2.HTTPError, e:
++    except urllib2.HTTPError as e:
+         if not e.code in proxy_callback_good_status:
+             logger.debug('Checking Proxy Callback URL {} returned {}. Not issuing PGT.'.format(uri, e.code))
+             return
 -    except urllib2.URLError, e:
++    except urllib2.URLError as e:
+         logger.debug('Checking Proxy Callback URL {} raised URLError. Not issuing PGT.'.format(uri))
+         return
+     pgt.save()
+     return pgt
+ def _cas2_proxy_success(pt):
+     return HttpResponse(proxy_success(pt))
+ def _cas2_sucess_response(user, pgt=None, proxies=None):
+     return HttpResponse(auth_success_response(user, pgt, proxies), mimetype='text/xml')
+ def _cas2_error_response(code, message=None):
+     return HttpResponse(u'''<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
+             <cas:authenticationFailure code="%(code)s">
+                 %(message)s
+             </cas:authenticationFailure>
+         </cas:serviceResponse>''' % {
+         'code': code,
+         'message': message if message else dict(ERROR_MESSAGES).get(code)
+     }, mimetype='text/xml')
+ def proxy_success(pt):
+     response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP)
+     proxySuccess = etree.SubElement(response, CAS + 'proxySuccess')
+     proxyTicket = etree.SubElement(proxySuccess, CAS + 'proxyTicket')
+     proxyTicket.text = pt
+     return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')
+ def auth_success_response(user, pgt, proxies):
+     response = etree.Element(CAS + 'serviceResponse', nsmap=NSMAP)
+     auth_success = etree.SubElement(response, CAS + 'authenticationSuccess')
+     username = etree.SubElement(auth_success, CAS + 'user')
+     username.text = user.username
+     if settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK:
+         callback = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_CALLBACK)
+         attrs = callback(user)
+         if len(attrs) > 0:
+             formater = get_callable(settings.CAS_CUSTOM_ATTRIBUTES_FORMATER)
+             formater(auth_success, attrs)
+     if pgt:
+         pgtElement = etree.SubElement(auth_success, CAS + 'proxyGrantingTicket')
+         pgtElement.text = pgt
+     if proxies:
+         proxiesElement = etree.SubElement(auth_success, CAS + "proxies")
+         for proxy in proxies:
+             proxyElement = etree.SubElement(proxiesElement, CAS + "proxy")
+             proxyElement.text = proxy
+     return unicode(etree.tostring(response, encoding='utf-8'), 'utf-8')