fix race in filters
[wolnelektury.git] / src / api / tests / tests.py
index ccee689..5ad5ca3 100644 (file)
@@ -1,28 +1,26 @@
-# -*- coding: utf-8 -*-
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
-from os import path
 import hashlib
 import hmac
 import hashlib
 import hmac
+from io import BytesIO
 import json
 import json
-from StringIO import StringIO
+from os import path
 from time import time
 from time import time
-from urllib import quote, urlencode
-from urlparse import parse_qs
+from unittest.mock import patch
+from urllib.parse import quote, urlencode, parse_qs
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
-from mock import patch
-from piston.models import Consumer, Token
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
+from api.models import Consumer, Token
 
 
 @override_settings(
 
 
 @override_settings(
@@ -31,6 +29,8 @@ import picture.tests
         'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}},
 )
 class ApiTest(TestCase):
         'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}},
 )
 class ApiTest(TestCase):
+    maxDiff = None
+
     def load_json(self, url):
         content = self.client.get(url).content
         try:
     def load_json(self, url):
         content = self.client.get(url).content
         try:
@@ -39,6 +39,13 @@ class ApiTest(TestCase):
             self.fail('No JSON could be decoded: %s' % content)
         return data
 
             self.fail('No JSON could be decoded: %s' % content)
         return data
 
+    def assert_response(self, url, name):
+        content = self.client.get(url).content.decode('utf-8').rstrip()
+        filename = path.join(path.dirname(__file__), 'res', 'responses', name)
+        with open(filename) as f:
+            good_content = f.read().rstrip()
+        self.assertEqual(content, good_content, content)
+
     def assert_json_response(self, url, name):
         data = self.load_json(url)
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
     def assert_json_response(self, url, name):
         data = self.load_json(url)
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
@@ -101,16 +108,18 @@ class TagTests(ApiTest):
 class PictureTests(ApiTest):
     def test_publish(self):
         slug = "kandinsky-composition-viii"
 class PictureTests(ApiTest):
     def test_publish(self):
         slug = "kandinsky-composition-viii"
-        xml = SimpleUploadedFile(
-            'composition8.xml',
-            open(path.join(
+        with open(path.join(
                 picture.tests.__path__[0], "files", slug + ".xml"
                 picture.tests.__path__[0], "files", slug + ".xml"
-            )).read())
-        img = SimpleUploadedFile(
-            'kompozycja-8.png',
-            open(path.join(
+            ), 'rb') as f:
+            xml = SimpleUploadedFile(
+                'composition8.xml',
+                f.read())
+        with open(path.join(
                 picture.tests.__path__[0], "files", slug + ".png"
                 picture.tests.__path__[0], "files", slug + ".png"
-            )).read())
+            ), 'rb') as f:
+            img = SimpleUploadedFile(
+                'kompozycja-8.png',
+                f.read())
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
@@ -130,6 +139,7 @@ class BooksTests(ApiTest):
     def test_books(self):
         self.assert_json_response('/api/books/', 'books.json')
         self.assert_json_response('/api/books/?new_api=true', 'books.json')
     def test_books(self):
         self.assert_json_response('/api/books/', 'books.json')
         self.assert_json_response('/api/books/?new_api=true', 'books.json')
+        self.assert_response('/api/books/?format=xml', 'books.xml')
 
         self.assert_slugs('/api/audiobooks/', ['parent'])
         self.assert_slugs('/api/daisy/', ['parent'])
 
         self.assert_slugs('/api/audiobooks/', ['parent'])
         self.assert_slugs('/api/daisy/', ['parent'])
@@ -171,12 +181,15 @@ class BooksTests(ApiTest):
             '/api/filter-books/?lektura=true',
             [])
 
             '/api/filter-books/?lektura=true',
             [])
 
-        self.assert_slugs(
-            '/api/filter-books/?preview=true',
-            ['grandchild'])
+        Book.objects.filter(slug='grandchild').update(preview=True)
+        # Skipping: we don't allow previewed books in filtered list.
+        #self.assert_slugs(
+        #    '/api/filter-books/?preview=true',
+        #    ['grandchild'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
+        Book.objects.filter(slug='grandchild').update(preview=False)
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
@@ -222,18 +235,15 @@ class BooksTests(ApiTest):
 
 class BlogTests(ApiTest):
     def test_get(self):
 
 class BlogTests(ApiTest):
     def test_get(self):
