Updated and fixes.
[wolnelektury.git] / src / api / tests / tests.py
index c8e07d2..5ad5ca3 100644 (file)
@@ -1,28 +1,26 @@
-# -*- coding: utf-8 -*-
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
-from os import path
 import hashlib
 import hmac
 import hashlib
 import hmac
+from io import BytesIO
 import json
 import json
-from StringIO import StringIO
+from os import path
 from time import time
 from time import time
-from urllib import quote, urlencode
-from urlparse import parse_qs
+from unittest.mock import patch
+from urllib.parse import quote, urlencode, parse_qs
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
-from mock import patch
-from piston.models import Consumer, Token
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
+from api.models import Consumer, Token
 
 
 @override_settings(
 
 
 @override_settings(
@@ -31,6 +29,8 @@ import picture.tests
         'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}},
 )
 class ApiTest(TestCase):
         'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}},
 )
 class ApiTest(TestCase):
+    maxDiff = None
+
     def load_json(self, url):
         content = self.client.get(url).content
         try:
     def load_json(self, url):
         content = self.client.get(url).content
         try:
@@ -40,7 +40,7 @@ class ApiTest(TestCase):
         return data
 
     def assert_response(self, url, name):
         return data
 
     def assert_response(self, url, name):
-        content = self.client.get(url).content.rstrip()
+        content = self.client.get(url).content.decode('utf-8').rstrip()
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
         with open(filename) as f:
             good_content = f.read().rstrip()
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
         with open(filename) as f:
             good_content = f.read().rstrip()
@@ -108,16 +108,18 @@ class TagTests(ApiTest):
 class PictureTests(ApiTest):
     def test_publish(self):
         slug = "kandinsky-composition-viii"
 class PictureTests(ApiTest):
     def test_publish(self):
         slug = "kandinsky-composition-viii"
-        xml = SimpleUploadedFile(
-            'composition8.xml',
-            open(path.join(
+        with open(path.join(
                 picture.tests.__path__[0], "files", slug + ".xml"
                 picture.tests.__path__[0], "files", slug + ".xml"
-            )).read())
-        img = SimpleUploadedFile(
-            'kompozycja-8.png',
-            open(path.join(
+            ), 'rb') as f:
+            xml = SimpleUploadedFile(
+                'composition8.xml',
+                f.read())
+        with open(path.join(
                 picture.tests.__path__[0], "files", slug + ".png"
                 picture.tests.__path__[0], "files", slug + ".png"
-            )).read())
+            ), 'rb') as f:
+            img = SimpleUploadedFile(
+                'kompozycja-8.png',
+                f.read())
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
@@ -179,12 +181,15 @@ class BooksTests(ApiTest):
             '/api/filter-books/?lektura=true',
             [])
 
             '/api/filter-books/?lektura=true',
             [])
 
-        self.assert_slugs(
-            '/api/filter-books/?preview=true',
-            ['grandchild'])
+        Book.objects.filter(slug='grandchild').update(preview=True)
+        # Skipping: we don't allow previewed books in filtered list.
+        #self.assert_slugs(
+        #    '/api/filter-books/?preview=true',
+        #    ['grandchild'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
+        Book.objects.filter(slug='grandchild').update(preview=False)
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
@@ -230,18 +235,15 @@ class BooksTests(ApiTest):
 
 class BlogTests(ApiTest):
     def test_get(self):
 
 class BlogTests(ApiTest):
     def test_get(self):
-        self.assertEqual(self.load_json('/api/blog/'), [])
-
-
-class PreviewTests(ApiTest):
-    def unauth(self):
-        self.assert_json_response('/api/preview/', 'preview.json')
+        self.assertEqual(self.load_json('/api/blog'), [])
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
+        cls.user.set_password('test')
+        cls.user.save()
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
@@ -253,6 +255,7 @@ class OAuth1Tests(ApiTest):
         User.objects.all().delete()
 
     def test_create_token(self):
         User.objects.all().delete()
 
     def test_create_token(self):
