Cleaning.
[wolnelektury.git] / src / api / tests / tests.py
index adaaa60..f85a061 100644 (file)
@@ -1,28 +1,26 @@
-# -*- coding: utf-8 -*-
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 from base64 import b64encode
-from os import path
 import hashlib
 import hmac
 import hashlib
 import hmac
+from io import BytesIO
 import json
 import json
-from StringIO import StringIO
+from os import path
 from time import time
 from time import time
-from urllib import quote, urlencode
-from urlparse import parse_qs
+from unittest.mock import patch
+from urllib.parse import quote, urlencode, parse_qs
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
 
 from django.contrib.auth.models import User
 from django.core.files.uploadedfile import SimpleUploadedFile
 from django.test import TestCase
 from django.test.utils import override_settings
-from mock import patch
-from piston.models import Consumer, Token
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
 
 from catalogue.models import Book, Tag
 from picture.forms import PictureImportForm
 from picture.models import Picture
 import picture.tests
+from api.models import Consumer, Token
 
 
 @override_settings(
 
 
 @override_settings(
@@ -40,12 +38,12 @@ class ApiTest(TestCase):
         return data
 
     def assert_response(self, url, name):
         return data
 
     def assert_response(self, url, name):
-        content = self.client.get(url).content.rstrip()
+        content = self.client.get(url).content.decode('utf-8').rstrip()
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
         with open(filename) as f:
             good_content = f.read().rstrip()
         self.assertEqual(content, good_content, content)
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
         with open(filename) as f:
             good_content = f.read().rstrip()
         self.assertEqual(content, good_content, content)
-    
+
     def assert_json_response(self, url, name):
         data = self.load_json(url)
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
     def assert_json_response(self, url, name):
         data = self.load_json(url)
         filename = path.join(path.dirname(__file__), 'res', 'responses', name)
@@ -112,12 +110,12 @@ class PictureTests(ApiTest):
             'composition8.xml',
             open(path.join(
                 picture.tests.__path__[0], "files", slug + ".xml"
             'composition8.xml',
             open(path.join(
                 picture.tests.__path__[0], "files", slug + ".xml"
-            )).read())
+            ), 'rb').read())
         img = SimpleUploadedFile(
             'kompozycja-8.png',
             open(path.join(
                 picture.tests.__path__[0], "files", slug + ".png"
         img = SimpleUploadedFile(
             'kompozycja-8.png',
             open(path.join(
                 picture.tests.__path__[0], "files", slug + ".png"
-            )).read())
+            ), 'rb').read())
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
 
         import_form = PictureImportForm({}, {
             'picture_xml_file': xml,
@@ -179,12 +177,15 @@ class BooksTests(ApiTest):
             '/api/filter-books/?lektura=true',
             [])
 
             '/api/filter-books/?lektura=true',
             [])
 
-        self.assert_slugs(
-            '/api/filter-books/?preview=true',
-            ['grandchild'])
+        Book.objects.filter(slug='grandchild').update(preview=True)
+        # Skipping: we don't allow previewed books in filtered list.
+        #self.assert_slugs(
+        #    '/api/filter-books/?preview=true',
+        #    ['grandchild'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
         self.assert_slugs(
             '/api/filter-books/?preview=false',
             ['child', 'parent'])
