From e0a665a26179e5ab5bd781dba56e82f0fcdc4dd7 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 8 Jun 2011 21:29:06 +0000 Subject: [PATCH] Inherit AuthenticationForm, support form override ... for compatibility with existing form customizations. --- cas_provider/forms.py | 21 +++------------------ cas_provider/views.py | 11 +++++++---- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/cas_provider/forms.py b/cas_provider/forms.py index 652f437..731e7da 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -1,6 +1,7 @@ from django import forms from django.conf import settings from django.contrib.auth import authenticate +from django.contrib.auth.forms import AuthenticationForm from django.forms import ValidationError from django.utils.translation import ugettext_lazy as _ from models import LoginTicket @@ -10,16 +11,10 @@ import datetime __all__ = ['LoginForm', ] -class LoginForm(forms.Form): - username = forms.CharField(max_length=30, label=_('username')) - password = forms.CharField(widget=forms.PasswordInput, label=_('password')) +class LoginForm(AuthenticationForm): lt = forms.CharField(widget=forms.HiddenInput) service = forms.CharField(widget=forms.HiddenInput, required=False) - def __init__(self, *args, **kwargs): - super(LoginForm, self).__init__(*args, **kwargs) - self._user = None - def clean_lt(self): ticket = self.cleaned_data['lt'] timeframe = datetime.datetime.now() - \ @@ -31,20 +26,10 @@ class LoginForm(forms.Form): return ticket def clean(self): - username = self.cleaned_data.get('username') - password = self.cleaned_data.get('password') - user = authenticate(username=username, password=password) - if user is None: - raise ValidationError(_('Incorrect username and/or password.')) - if not user.is_active: - raise ValidationError(_('This account is disabled.')) - self._user = user + super(LoginForm, self).clean(self) self.cleaned_data.get('lt').delete() return self.cleaned_data - def get_user(self): - return self._user - def get_errors(self): errors = [] for k, error in self.errors.items(): diff --git a/cas_provider/views.py b/cas_provider/views.py index c9a44db..7a3e690 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -26,7 +26,8 @@ ERROR_MESSAGES = ( def login(request, template_name='cas/login.html', \ success_redirect=settings.LOGIN_REDIRECT_URL, - warn_template_name='cas/warn.html'): + warn_template_name='cas/warn.html', + form_class=LoginForm): service = request.GET.get('service', None) if request.user.is_authenticated(): if service is not None: @@ -40,7 +41,7 @@ def login(request, template_name='cas/login.html', \ else: return HttpResponseRedirect(success_redirect) if request.method == 'POST': - form = LoginForm(request.POST) + form = form_class(data=request.POST, request=request) if form.is_valid(): user = form.get_user() auth_login(request, user) @@ -50,13 +51,15 @@ def login(request, template_name='cas/login.html', \ success_redirect = ticket.get_redirect_url() return HttpResponseRedirect(success_redirect) else: - form = LoginForm(initial={ + form = form_class(request=request, initial={ 'service': service, 'lt': LoginTicket.objects.create() }) + if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'): + request.session.set_test_cookie() return render_to_response(template_name, { 'form': form, - 'errors': form.get_errors() + 'errors': form.get_errors() if hasattr(form, 'get_errors') else None, }, context_instance=RequestContext(request)) -- 2.20.1