Inherit AuthenticationForm, support form override
authorGunnlaugur Thor Briem <gunnlaugur@gmail.com>
Wed, 8 Jun 2011 21:29:06 +0000 (21:29 +0000)
committerGunnlaugur Thor Briem <gunnlaugur@gmail.com>
Wed, 8 Jun 2011 21:37:12 +0000 (21:37 +0000)
... for compatibility with existing form customizations.

cas_provider/forms.py
cas_provider/views.py

index 652f437..731e7da 100644 (file)
@@ -1,6 +1,7 @@
 from django import forms
 from django.conf import settings
 from django.contrib.auth import authenticate
+from django.contrib.auth.forms import AuthenticationForm
 from django.forms import ValidationError
 from django.utils.translation import ugettext_lazy as _
 from models import LoginTicket
@@ -10,16 +11,10 @@ import datetime
 __all__ = ['LoginForm', ]
 
 
-class LoginForm(forms.Form):
-    username = forms.CharField(max_length=30, label=_('username'))
-    password = forms.CharField(widget=forms.PasswordInput, label=_('password'))
+class LoginForm(AuthenticationForm):
     lt = forms.CharField(widget=forms.HiddenInput)
     service = forms.CharField(widget=forms.HiddenInput, required=False)
 
-    def __init__(self, *args, **kwargs):
-        super(LoginForm, self).__init__(*args, **kwargs)
-        self._user = None
-
     def clean_lt(self):
         ticket = self.cleaned_data['lt']
         timeframe = datetime.datetime.now() - \
@@ -31,20 +26,10 @@ class LoginForm(forms.Form):
         return ticket
 
     def clean(self):
-        username = self.cleaned_data.get('username')
-        password = self.cleaned_data.get('password')
-        user = authenticate(username=username, password=password)
-        if user is None:
-            raise ValidationError(_('Incorrect username and/or password.'))
-        if not user.is_active:
-            raise ValidationError(_('This account is disabled.'))
-        self._user = user
+        super(LoginForm, self).clean(self)
         self.cleaned_data.get('lt').delete()
         return self.cleaned_data
 
-    def get_user(self):
-        return self._user
-    
     def get_errors(self):
         errors = []
         for k, error in self.errors.items():
index c9a44db..7a3e690 100644 (file)
@@ -26,7 +26,8 @@ ERROR_MESSAGES = (
 
 def login(request, template_name='cas/login.html', \
                 success_redirect=settings.LOGIN_REDIRECT_URL,
-                warn_template_name='cas/warn.html'):
+                warn_template_name='cas/warn.html',
+                form_class=LoginForm):
     service = request.GET.get('service', None)
     if request.user.is_authenticated():
         if service is not None:
@@ -40,7 +41,7 @@ def login(request, template_name='cas/login.html', \
         else:
             return HttpResponseRedirect(success_redirect)
     if request.method == 'POST':
-        form = LoginForm(request.POST)
+        form = form_class(data=request.POST, request=request)
         if form.is_valid():
             user = form.get_user()
             auth_login(request, user)
@@ -50,13 +51,15 @@ def login(request, template_name='cas/login.html', \
                 success_redirect = ticket.get_redirect_url()
             return HttpResponseRedirect(success_redirect)
     else:
-        form = LoginForm(initial={
+        form = form_class(request=request, initial={
             'service': service,
             'lt': LoginTicket.objects.create()
         })
+    if hasattr(request, 'session') and hasattr(request.session, 'set_test_cookie'):
+        request.session.set_test_cookie()
     return render_to_response(template_name, {
         'form': form,
-        'errors': form.get_errors()
+        'errors': form.get_errors() if hasattr(form, 'get_errors') else None,
     }, context_instance=RequestContext(request))