X-Git-Url: https://git.mdrn.pl/django-cas-provider.git/blobdiff_plain/4f3dc0c8018d8ca6ff31429043d64ccbdae8b956..11d994f1d3d527e130eedbdacca59aa600f98fa8:/cas_provider/forms.py?ds=sidebyside diff --git a/cas_provider/forms.py b/cas_provider/forms.py index ba6aff0..662e517 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -1,37 +1,32 @@ from django import forms from django.conf import settings from django.contrib.auth import authenticate -from django.contrib.auth.forms import AuthenticationForm from django.forms import ValidationError from django.utils.translation import ugettext_lazy as _ from models import LoginTicket import datetime -__all__ = ['LoginForm', ] - - -class LoginForm(AuthenticationForm): - lt = forms.CharField(widget=forms.HiddenInput) +class LoginForm(forms.Form): + email = forms.CharField(widget=forms.TextInput(attrs={'autofocus': 'autofocus', + 'max_length': '255'})) + password = forms.CharField(widget=forms.PasswordInput) service = forms.CharField(widget=forms.HiddenInput, required=False) + remember_me = forms.BooleanField(required=False, label="Keep me signed in", + widget=forms.CheckboxInput(attrs={'class': 'remember_me'})) - def clean_lt(self): - ticket = self.cleaned_data['lt'] - timeframe = datetime.datetime.now() - \ - datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION) - try: - return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe) - except LoginTicket.DoesNotExist: - raise ValidationError(_('Login ticket expired. Please try again.')) - return ticket + def __init__(self, *args, **kwargs): + # renew = kwargs.pop('renew', None) + # gateway = kwargs.pop('gateway', None) + request = kwargs.pop('request', None) + super(LoginForm, self).__init__(*args, **kwargs) + self.request = request - def clean(self): - AuthenticationForm.clean(self) - self.cleaned_data.get('lt').delete() - return self.cleaned_data + def clean_remember_me(self): + remember = self.cleaned_data['remember_me'] + if not remember and self.request is not None: + self.request.session.set_expiry(0) + - def get_errors(self): - errors = [] - for k, error in self.errors.items(): - errors += [e for e in error] - return errors +class MergeLoginForm(LoginForm): + email = forms.CharField(max_length=255, widget=forms.HiddenInput)