X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/blobdiff_plain/23f180a9809cbeed655e82359f7c7e6cfc16fb68..6e6126a7402676610d76703bdb7c37dabddde1a8:/cas_provider/forms.py?ds=inline diff --git a/cas_provider/forms.py b/cas_provider/forms.py index 8a8226b..ba6aff0 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -1,17 +1,37 @@ from django import forms -from django.contrib.auth.forms import AuthenticationForm +from django.conf import settings from django.contrib.auth import authenticate +from django.contrib.auth.forms import AuthenticationForm +from django.forms import ValidationError from django.utils.translation import ugettext_lazy as _ +from models import LoginTicket +import datetime + + +__all__ = ['LoginForm', ] + + +class LoginForm(AuthenticationForm): + lt = forms.CharField(widget=forms.HiddenInput) + service = forms.CharField(widget=forms.HiddenInput, required=False) + + def clean_lt(self): + ticket = self.cleaned_data['lt'] + timeframe = datetime.datetime.now() - \ + datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION) + try: + return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe) + except LoginTicket.DoesNotExist: + raise ValidationError(_('Login ticket expired. Please try again.')) + return ticket -from utils import create_login_ticket + def clean(self): + AuthenticationForm.clean(self) + self.cleaned_data.get('lt').delete() + return self.cleaned_data -class LoginForm(forms.Form): - username = forms.CharField(max_length=30, label=_('username')) - password = forms.CharField(widget=forms.PasswordInput, label=_('password')) - #warn = forms.BooleanField(required=False) # TODO: Implement - lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket) - def __init__(self, service=None, renew=None, gateway=None, request=None, *args, **kwargs): - super(LoginForm, self).__init__(*args, **kwargs) - self.request = request - if service is not None: - self.fields['service'] = forms.CharField(widget=forms.HiddenInput, initial=service) \ No newline at end of file + def get_errors(self): + errors = [] + for k, error in self.errors.items(): + errors += [e for e in error] + return errors