using south for db migrations
[django-cas-provider.git] / cas_provider / forms.py
index ba77b62..ba6aff0 100644 (file)
@@ -1,16 +1,37 @@
 from django import forms
 from django import forms
-from django.contrib.auth.forms import AuthenticationForm
+from django.conf import settings
 from django.contrib.auth import authenticate
 from django.contrib.auth import authenticate
+from django.contrib.auth.forms import AuthenticationForm
+from django.forms import ValidationError
+from django.utils.translation import ugettext_lazy as _
+from models import LoginTicket
+import datetime
+
+
+__all__ = ['LoginForm', ]
+
+
+class LoginForm(AuthenticationForm):
+    lt = forms.CharField(widget=forms.HiddenInput)
+    service = forms.CharField(widget=forms.HiddenInput, required=False)
+
+    def clean_lt(self):
+        ticket = self.cleaned_data['lt']
+        timeframe = datetime.datetime.now() - \
+                    datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION)
+        try:
+            return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe)
+        except LoginTicket.DoesNotExist:
+            raise ValidationError(_('Login ticket expired. Please try again.'))
+        return ticket
 
 
-from cas_provider.utils import create_login_ticket
+    def clean(self):
+        AuthenticationForm.clean(self)
+        self.cleaned_data.get('lt').delete()
+        return self.cleaned_data
 
 
-class LoginForm(forms.Form):
-    username = forms.CharField(max_length=30)
-    password = forms.CharField(widget=forms.PasswordInput)
-    #warn = forms.BooleanField(required=False)  # TODO: Implement
-    lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket)
-    def __init__(self, service=None, renew=None, gateway=None, request=None, *args, **kwargs):
-        super(LoginForm, self).__init__(*args, **kwargs)
-        self.request = request
-        if service is not None:
-            self.fields['service'] = forms.CharField(widget=forms.HiddenInput, initial=service)
\ No newline at end of file
+    def get_errors(self):
+        errors = []
+        for k, error in self.errors.items():
+            errors += [e for e in error]
+        return errors