Start replacing Piston in OAuth flow with OAuthLib.
authorRadek Czajka <rczajka@rczajka.pl>
Wed, 6 Feb 2019 23:56:21 +0000 (00:56 +0100)
committerRadek Czajka <rczajka@rczajka.pl>
Wed, 6 Feb 2019 23:56:21 +0000 (00:56 +0100)
requirements/requirements.txt
src/api/drf_auth.py
src/api/request_validator.py [new file with mode: 0644]
src/api/tests/tests.py
src/api/urls.py
src/api/views.py

index 3b8c95e..34f4d9a 100644 (file)
@@ -15,6 +15,7 @@ django-allauth>=0.32,<0.33
 django-extensions
 djangorestframework<3.7
 djangorestframework-xml
 django-extensions
 djangorestframework<3.7
 djangorestframework-xml
+oauthlib>=3.0.1,<3.1
 
 # contact
 pyyaml
 
 # contact
 pyyaml
index 26018c6..ca6a491 100644 (file)
@@ -1,20 +1,29 @@
-"""
-Transitional code: bridge between Piston's OAuth implementation
-and DRF views.
-"""
-from piston.authentication import OAuthAuthentication
+# -*- coding: utf-8 -*-
+# This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
+# Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
+#
+from oauthlib.oauth1 import ResourceEndpoint
 from rest_framework.authentication import BaseAuthentication
 from rest_framework.authentication import BaseAuthentication
+from .request_validator import PistonRequestValidator
 
 
 class PistonOAuthAuthentication(BaseAuthentication):
     def __init__(self):
 
 
 class PistonOAuthAuthentication(BaseAuthentication):
     def __init__(self):
-        self.piston_auth = OAuthAuthentication()
+        validator = PistonRequestValidator()
+        self.provider = ResourceEndpoint(validator)
 
     def authenticate_header(self, request):
         return 'OAuth realm="API"'
 
     def authenticate(self, request):
 
     def authenticate_header(self, request):
         return 'OAuth realm="API"'
 
     def authenticate(self, request):
-        if self.piston_auth.is_valid_request(request):
-            consumer, token, parameters = self.piston_auth.validate_token(request)
-            if consumer and token:
-                return token.user, token
+        v, r = self.provider.validate_protected_resource_request(
+            request.build_absolute_uri(),
+            http_method=request.method,
+            body=request.body,
+            headers={
+                "Authorization": request.META['HTTP_AUTHORIZATION'],
+                "Content-Type": request.content_type,
+            } if 'HTTP_AUTHORIZATION' in request.META else None
+        )
+        if v:
+            return r.token.user, r.token
diff --git a/src/api/request_validator.py b/src/api/request_validator.py
new file mode 100644 (file)
index 0000000..b8554ad
--- /dev/null
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
+# Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
+#
+from oauthlib.oauth1 import RequestValidator
+from piston.models import Consumer, Nonce, Token
+
+
+class PistonRequestValidator(RequestValidator):
+    dummy_access_token = '!'
+    realms = ['API']
+
+    # Just for the tests.
+    # It'd be a little more kosher to use test client with secure=True.
+    enforce_ssl = False
+
+    # iOS app generates 8-char nonces.
+    nonce_length = 8, 250
+
+    # Because piston.models.Token.key is char(18).
+    access_token_length = 18, 32
+
+    def check_client_key(self, client_key):
+        """We control the keys anyway."""
+        return True
+
+    def get_access_token_secret(self, client_key, token, request):
+        return request.token.secret
+
+    def get_default_realms(self, client_key, request):
+        return ['API']
+
+    def validate_access_token(self, client_key, token, request):
+        try:
+            token = Token.objects.get(
+                token_type=Token.ACCESS,
+                consumer__key=client_key,
+                key=token
+            )
+        except Token.DoesNotExist:
+            return False
+        else:
+            request.token = token
+            return True
+
+    def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
+                                     request, request_token=None, access_token=None):
+        # TODO: validate the timestamp
+        token = request_token or access_token
+        # Yes, this is what Piston did.
+        if token is None:
+            return True
+
+        nonce, created = Nonce.objects.get_or_create(consumer_key=client_key,
+                                                     token_key=token,
+                                                     key=nonce)
+        return created
+
+    def validate_client_key(self, client_key, request):
+        try:
+            request.oauth_consumer = Consumer.objects.get(key=client_key)
+        except Consumer.DoesNotExist:
+            return False
+        return True
+
+    def validate_realms(self, client_key, token, request, uri=None, realms=None):
+        return True
+
+    def validate_requested_realms(self, *args, **kwargs):
+        return True
+
+    def validate_redirect_uri(self, *args, **kwargs):
+        return True
+
+    def get_client_secret(self, client_key, request):
+        return request.oauth_consumer.secret
+
+    def save_request_token(self, token, request):
+        Token.objects.create(
+            token_type=Token.REQUEST,
+            timestamp=request.timestamp,
+            key=token['oauth_token'],
+            secret=token['oauth_token_secret'],
+            consumer=request.oauth_consumer,
+        )
index 38f1882..ee3d66e 100644 (file)
@@ -253,7 +253,7 @@ class OAuth1Tests(ApiTest):
         User.objects.all().delete()
 
     def test_create_token(self):
         User.objects.all().delete()
 
     def test_create_token(self):
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         raw = '&'.join([
@@ -274,7 +274,7 @@ class OAuth1Tests(ApiTest):
             key=request_token['oauth_token'][0], token_type=Token.REQUEST
         ).update(user=self.user, is_approved=True)
 
             key=request_token['oauth_token'][0], token_type=Token.REQUEST
         ).update(user=self.user, is_approved=True)
 
