Merge branch 'develop'
[django-cas-provider.git] / cas_provider / tests.py
1 from cas_provider.models import ServiceTicket
2 from cas_provider.views import _cas2_sucess_response, _cas2_error_response, \
3     INVALID_TICKET
4 from django.contrib.auth.models import User
5 from django.core.urlresolvers import reverse
6 from django.test import TestCase
7 from urlparse import urlparse
8
9
10 class ViewsTest(TestCase):
11
12     fixtures = ['cas_users.json', ]
13
14     def setUp(self):
15         self.service = 'http://example.com/'
16
17
18     def test_succeessful_login(self):
19         response = self._login_user('root', '123')
20         self._validate_cas1(response, True)
21
22         response = self.client.get(reverse('cas_login'), {'service': self.service}, follow=False)
23         self.assertEqual(response.status_code, 302)
24         self.assertTrue(response['location'].startswith('%s?ticket=' % self.service))
25
26         response = self.client.get(reverse('cas_login'), follow=False)
27         self.assertEqual(response.status_code, 302)
28         self.assertTrue(response['location'].startswith('http://testserver/'))
29
30         response = self.client.get(response['location'], follow=False)
31         self.assertIn(response.status_code, [302, 200])
32
33         response = self.client.get(reverse('cas_login'), {'service': self.service, 'warn': True}, follow=False)
34         self.assertEqual(response.status_code, 200)
35         self.assertTemplateUsed(response, 'cas/warn.html')
36
37
38     def test_logout(self):
39         response = self._login_user('root', '123')
40         self._validate_cas1(response, True)
41
42         response = self.client.get(reverse('cas_logout'), follow=False)
43         self.assertEqual(response.status_code, 200)
44
45         response = self.client.get(reverse('cas_login'), follow=False)
46         self.assertEqual(response.status_code, 200)
47         self.assertEqual(response.context['user'].is_anonymous(), True)
48
49
50     def test_broken_pwd(self):
51         self._fail_login('root', '321')
52
53     def test_broken_username(self):
54         self._fail_login('notroot', '123')
55
56     def test_nonactive_user_login(self):
57         self._fail_login('nonactive', '123')
58
59     def test_cas2_success_validate(self):
60         response = self._login_user('root', '123')
61         self._validate_cas2(response, True)
62
63     def test_cas2_fail_validate(self):
64         for user, pwd in (('root', '321'), ('notroot', '123'), ('nonactive', '123')):
65             response = self._login_user(user, pwd)
66             self._validate_cas2(response, False)
67
68
69     def _fail_login(self, username, password):
70         response = self._login_user(username, password)
71         self._validate_cas1(response, False)
72
73         response = self.client.get(reverse('cas_login'), {'service': self.service}, follow=False)
74         self.assertEqual(response.status_code, 200)
75         response = self.client.get(reverse('cas_login'), follow=False)
76         self.assertEqual(response.status_code, 200)
77
78
79
80     def _login_user(self, username, password):
81         self.username = username
82         response = self.client.get(reverse('cas_login'), {'service': self.service})
83         self.assertEqual(response.status_code, 200)
84         self.assertTemplateUsed(response, 'cas/login.html')
85         form = response.context['form']
86         service = form['service'].value()
87         return self.client.post(reverse('cas_login'), {
88             'username': username,
89             'password': password,
90             'lt': form['lt'].value(),
91             'service': service
92         }, follow=False)
93
94
95     def _validate_cas1(self, response, is_correct=True):
96         if is_correct:
97             self.assertEqual(response.status_code, 302)
98             self.assertTrue(response.has_header('location'))
99             location = urlparse(response['location'])
100             ticket = location.query.split('=')[1]
101
102             response = self.client.get(reverse('cas_validate'), {'ticket': ticket, 'service': self.service}, follow=False)
103             self.assertEqual(response.status_code, 200)
104             self.assertEqual(unicode(response.content), u'yes\n%s\n' % self.username)
105         else:
106             self.assertEqual(response.status_code, 200)
107             self.assertEqual(len(response.context['form'].errors), 1)
108
109             response = self.client.get(reverse('cas_validate'), {'ticket': 'ST-12312312312312312312312', 'service': self.service}, follow=False)
110             self.assertEqual(response.status_code, 200)
111             self.assertEqual(response.content, u'no\n\n')
112
113
114     def _validate_cas2(self, response, is_correct=True):
115         if is_correct:
116             self.assertEqual(response.status_code, 302)
117             self.assertTrue(response.has_header('location'))
118             location = urlparse(response['location'])
119             ticket = location.query.split('=')[1]
120
121             response = self.client.get(reverse('cas_service_validate'), {'ticket': ticket, 'service': self.service}, follow=False)
122             self.assertEqual(response.status_code, 200)
123             self.assertEqual(response.content, _cas2_sucess_response(self.username).content)
124         else:
125             self.assertEqual(response.status_code, 200)
126             self.assertEqual(len(response.context['form'].errors), 1)
127
128             response = self.client.get(reverse('cas_service_validate'), {'ticket': 'ST-12312312312312312312312', 'service': self.service}, follow=False)
129             self.assertEqual(response.status_code, 200)
130             self.assertEqual(response.content, _cas2_error_response(INVALID_TICKET).content)
131
132
133 class ModelsTestCase(TestCase):
134
135     fixtures = ['cas_users.json', ]
136
137     def setUp(self):
138         self.user = User.objects.get(username='root')
139
140     def test_redirects(self):
141         ticket = ServiceTicket.objects.create(service='http://example.com', user=self.user)
142         self.assertEqual(ticket.get_redirect_url(), '%(service)s?ticket=%(ticket)s' % ticket.__dict__)
143