+        Book.objects.filter(slug='grandchild').update(preview=False)
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
 
         self.assert_slugs(
             '/api/filter-books/?audiobook=true',
@@ -230,18 +231,15 @@ class BooksTests(ApiTest):
 
 class BlogTests(ApiTest):
     def test_get(self):
 
 class BlogTests(ApiTest):
     def test_get(self):
-        self.assertEqual(self.load_json('/api/blog/'), [])
-
-
-class PreviewTests(ApiTest):
-    def unauth(self):
-        self.assert_json_response('/api/preview/', 'preview.json')
+        self.assertEqual(self.load_json('/api/blog'), [])
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
 
 
 class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
+        cls.user.set_password('test')
+        cls.user.save()
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
@@ -253,7 +251,8 @@ class OAuth1Tests(ApiTest):
         User.objects.all().delete()
 
     def test_create_token(self):
         User.objects.all().delete()
 
     def test_create_token(self):
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        # Fetch request token.
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
@@ -262,42 +261,59 @@ class OAuth1Tests(ApiTest):
             quote(base_query, safe='')
         ])
         h = hmac.new(
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&', raw, hashlib.sha1
+            (quote(self.consumer_secret) + '&').encode('latin1'),
+            raw.encode('latin1'),
+            hashlib.sha1
         ).digest()
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
-        request_token = parse_qs(response.content)
+        request_token_data = parse_qs(response.content.decode('latin1'))
+        request_token = request_token_data['oauth_token'][0]
+        request_token_secret = request_token_data['oauth_token_secret'][0]
+
+        # Request token authorization.
+        self.client.login(username='test', password='test')
+        response = self.client.get(
+            '/api/oauth/authorize/?oauth_token=%s&oauth_callback=test://oauth.callback/' % (
+                request_token,
+            )
+        )
+        post_data = response.context['form'].initial
 
 
-        Token.objects.filter(
-            key=request_token['oauth_token'][0], token_type=Token.REQUEST
-        ).update(user=self.user, is_approved=True)
+        response = self.client.post('/api/oauth/authorize/?' + urlencode(post_data))
+        self.assertEqual(
+            response['Location'],
+            'test://oauth.callback/?oauth_token=' + request_token
+        )
 
 
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        # Fetch access token.
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
-                          int(time()), request_token['oauth_token'][0]))
+                          int(time()), request_token))
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
             quote(base_query, safe='')
         ])
         h = hmac.new(
-            quote(self.consumer_secret) + '&' +
-            quote(request_token['oauth_token_secret'][0], safe=''),
-            raw,
+            (quote(self.consumer_secret) + '&' +
+             quote(request_token_secret, safe='')).encode('latin1'),
+            raw.encode('latin1'),
             hashlib.sha1
         ).digest()
             hashlib.sha1
         ).digest()
-        h = b64encode(h).rstrip('\n')
+        h = b64encode(h).rstrip(b'\n')
         sign = quote(h)
         sign = quote(h)
-        query = u"{}&oauth_signature={}".format(base_query, sign)
-        response = self.client.get(u'/api/oauth/access_token/?' + query)
-        access_token = parse_qs(response.content)
+        query = "{}&oauth_signature={}".format(base_query, sign)
+        response = self.client.get('/api/oauth/access_token/?' + query)
+        access_token_data = parse_qs(response.content.decode('latin1'))
+        access_token = access_token_data['oauth_token'][0]
 
         self.assertTrue(
             Token.objects.filter(
 
         self.assertTrue(
             Token.objects.filter(
-                key=access_token['oauth_token'][0],
+                key=access_token,
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
@@ -319,7 +335,7 @@ class AuthorizedTests(ApiTest):
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
             consumer=cls.consumer,
             token_type=Token.ACCESS,
             timestamp=time())
-        cls.key = cls.consumer.secret + '&' + cls.token.secret
+        cls.key = (cls.consumer.secret + '&' + cls.token.secret).encode('latin1')
 
     @classmethod
     def tearDownClass(cls):
 
     @classmethod
     def tearDownClass(cls):
@@ -327,10 +343,10 @@ class AuthorizedTests(ApiTest):
         cls.consumer.delete()
         super(AuthorizedTests, cls).tearDownClass()
 
         cls.consumer.delete()
         super(AuthorizedTests, cls).tearDownClass()
 
-    def signed(self, url, method='GET', params=None):
+    def signed(self, url, method='GET', params=None, data=None):
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
-            "oauth_nonce": "%f" % time(),
+            "oauth_nonce": ("%f" % time()).replace('.', ''),
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
@@ -340,30 +356,45 @@ class AuthorizedTests(ApiTest):
         sign_params = {}
         if params:
             sign_params.update(params)
         sign_params = {}
         if params:
             sign_params.update(params)
+        if data:
+            sign_params.update(data)
         sign_params.update(auth_params)
         raw = "&".join([
             method.upper(),
             quote('http://testserver' + url, safe=''),
             quote("&".join(
         sign_params.update(auth_params)
         raw = "&".join([
             method.upper(),
             quote('http://testserver' + url, safe=''),
             quote("&".join(
-                quote(str(k)) + "=" + quote(str(v))
+                quote(str(k), safe='') + "=" + quote(str(v), safe='')
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
                 for (k, v) in sorted(sign_params.items())))
         ])
         auth_params["oauth_signature"] = quote(b64encode(hmac.new(
-            self.key, raw, hashlib.sha1).digest()).rstrip('\n'))
+            self.key,
+            raw.encode('latin1'),
+            hashlib.sha1
+        ).digest()).rstrip(b'\n'))
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
         auth = 'OAuth realm="API", ' + ', '.join(
             '{}="{}"'.format(k, v) for (k, v) in auth_params.items())
 
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
-                url,
-                HTTP_AUTHORIZATION=auth
-            )
+            url,
+            data=urlencode(data) if data else None,
+            content_type='application/x-www-form-urlencoded',
+            HTTP_AUTHORIZATION=auth,
+        )
 
 
-    def signed_json(self, url, method='GET', params=None):
-        return json.loads(self.signed(url, method, params).content)
+    def signed_json(self, url, method='GET', params=None, data=None):
+        return json.loads(self.signed(url, method, params, data).content)
 
     def test_books(self):
 
     def test_books(self):
+        self.assertEqual(
+            [b['liked'] for b in self.signed_json('/api/books/')],
+            [False, False, False]
+        )
+        data = self.signed_json('/api/books/child/')
+        self.assertFalse(data['parent']['liked'])
+        self.assertFalse(data['children'][0]['liked'])
+
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
             {"likes": False}
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
             {"likes": False}
@@ -377,6 +408,9 @@ class AuthorizedTests(ApiTest):
         self.assertTrue(self.signed_json('/api/parent_books/')[0]['liked'])
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
         self.assertTrue(self.signed_json('/api/parent_books/')[0]['liked'])
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
+
+        self.assertTrue(self.signed_json(
+            '/api/books/child/')['parent']['liked'])
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
@@ -404,20 +438,60 @@ class AuthorizedTests(ApiTest):
             ['parent'])
 
     def test_subscription(self):
             ['parent'])
 
     def test_subscription(self):
