X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/blobdiff_plain/3e3322506beb1dc6f1e4a2b5343c81fbf7e77a17..f9c70e42b4ba7725813b8bc351b6784540d62f04:/cas_provider/forms.py diff --git a/cas_provider/forms.py b/cas_provider/forms.py index 80b8913..ba6aff0 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -1,15 +1,37 @@ from django import forms +from django.conf import settings +from django.contrib.auth import authenticate +from django.contrib.auth.forms import AuthenticationForm +from django.forms import ValidationError from django.utils.translation import ugettext_lazy as _ +from models import LoginTicket +import datetime -from utils import create_login_ticket - -class LoginForm(forms.Form): - username = forms.CharField(max_length=30, label=_('username')) - password = forms.CharField(widget=forms.PasswordInput, label=_('password')) - #warn = forms.BooleanField(required=False) # TODO: Implement - lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket) - def __init__(self, service=None, renew=None, gateway=None, request=None, *args, **kwargs): - super(LoginForm, self).__init__(*args, **kwargs) - self.request = request - if service is not None: - self.fields['service'] = forms.CharField(widget=forms.HiddenInput, initial=service) \ No newline at end of file + +__all__ = ['LoginForm', ] + + +class LoginForm(AuthenticationForm): + lt = forms.CharField(widget=forms.HiddenInput) + service = forms.CharField(widget=forms.HiddenInput, required=False) + + def clean_lt(self): + ticket = self.cleaned_data['lt'] + timeframe = datetime.datetime.now() - \ + datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION) + try: + return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe) + except LoginTicket.DoesNotExist: + raise ValidationError(_('Login ticket expired. Please try again.')) + return ticket + + def clean(self): + AuthenticationForm.clean(self) + self.cleaned_data.get('lt').delete() + return self.cleaned_data + + def get_errors(self): + errors = [] + for k, error in self.errors.items(): + errors += [e for e in error] + return errors