From 93472270e0bb9c9b3d3c54e99e9f13e4d272b367 Mon Sep 17 00:00:00 2001 From: Alex Kamedov Date: Sun, 24 Apr 2011 16:56:59 +0600 Subject: [PATCH 1/1] small refactoring --- cas_provider/forms.py | 50 +++++++++++++--- cas_provider/locale/ru/LC_MESSAGES/django.mo | Bin 1469 -> 1469 bytes cas_provider/locale/ru/LC_MESSAGES/django.po | 60 +++++++++---------- cas_provider/models.py | 53 ++++++++++++---- cas_provider/tests.py | 29 +++++++-- cas_provider/utils.py | 24 -------- cas_provider/views.py | 59 +++++++----------- 7 files changed, 161 insertions(+), 114 deletions(-) delete mode 100644 cas_provider/utils.py diff --git a/cas_provider/forms.py b/cas_provider/forms.py index 912f184..47c2fdc 100644 --- a/cas_provider/forms.py +++ b/cas_provider/forms.py @@ -1,16 +1,52 @@ from django import forms +from django.conf import settings +from django.contrib.auth import authenticate +from django.core.exceptions import ValidationError from django.utils.translation import ugettext_lazy as _ -from utils import create_login_ticket +from models import LoginTicket +import datetime + + +__all__ = ['LoginForm', ] class LoginForm(forms.Form): username = forms.CharField(max_length=30, label=_('username')) password = forms.CharField(widget=forms.PasswordInput, label=_('password')) - #warn = forms.BooleanField(required=False) # TODO: Implement - lt = forms.CharField(widget=forms.HiddenInput, initial=create_login_ticket) + lt = forms.CharField(widget=forms.HiddenInput) + service = forms.CharField(widget=forms.HiddenInput, required=False) - def __init__(self, service=None, renew=None, gateway=None, request=None, *args, **kwargs): + def __init__(self, *args, **kwargs): super(LoginForm, self).__init__(*args, **kwargs) - self.request = request - if service is not None: - self.fields['service'] = forms.CharField(widget=forms.HiddenInput, initial=service) + self._user = None + + def clean_lt(self): + ticket = self.cleaned_data['lt'] + timeframe = datetime.datetime.now() - \ + datetime.timedelta(minutes=settings.CAS_TICKET_EXPIRATION) + try: + return LoginTicket.objects.get(ticket=ticket, created__gte=timeframe) + except LoginTicket.DoesNotExist: + raise ValidationError(_('Login ticket expired. Please try again.')) + return ticket + + def clean(self): + username = self.cleaned_data.get('username') + password = self.cleaned_data.get('password') + user = authenticate(username=username, password=password) + if user is None: + raise ValidationError(_('Incorrect username and/or password.')) + if not user.is_active: + raise ValidationError(_('This account is disabled.')) + self._user = user + self.cleaned_data.get('lt').delete() + return self.cleaned_data + + def get_user(self): + return self._user + + def get_errors(self): + errors = [] + for k, error in self.errors.items(): + errors += [e for e in error] + return errors diff --git a/cas_provider/locale/ru/LC_MESSAGES/django.mo b/cas_provider/locale/ru/LC_MESSAGES/django.mo index b0e243a3f4a879f06bb2b65a2a452e48c50989bb..1091e950144740f02e6dbe02b744f7674212bb7b 100644 GIT binary patch delta 21 ccmdnXy_b7KE)$25iGrb-m8t3GQYH^(07m2mz5oCK delta 21 ccmdnXy_b7KE)$1=xq_jgm8tpWQYH^(07lIPyZ`_I diff --git a/cas_provider/locale/ru/LC_MESSAGES/django.po b/cas_provider/locale/ru/LC_MESSAGES/django.po index 8f3cbec..a72c935 100644 --- a/cas_provider/locale/ru/LC_MESSAGES/django.po +++ b/cas_provider/locale/ru/LC_MESSAGES/django.po @@ -1,12 +1,12 @@ # SOME DESCRIPTIVE TITLE. # Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER # This file is distributed under the same license as the PACKAGE package. -# Volf , 2011. +# Alex Kamedov , 2011. msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2011-04-07 11:57+0600\n" +"POT-Creation-Date: 2011-04-24 16:55+0600\n" "PO-Revision-Date: 2011-04-07 12:01+0600\n" "Last-Translator: Volf \n" "Language-Team: delux\n" @@ -14,58 +14,58 @@ msgstr "" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=UTF-8\n" "Content-Transfer-Encoding: 8bit\n" -"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n%" -"10<=4 && (n%100<10 || n%100>=20) ? 1 : 2);\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n" +"%10<=4 && (n%100<10 || n%100>=20) ? 1 : 2);\n" "X-Generator: Virtaal 0.6.1\n" -#: forms.py:9 +#: forms.py:14 msgid "username" msgstr "имя пользователя" -#: forms.py:10 +#: forms.py:15 msgid "password" msgstr "пароль" -#: models.py:6 -msgid "user" -msgstr "пользователь" +#: forms.py:30 +msgid "Login ticket expired. Please try again." +msgstr "Истек срок действия билета входа. Пожалуйста, попробуйте еще раз." -#: models.py:7 -msgid "service" -msgstr "сервис" +#: forms.py:38 +msgid "Incorrect username and/or password." +msgstr "Неверное имя пользователя и/или пароль." + +#: forms.py:40 +msgid "This account is disabled." +msgstr "Эта учетная запись отключена." -#: models.py:8 models.py:19 +#: models.py:14 msgid "ticket" msgstr "билет" -#: models.py:9 models.py:20 +#: models.py:15 msgid "created" msgstr "создан" -#: models.py:12 +#: models.py:34 +msgid "user" +msgstr "пользователь" + +#: models.py:35 +msgid "service" +msgstr "сервис" + +#: models.py:40 msgid "Service Ticket" msgstr "Билет для сервиса" -#: models.py:13 +#: models.py:41 msgid "Service Tickets" msgstr "Билеты для сервисов" -#: models.py:23 +#: models.py:59 msgid "Login Ticket" msgstr "Билет для входа" -#: models.py:24 +#: models.py:60 msgid "Login Tickets" msgstr "Билеты для входа" - -#: views.py:36 -msgid "Login ticket expired. Please try again." -msgstr "Истек срок действия билета входа. Пожалуйста, попробуйте еще раз." - -#: views.py:49 -msgid "This account is disabled." -msgstr "Эта учетная запись отключена." - -#: views.py:51 -msgid "Incorrect username and/or password." -msgstr "Неверное имя пользователя и/или пароль." diff --git a/cas_provider/models.py b/cas_provider/models.py index 516992d..ec4b695 100644 --- a/cas_provider/models.py +++ b/cas_provider/models.py @@ -1,29 +1,60 @@ from django.contrib.auth.models import User from django.db import models from django.utils.translation import ugettext_lazy as _ +from random import Random +import string +import urllib +import urlparse + __all__ = ['ServiceTicket', 'LoginTicket'] -class ServiceTicket(models.Model): + +class BaseTicket(models.Model): + ticket = models.CharField(_('ticket'), max_length=32) + created = models.DateTimeField(_('created'), auto_now=True) + + class Meta: + abstract = True + + def __init__(self, *args, **kwargs): + if 'ticket' not in kwargs: + kwargs['ticket'] = self._generate_ticket() + super(BaseTicket, self).__init__(*args, **kwargs) + + def __unicode__(self): + return self.ticket + + def _generate_ticket(self, length=29, chars=string.ascii_letters + string.digits): + """ Generates a random string of the requested length. Used for creation of tickets. """ + return u"%s-%s" % (self.prefix, ''.join(Random().sample(chars, length))) + + +class ServiceTicket(BaseTicket): user = models.ForeignKey(User, verbose_name=_('user')) service = models.URLField(_('service'), verify_exists=False) - ticket = models.CharField(_('ticket'), max_length=256) - created = models.DateTimeField(_('created'), auto_now=True) + + prefix = 'ST' class Meta: verbose_name = _('Service Ticket') verbose_name_plural = _('Service Tickets') - def __unicode__(self): - return "%s (%s) - %s" % (self.user.username, self.service, self.created) + def get_redirect_url(self): + parsed = urlparse.urlparse(self.service) + query = urlparse.parse_qs(parsed.query) + query['ticket'] = [self.ticket] + query = [ ((k, v) if len(v) > 1 else (k, v[0])) for k, v in query.iteritems()] + parsed = urlparse.ParseResult(parsed.scheme, parsed.netloc, + parsed.path, parsed.params, + urllib.urlencode(query), parsed.fragment) + return parsed.geturl() -class LoginTicket(models.Model): - ticket = models.CharField(_('ticket'), max_length=32) - created = models.DateTimeField(_('created'), auto_now=True) + +class LoginTicket(BaseTicket): + + prefix = 'LT' class Meta: verbose_name = _('Login Ticket') verbose_name_plural = _('Login Tickets') - - def __unicode__(self): - return "%s - %s" % (self.ticket, self.created) diff --git a/cas_provider/tests.py b/cas_provider/tests.py index 2ad95f7..58b8ded 100644 --- a/cas_provider/tests.py +++ b/cas_provider/tests.py @@ -1,9 +1,11 @@ +from cas_provider.models import ServiceTicket +from django.contrib.auth.models import User from django.core.urlresolvers import reverse from django.test import TestCase from urlparse import urlparse -class UserTest(TestCase): +class ViewsTest(TestCase): fixtures = ['cas_users.json', ] @@ -30,7 +32,7 @@ class UserTest(TestCase): def test_logout(self): response = self._login_user('root', '123') self._validate_cas1(response, True) - + response = self.client.get(reverse('cas_logout'), follow=False) self.assertEqual(response.status_code, 200) @@ -57,7 +59,7 @@ class UserTest(TestCase): self.assertEqual(response.status_code, 200) response = self.client.get(reverse('cas_login'), follow=False) self.assertEqual(response.status_code, 200) - + def _login_user(self, username, password): @@ -83,9 +85,24 @@ class UserTest(TestCase): response = self.client.get(reverse('cas_validate'), {'ticket': ticket, 'service': self.service}, follow=False) self.assertEqual(response.status_code, 200) - self.assertEqual(unicode(response.content), u'yes\r\n%s\r\n' % self.username if is_correct else u'no\r\n') + self.assertEqual(unicode(response.content), u'yes\r\n%s\r\n' % self.username) else: self.assertEqual(response.status_code, 200) - self.assertGreater(len(response.context['errors']), 0) - self.assertEqual(len(response.context['form'].errors), 0) + self.assertEqual(len(response.context['form'].errors), 1) + + response = self.client.get(reverse('cas_validate'), {'ticket': 'ST-12312312312312312312312', 'service': self.service}, follow=False) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, u'no\r\n\r\n') + + +class ModelsTestCase(TestCase): + + fixtures = ['cas_users.json', ] + + def setUp(self): + self.user = User.objects.get(username='root') + + def test_redirects(self): + ticket = ServiceTicket.objects.create(service='http://example.com', user=self.user) + self.assertEqual(ticket.get_redirect_url(), '%(service)s?ticket=%(ticket)s' % ticket.__dict__) diff --git a/cas_provider/utils.py b/cas_provider/utils.py deleted file mode 100644 index fcff8a6..0000000 --- a/cas_provider/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from models import ServiceTicket, LoginTicket -from random import Random -import string - - -def _generate_string(length=8, chars=string.ascii_letters + string.digits): - """ Generates a random string of the requested length. Used for creation of tickets. """ - return ''.join(Random().sample(chars, length)) - -def create_service_ticket(user, service): - """ Creates a new service ticket for the specified user and service. - Uses _generate_string. - """ - ticket_string = 'ST-' + _generate_string(29) # Total ticket length = 29 + 3 = 32 - ticket = ServiceTicket(service=service, user=user, ticket=ticket_string) - ticket.save() - return ticket - -def create_login_ticket(): - """ Creates a new login ticket for the login form. Uses _generate_string. """ - ticket_string = 'LT-' + _generate_string(29) - ticket = LoginTicket(ticket=ticket_string) - ticket.save() - return ticket_string diff --git a/cas_provider/views.py b/cas_provider/views.py index 3686d74..41d02fd 100644 --- a/cas_provider/views.py +++ b/cas_provider/views.py @@ -1,57 +1,44 @@ from django.conf import settings -from django.contrib.auth import authenticate, login as auth_login, \ +from django.contrib.auth import login as auth_login, \ logout as auth_logout from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import render_to_response from django.template import RequestContext -from django.utils.translation import ugettext_lazy as _ from forms import LoginForm from models import ServiceTicket, LoginTicket -from utils import create_service_ticket __all__ = ['login', 'validate', 'logout', 'service_validate'] -def login(request, template_name='cas/login.html', success_redirect=getattr(settings, 'LOGIN_REDIRECT_URL', '/accounts/')): +def login(request, template_name='cas/login.html', \ + success_redirect=getattr(settings, 'LOGIN_REDIRECT_URL', '/accounts/')): service = request.GET.get('service', None) if request.user.is_authenticated(): if service is not None: - ticket = create_service_ticket(request.user, service) - if service.find('?') == -1: - return HttpResponseRedirect(service + '?ticket=' + ticket.ticket) - else: - return HttpResponseRedirect(service + '&ticket=' + ticket.ticket) + ticket = ServiceTicket.objects.create(service=service, user=request.user) + return HttpResponseRedirect(ticket.get_redirect_url()) else: return HttpResponseRedirect(success_redirect) - errors = [] if request.method == 'POST': - username = request.POST.get('username', None) - password = request.POST.get('password', None) - service = request.POST.get('service', None) - lt = request.POST.get('lt', None) - - try: - login_ticket = LoginTicket.objects.get(ticket=lt) - except: - errors.append(_('Login ticket expired. Please try again.')) - else: - login_ticket.delete() - user = authenticate(username=username, password=password) - if user is not None: - if user.is_active: - auth_login(request, user) - if service is not None: - ticket = create_service_ticket(user, service) - return HttpResponseRedirect(service + '?ticket=' + ticket.ticket) - else: - return HttpResponseRedirect(success_redirect) - else: - errors.append(_('This account is disabled.')) - else: - errors.append(_('Incorrect username and/or password.')) - form = LoginForm(service) - return render_to_response(template_name, {'form': form, 'errors': errors}, context_instance=RequestContext(request)) + form = LoginForm(request.POST) + if form.is_valid(): + user = form.get_user() + auth_login(request, user) + service = form.cleaned_data.get('service') + if service is not None: + ticket = ServiceTicket.objects.create(service=service, user=user) + success_redirect = ticket.get_redirect_url() + return HttpResponseRedirect(success_redirect) + else: + form = LoginForm(initial={ + 'service': service, + 'lt': LoginTicket.objects.create() + }) + return render_to_response(template_name, { + 'form': form, + 'errors': form.get_errors() + }, context_instance=RequestContext(request)) def validate(request): -- 2.20.1