+        Book.objects.filter(slug='grandchild').update(preview=True)
+
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
             {"username": "test", "premium": False})
         self.assertEqual(
             self.signed('/api/epub/grandchild/').status_code,
         self.assert_slugs('/api/preview/', ['grandchild'])
         self.assertEqual(
             self.signed_json('/api/username/'),
             {"username": "test", "premium": False})
         self.assertEqual(
             self.signed('/api/epub/grandchild/').status_code,
-            401)  # Not 403 because Piston.
+            403)
 
 
-        with patch('api.handlers.user_is_subscribed', return_value=True):
+        with patch('club.models.Membership.is_active_for', return_value=True):
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
             with patch('django.core.files.storage.Storage.open',
             self.assertEqual(
                 self.signed_json('/api/username/'),
                 {"username": "test", "premium": True})
             with patch('django.core.files.storage.Storage.open',
-                       return_value=StringIO("<epub>")):
+                       return_value=BytesIO(b"<epub>")):
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
                 self.assertEqual(
                     self.signed('/api/epub/grandchild/').content,
-                    "<epub>")
+                    b"<epub>")
+
+        Book.objects.filter(slug='grandchild').update(preview=False)
+
+    def test_publish(self):
+        response = self.signed('/api/books/',
+                               method='POST',
+                               data={"data": json.dumps({})})
+        self.assertEqual(response.status_code, 403)
+
+        response = self.signed('/api/pictures/',
+                               method='POST',
+                               data={"data": json.dumps({})})
+        self.assertEqual(response.status_code, 403)
+
+        self.user.is_superuser = True
+        self.user.save()
+
+        with patch('catalogue.models.Book.from_xml_file') as mock:
+            response = self.signed('/api/books/',
+                                   method='POST',
+                                   data={"data": json.dumps({
+                                       "book_xml": "<utwor/>"
+                                   })})
+            self.assertTrue(mock.called)
+        self.assertEqual(response.status_code, 201)
+
+        with patch('picture.models.Picture.from_xml_file') as mock:
+            response = self.signed('/api/pictures/',
+                                   method='POST',
+                                   data={"data": json.dumps({
+                                       "picture_xml": "<utwor/>",
+                                       "picture_image_data": "Kg==",
+                                   })})
+            self.assertTrue(mock.called)
+        self.assertEqual(response.status_code, 201)
+
+        self.user.is_superuser = False
+        self.user.save()