From 4a51c1aaa9c3eeac5368ff4e06870a94f18123aa Mon Sep 17 00:00:00 2001 From: deyk Date: Wed, 18 Jan 2012 11:44:49 -0800 Subject: [PATCH 1/1] Added support for handling user-merging workflows at authentication time. - views.login now has a merge mode flag, and a corresponding /cas/merge/ url. - forms.LoginForm now does its own validation, simplifying views.login - views.social_auth_login hasn't been adapted yet. --- .gitignore | 1 + cas_provider/exceptions.py | 6 ++++ cas_provider/forms.py | 34 +++++++++++++++++--- cas_provider/urls.py | 11 ++++--- cas_provider/views.py | 64 +++++++++++++++++++++++++++----------- 5 files changed, 89 insertions(+), 27 deletions(-) create mode 100644 cas_provider/exceptions.py diff --git a/.gitignore b/.gitignore index 0d20b64..9063127 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.pyc +django_cas_provider.egg-info/ diff --git a/cas_provider/exceptions.py b/cas_provider/exceptions.py new file mode 100644 index 0000000..202eda8 --- /dev/null +++ b/cas_provider/exceptions.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""cas_provider.exceptions -- exceptions defined for CAS login workflows +""" + + +class SameEmailMismatchedPasswords(Exception): pass diff --git a/cas_provider/forms.py b/cas_provider/forms.py index cb3660b..8a1452a 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -2,15 +2,41 @@ from django import forms from django.contrib.auth.forms import AuthenticationForm from django.contrib.auth import authenticate +from models import ServiceTicket, LoginTicket from utils import create_login_ticket + class LoginForm(forms.Form): email = forms.CharField(max_length=255) password = forms.CharField(widget=forms.PasswordInput) #warn = forms.BooleanField(required=False) # TODO: Implement - lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket) - def __init__(self, service=None, renew=None, gateway=None, request=None, *args, **kwargs): + lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket, required=False) + service = forms.CharField(widget=forms.HiddenInput, initial='', required=False) + remember_me = forms.BooleanField(required=False) + + def __init__(self, *args, **kwargs): + # renew = kwargs.pop('renew', None) + # gateway = kwargs.pop('gateway', None) + request = kwargs.pop('request', None) super(LoginForm, self).__init__(*args, **kwargs) self.request = request - if service is not None: - self.fields['service'] = forms.CharField(widget=forms.HiddenInput, initial=service) + + def clean_remember_me(self): + remember = self.cleaned_data['remember_me'] + if not remember and self.request is not None: + self.request.session.set_expiry(0) + + def clean_lt(self): + lt = self.cleaned_data.get('lt', + self.initial.get('lt', + self.fields['lt'].initial())) + try: + login_ticket = LoginTicket.objects.get(ticket=lt) + except: + raise forms.ValidationError("Login ticket expired. Please try again.") + else: + login_ticket.delete() + + +class MergeLoginForm(LoginForm): + email = forms.CharField(max_length=255, widget=forms.HiddenInput) diff --git a/cas_provider/urls.py b/cas_provider/urls.py index e8b8a14..75399c0 100644 --- a/cas_provider/urls.py +++ b/cas_provider/urls.py @@ -3,8 +3,9 @@ from django.conf.urls.defaults import * from views import * urlpatterns = patterns('', - url(r'^login/', login), - url(r'^socialauth-login/$', socialauth_login), - url(r'^validate/', validate), - url(r'^logout/', logout), -) \ No newline at end of file + url(r'^login/', login), + url(r'^socialauth-login/$', socialauth_login), + url(r'^validate/', validate), + url(r'^logout/', logout), + url(r'^merge/', login, {'merge': True}) + ) diff --git a/cas_provider/views.py b/cas_provider/views.py index 6a13b22..ba7b0a2 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -1,16 +1,21 @@ +import urllib + from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import render_to_response from django.template import RequestContext from django.contrib.auth import authenticate from django.contrib.auth import login as auth_login, logout as auth_logout +from django.core.urlresolvers import reverse -from forms import LoginForm -from models import ServiceTicket, LoginTicket +from forms import LoginForm, MergeLoginForm +from models import ServiceTicket from utils import create_service_ticket +from exceptions import SameEmailMismatchedPasswords __all__ = ['login', 'validate', 'logout'] -def login(request, template_name='cas/login.html', success_redirect='/account/'): + +def login(request, template_name='cas/login.html', success_redirect='/account/', merge=False): service = request.GET.get('service', None) if service is not None: request.session['service'] = service @@ -25,20 +30,37 @@ def login(request, template_name='cas/login.html', success_redirect='/account/') return HttpResponseRedirect(success_redirect) errors = [] if request.method == 'POST': - email = request.POST.get('email', None) - password = request.POST.get('password', None) - service = request.POST.get('service', None) - lt = request.POST.get('lt', None) - if not request.POST.get('remember_me', None): - request.session.set_expiry(0) - - try: - login_ticket = LoginTicket.objects.get(ticket=lt) - except: - errors.append('Login ticket expired. Please try again.') + if merge: + form = MergeLoginForm(request.POST, request=request) else: - login_ticket.delete() - user = authenticate(username=email, password=password) + form = LoginForm(request.POST, request=request) + + if form.is_valid(): + try: + auth_args = dict(username=form.cleaned_data['email'], + password=form.cleaned_data['password']) + if merge: + # We only want to send the merge argument if it's + # True. If it it's False, we want it to propagate + # through the auth backends properly. + auth_args['merge'] = merge + user = authenticate(**auth_args) + except SameEmailMismatchedPasswords: + # Need to merge the accounts? + if merge: + # We shouldn't get here... + raise + else: + base_url = reverse('cas_provider_merge') + args = dict( + success_redirect=success_redirect, + email=form.cleaned_data['email'], + ) + if service is not None: + args['service'] = service + args = urllib.urlencode(args) + + return HttpResponseRedirect('%s?%s' % (base_url, args)) if user is not None: if user.is_active: auth_login(request, user) @@ -51,7 +73,11 @@ def login(request, template_name='cas/login.html', success_redirect='/account/') errors.append('This account is disabled.') else: errors.append('Incorrect username and/or password.') - form = LoginForm(service) + else: + if merge: + form = MergeLoginForm(request.GET, request=request) + else: + form = LoginForm(request.GET, request=request) return render_to_response(template_name, {'form': form, 'errors': errors}, context_instance=RequestContext(request)) def socialauth_login(request, template_name='cas/login.html', success_redirect='/account/'): @@ -79,7 +105,8 @@ def socialauth_login(request, template_name='cas/login.html', success_redirect=' else: errors.append('Incorrect username and/or password.') return render_to_response(template_name, {'errors': errors}, context_instance=RequestContext(request)) - + + def validate(request): service = request.GET.get('service', None) ticket_string = request.GET.get('ticket', None) @@ -93,6 +120,7 @@ def validate(request): pass return HttpResponse("no\n\n") + def logout(request, template_name='cas/logout.html'): url = request.GET.get('url', None) auth_logout(request) -- 2.20.1