-        self.assertEqual(self.load_json('/api/blog/'), [])
-
-
-class PreviewTests(ApiTest):
-    def unauth(self):
-        self.assert_json_response('/api/preview/', 'preview.json')
+        self.assertEqual(self.load_json('/api/blog'), [])
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
+        cls.user.set_password('test')
+        cls.user.save()
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
@@ -245,7 +255,8 @@ class OAuth1Tests(ApiTest):
         User.objects.all().delete()
 
     def test_create_token(self):
         User.objects.all().delete()
 
     def test_create_token(self):
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        # Fetch request token.
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
@@ -254,42 +265,59 @@ class OAuth1Tests(ApiTest):
             quote(base_query, safe='')
         ])
         h = hmac.new(
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&', raw, hashlib.sha1
+            (quote(self.consumer_secret) + '&').encode('latin1'),
+            raw.encode('latin1'),
+            hashlib.sha1
         ).digest()
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
-        request_token = parse_qs(response.content)
+        request_token_data = parse_qs(response.content.decode('latin1'))
+        request_token = request_token_data['oauth_token'][0]
+        request_token_secret = request_token_data['oauth_token_secret'][0]
+
+        # Request token authorization.
+        self.client.login(username='test', password='test')
+        response = self.client.get(
+            '/api/oauth/authorize/?oauth_token=%s&oauth_callback=test://oauth.callback/' % (
+                request_token,
+            )
+        )
+        post_data = response.context['form'].initial
 
 
-        Token.objects.filter(
-            key=request_token['oauth_token'][0], token_type=Token.REQUEST
-        ).update(user=self.user, is_approved=True)
+        response = self.client.post('/api/oauth/authorize/?' + urlencode(post_data))
+        self.assertEqual(
+            response['Location'],
+            'test://oauth.callback/?oauth_token=' + request_token
+        )
 
 
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        # Fetch access token.
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
-                          int(time()), request_token['oauth_token'][0]))
+                          int(time()), request_token))
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&' +
-            quote(request_token['oauth_token_secret'][0], safe=''),
-            raw,
+            (quote(self.consumer_secret) + '&' +
+             quote(request_token_secret, safe='')).encode('latin1'),
+            raw.encode('latin1'),
             hashlib.sha1
         ).digest()
             hashlib.sha1
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         sign = quote(h)
-        query = u"{}&oauth_signature={}".format(base_query, sign)
-        response = self.client.get(u'/api/oauth/access_token/?' + query)
-        access_token = parse_qs(response.content)
+        query = "{}&oauth_signature={}".format(base_query, sign)
+        response = self.client.get('/api/oauth/access_token/?' + query)
+        access_token_data = parse_qs(response.content.decode('latin1'))
+        access_token = access_token_data['oauth_token'][0]
 
         self.assertTrue(
             Token.objects.filter(
 
         self.assertTrue(
             Token.objects.filter(
-                key=access_token['oauth_token'][0],
+                key=access_token,
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
@@ -311,7 +339,7 @@ class AuthorizedTests(ApiTest):
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
-        cls.key = cls.consumer.secret + '&' + cls.token.secret
+        cls.key = (cls.consumer.secret + '&' + cls.token.secret).encode('latin1')
 
     @classmethod
     def tearDownClass(cls):
 
     @classmethod
     def tearDownClass(cls):
@@ -319,10 +347,10 @@ class AuthorizedTests(ApiTest):
         cls.consumer.delete()
         super(AuthorizedTests, cls).tearDownClass()
 
         cls.consumer.delete()
         super(AuthorizedTests, cls).tearDownClass()
 
-    def signed(self, url, method='GET', params=None):
+    def signed(self, url, method='GET', params=None, data=None):
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
-            "oauth_nonce": "%f" % time(),
+            "oauth_nonce": ("%f" % time()).replace('.', ''),
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
@@ -332,30 +360,45 @@ class AuthorizedTests(ApiTest):
         sign_params = {}
         if params:
             sign_params.update(params)
         sign_params = {}
         if params:
             sign_params.update(params)
+        if data:
+            sign_params.update(data)
         sign_params.update(auth_params)
         raw = "&".join([
             method.upper(),
             quote('http://testserver' + url, safe=''),
             quote("&".join(
         sign_params.update(auth_params)
         raw = "&".join([
             method.upper(),
             quote('http://testserver' + url, safe=''),
             quote("&".join(
-                quote(str(k)) + "=" + quote(str(v))
+                quote(str(k), safe='') + "=" + quote(str(v), safe='')
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
-            self.key, raw, hashlib.sha1).digest()).rstrip('\n'))
+            self.key,
+            raw.encode('latin1'),
+            hashlib.sha1
+        ).digest()).rstrip(b'\n'))
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
-                url,
-                HTTP_AUTHORIZATION=auth
-            )
+            url,
+            data=urlencode(data) if data else None,
+            content_type='application/x-www-form-urlencoded',
+            HTTP_AUTHORIZATION=auth,
+        )
 
 
-    def signed_json(self, url, method='GET', params=None):
-        return json.loads(self.signed(url, method, params).content)
+    def signed_json(self, url, method='GET', params=None, data=None):
+        return json.loads(self.signed(url, method, params, data).content)
 
     def test_books(self):
 
     def test_books(self):
+        self.assertEqual(
+            [b['liked'] for b in self.signed_json('/api/books/')],
+            [False, False, False]
+        )
+        data = self.signed_json('/api/books/child/')
+        self.assertFalse(data['parent']['liked'])
+        self.assertFalse(data['children'][0]['liked'])
+
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
             {"likes": False}
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
             {"likes": False}
@@ -369,6 +412,9 @@ class AuthorizedTests(ApiTest):
         self.assertTrue(self.signed_json('/api/parent_books/')[0]['liked'])
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
         self.assertTrue(self.signed_json('/api/parent_books/')[0]['liked'])
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
+
+        self.assertTrue(self.signed_json(
+            '/api/books/child/')['parent']['liked'])
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
@@ -396,20 +442,60 @@ class AuthorizedTests(ApiTest):
             ['parent'])
 
     def test_subscription(self):
             ['parent'])
 
     def test_subscription(self):
+        Book.objects.filter(slug='grandchild').update(preview=True)
+
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
             {"username": "test", "premium": False})
         self.assertEqual(
             self.signed('/api/epub/grandchild/').status_code,
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
             {"username": "test", "premium": False})
         self.assertEqual(
             self.signed('/api/epub/grandchild/').status_code,
-            401)  # Not 403 because Piston.
+            403)
 
 
-        with patch('api.handlers.user_is_subscribed', return_value=True):
+        with patch('club.models.Membership.is_active_for', return_value=True):
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
             with patch('django.core.files.storage.Storage.open',
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
             with patch('django.core.files.storage.Storage.open',
-                       return_value=StringIO("<epub>")):
+                       return_value=BytesIO(b"<epub>")):
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
-                    "<epub>")
+                    b"<epub>")
+
+        Book.objects.filter(slug='grandchild').update(preview=False)
+
+    def test_publish(self):
+        response = self.signed('/api/books/',
+                               method='POST',
+                               data={"data": json.dumps({})})
+        self.assertEqual(response.status_code, 403)
+
+        response = self.signed('/api/pictures/',
+                               method='POST',
+                               data={"data": json.dumps({})})
+        self.assertEqual(response.status_code, 403)
+
+        self.user.is_superuser = True
+        self.user.save()
+
+        with patch('catalogue.models.Book.from_xml_file') as mock:
+            response = self.signed('/api/books/',
+                                   method='POST',
+                                   data={"data": json.dumps({
+                                       "book_xml": "<utwor/>"
+                                   })})
+            self.assertTrue(mock.called)
+        self.assertEqual(response.status_code, 201)
+
+        with patch('picture.models.Picture.from_xml_file') as mock:
+            response = self.signed('/api/pictures/',
+                                   method='POST',
+                                   data={"data": json.dumps({
+                                       "picture_xml": "<utwor/>",
+                                       "picture_image_data": "Kg==",
+                                   })})
+            self.assertTrue(mock.called)
+        self.assertEqual(response.status_code, 201)
+
+        self.user.is_superuser = False
+        self.user.save()