+        # Fetch request token.
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
@@ -262,42 +265,59 @@ class OAuth1Tests(ApiTest):
             quote(base_query, safe='')
         ])
         h = hmac.new(
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&', raw, hashlib.sha1
+            (quote(self.consumer_secret) + '&').encode('latin1'),
+            raw.encode('latin1'),
+            hashlib.sha1
         ).digest()
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
-        request_token = parse_qs(response.content)
+        request_token_data = parse_qs(response.content.decode('latin1'))
+        request_token = request_token_data['oauth_token'][0]
+        request_token_secret = request_token_data['oauth_token_secret'][0]
+
+        # Request token authorization.
+        self.client.login(username='test', password='test')
+        response = self.client.get(
+            '/api/oauth/authorize/?oauth_token=%s&oauth_callback=test://oauth.callback/' % (
+                request_token,
+            )
+        )
+        post_data = response.context['form'].initial
 
 
-        Token.objects.filter(
-            key=request_token['oauth_token'][0], token_type=Token.REQUEST
-        ).update(user=self.user, is_approved=True)
+        response = self.client.post('/api/oauth/authorize/?' + urlencode(post_data))
+        self.assertEqual(
+            response['Location'],
+            'test://oauth.callback/?oauth_token=' + request_token
+        )
 
 
+        # Fetch access token.
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
-                          int(time()), request_token['oauth_token'][0]))
+                          int(time()), request_token))
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&' +
-            quote(request_token['oauth_token_secret'][0], safe=''),
-            raw,
+            (quote(self.consumer_secret) + '&' +
+             quote(request_token_secret, safe='')).encode('latin1'),
+            raw.encode('latin1'),
             hashlib.sha1
         ).digest()
             hashlib.sha1
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         sign = quote(h)
-        query = u"{}&oauth_signature={}".format(base_query, sign)
-        response = self.client.get(u'/api/oauth/access_token/?' + query)
-        access_token = parse_qs(response.content)
+        query = "{}&oauth_signature={}".format(base_query, sign)
+        response = self.client.get('/api/oauth/access_token/?' + query)
+        access_token_data = parse_qs(response.content.decode('latin1'))
+        access_token = access_token_data['oauth_token'][0]
 
         self.assertTrue(
             Token.objects.filter(
 
         self.assertTrue(
             Token.objects.filter(
-                key=access_token['oauth_token'][0],
+                key=access_token,
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
@@ -319,7 +339,7 @@ class AuthorizedTests(ApiTest):
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
-        cls.key = cls.consumer.secret + '&' + cls.token.secret
+        cls.key = (cls.consumer.secret + '&' + cls.token.secret).encode('latin1')
 
     @classmethod
     def tearDownClass(cls):
 
     @classmethod
     def tearDownClass(cls):
@@ -351,7 +371,10 @@ class AuthorizedTests(ApiTest):
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
-            self.key, raw, hashlib.sha1).digest()).rstrip('\n'))
+            self.key,
+            raw.encode('latin1'),
+            hashlib.sha1
+        ).digest()).rstrip(b'\n'))
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
@@ -419,6 +442,8 @@ class AuthorizedTests(ApiTest):
             ['parent'])
 
     def test_subscription(self):
             ['parent'])
 
     def test_subscription(self):
+        Book.objects.filter(slug='grandchild').update(preview=True)
+
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
@@ -427,16 +452,17 @@ class AuthorizedTests(ApiTest):
             self.signed('/api/epub/grandchild/').status_code,
             403)
 
             self.signed('/api/epub/grandchild/').status_code,
             403)
 
-        with patch('api.fields.user_is_subscribed', return_value=True):
+        with patch('club.models.Membership.is_active_for', return_value=True):
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
-        with patch('paypal.permissions.user_is_subscribed', return_value=True):
             with patch('django.core.files.storage.Storage.open',
             with patch('django.core.files.storage.Storage.open',
-                       return_value=StringIO("<epub>")):
+                       return_value=BytesIO(b"<epub>")):
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
-                    "<epub>")
+                    b"<epub>")
+
+        Book.objects.filter(slug='grandchild').update(preview=False)
 
     def test_publish(self):
         response = self.signed('/api/books/',
 
     def test_publish(self):
         response = self.signed('/api/books/',