-        base_query = ("oauth_consumer_key=client&oauth_nonce=123&"
+        base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
                           int(time()), request_token['oauth_token'][0]))
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
                           int(time()), request_token['oauth_token'][0]))
@@ -330,7 +330,7 @@ class AuthorizedTests(ApiTest):
     def signed(self, url, method='GET', params=None, data=None):
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
     def signed(self, url, method='GET', params=None, data=None):
         auth_params = {
             "oauth_consumer_key": self.consumer.key,
-            "oauth_nonce": "%f" % time(),
+            "oauth_nonce": ("%f" % time()).replace('.', ''),
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
             "oauth_signature_method": "HMAC-SHA1",
             "oauth_timestamp": int(time()),
             "oauth_token": self.token.key,
@@ -358,10 +358,11 @@ class AuthorizedTests(ApiTest):
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
         if params:
             url = url + '?' + urlencode(params)
         return getattr(self.client, method.lower())(
-                url,
-                data=data,
-                HTTP_AUTHORIZATION=auth
-            )
+            url,
+            data=urlencode(data) if data else None,
+            content_type='application/x-www-form-urlencoded',
+            HTTP_AUTHORIZATION=auth,
+        )
 
     def signed_json(self, url, method='GET', params=None, data=None):
         return json.loads(self.signed(url, method, params, data).content)
 
     def signed_json(self, url, method='GET', params=None, data=None):
         return json.loads(self.signed(url, method, params, data).content)
@@ -371,10 +372,9 @@ class AuthorizedTests(ApiTest):
             [b['liked'] for b in self.signed_json('/api/books/')],
             [False, False, False]
         )
             [b['liked'] for b in self.signed_json('/api/books/')],
             [False, False, False]
         )
-        # This one fails in the legacy implementation
-        # data = self.signed_json('/api/books/child/')
-        # self.assertFalse(data['parent']['liked'])
-        # self.assertFalse(data['children'][0]['liked'])
+        data = self.signed_json('/api/books/child/')
+        self.assertFalse(data['parent']['liked'])
+        self.assertFalse(data['children'][0]['liked'])
 
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
 
         self.assertEqual(
             self.signed_json('/api/like/parent/'),
@@ -390,9 +390,8 @@ class AuthorizedTests(ApiTest):
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
 
         self.assertTrue(self.signed_json(
             '/api/filter-books/', params={"search": "parent"})[0]['liked'])
 
-        # This one fails in the legacy implementation.
-        #self.assertTrue(self.signed_json(
-        #    '/api/books/child/')['parent']['liked'])
+        self.assertTrue(self.signed_json(
+            '/api/books/child/')['parent']['liked'])
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
         # Liked books go on shelf.
         self.assertEqual(
             [x['slug'] for x in self.signed_json('/api/shelf/likes/')],
index 973bf03..150dc4c 100644 (file)
@@ -13,7 +13,7 @@ from . import views
 
 
 urlpatterns = [
 
 
 urlpatterns = [
-    url(r'^oauth/request_token/$', oauth_request_token),
+    url(r'^oauth/request_token/$', views.OAuth1RequestTokenView.as_view()),
     url(r'^oauth/authorize/$', oauth_user_auth, name='oauth_user_auth'),
     url(r'^oauth/access_token/$', csrf_exempt(oauth_access_token)),
 
     url(r'^oauth/authorize/$', oauth_user_auth, name='oauth_user_auth'),
     url(r'^oauth/access_token/$', csrf_exempt(oauth_access_token)),
 
index 377beb6..6d462fa 100644 (file)
@@ -2,7 +2,10 @@
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
-from django.http import Http404
+from django.http import Http404, HttpResponse
+from oauthlib.common import urlencode
+from oauthlib.oauth1 import RequestTokenEndpoint
+from piston.models import KEY_SIZE, SECRET_SIZE
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.views import APIView
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.views import APIView
@@ -11,6 +14,42 @@ from migdal.models import Entry
 from catalogue.models import Book
 from .models import BookUserData
 from . import serializers
 from catalogue.models import Book
 from .models import BookUserData
 from . import serializers
+from .request_validator import PistonRequestValidator
+
+
+class OAuth1RequestTokenEndpoint(RequestTokenEndpoint):
+    def _create_request(self, *args, **kwargs):
+        r = super(OAuth1RequestTokenEndpoint, self)._create_request(*args, **kwargs)
+        r.redirect_uri = 'oob'
+        return r
+
+    def create_request_token(self, request, credentials):
+        token = {
+            'oauth_token': self.token_generator()[:KEY_SIZE],
+            'oauth_token_secret': self.token_generator()[:SECRET_SIZE],
+        }
+        token.update(credentials)
+        self.request_validator.save_request_token(token, request)
+        return urlencode(token.items())
+
+
+class OAuth1RequestTokenView(APIView):
+    def __init__(self):
+        self.endpoint = OAuth1RequestTokenEndpoint(PistonRequestValidator())
+    def dispatch(self, request):
+        headers, body, status = self.endpoint.create_request_token_response(
+            request.build_absolute_uri(),
+            request.method,
+            request.body,
+            {
+                "Authorization": request.META['HTTP_AUTHORIZATION']
+            } if 'HTTP_AUTHORIZATION' in request.META else None
+        )
+
+        response = HttpResponse(body, status=status)
+        for k, v in headers.items():
+            response[k] = v
+        return response
 
 
 class UserView(RetrieveAPIView):
 
 
 class UserView(RetrieveAPIView):