App compatibility fix + some error handling.
authorRadek Czajka <rczajka@rczajka.pl>
Tue, 12 Feb 2019 21:51:47 +0000 (22:51 +0100)
committerRadek Czajka <rczajka@rczajka.pl>
Tue, 12 Feb 2019 21:51:47 +0000 (22:51 +0100)
src/api/piston_patch.py
src/api/tests/tests.py

index 3c7e50f..6a80e15 100644 (file)
@@ -2,10 +2,10 @@
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
 # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
 # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
 #
-from oauthlib.oauth1 import AuthorizationEndpoint
+from oauthlib.oauth1 import AuthorizationEndpoint, OAuth1Error
 from django.contrib.auth.decorators import login_required
 from django import forms
 from django.contrib.auth.decorators import login_required
 from django import forms
-from django.http import HttpResponseRedirect
+from django.http import HttpResponse, HttpResponseRedirect
 from django.shortcuts import render
 from .request_validator import PistonRequestValidator
 from .utils import oauthlib_request, oauthlib_response
 from django.shortcuts import render
 from .request_validator import PistonRequestValidator
 from .utils import oauthlib_request, oauthlib_response
@@ -21,15 +21,26 @@ class OAuthAuthenticationForm(forms.Form):
     # removed authorize_access - redundant
 
 
     # removed authorize_access - redundant
 
 
+class OAuth1AuthorizationEndpoint(AuthorizationEndpoint):
+    def create_verifier(self, request, credentials):
+        verifier = super(OAuth1AuthorizationEndpoint, self).create_verifier(request, credentials)
+        return {
+            'oauth_token': verifier['oauth_token'],
+        }
+
+
 @login_required
 def oauth_user_auth(request):
 @login_required
 def oauth_user_auth(request):
-    endpoint = AuthorizationEndpoint(PistonRequestValidator())
+    endpoint = OAuth1AuthorizationEndpoint(PistonRequestValidator())
 
     if request.method == "GET":
         # Why not just get oauth_token here?
         # This is fairly straightforward, in't?
 
     if request.method == "GET":
         # Why not just get oauth_token here?
         # This is fairly straightforward, in't?
-        realms, credentials = endpoint.get_realms_and_credentials(
-            **oauthlib_request(request))
+        try:
+            realms, credentials = endpoint.get_realms_and_credentials(
+                **oauthlib_request(request))
+        except OAuth1Error as e:
+            return HttpResponse(e.message, status=400)
         callback = request.GET.get('oauth_callback')
 
         form = OAuthAuthenticationForm(initial={
         callback = request.GET.get('oauth_callback')
 
         form = OAuthAuthenticationForm(initial={
@@ -40,11 +51,14 @@ def oauth_user_auth(request):
         return render(request, 'piston/authorize_token.html', {'form': form})
 
     elif request.method == "POST":
         return render(request, 'piston/authorize_token.html', {'form': form})
 
     elif request.method == "POST":
-        response = oauthlib_response(
-            endpoint.create_authorization_response(
-                credentials={"user": request.user},
-                **oauthlib_request(request)
+        try:
+            response = oauthlib_response(
+                endpoint.create_authorization_response(
+                    credentials={"user": request.user},
+                    **oauthlib_request(request)
+                )
             )
             )
-        )
-
-        return response
+        except OAuth1Error as e:
+            return HttpResponse(e.message, status=400)
+        else:
+            return response
index c8e07d2..298a794 100644 (file)
@@ -242,6 +242,8 @@ class OAuth1Tests(ApiTest):
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
     @classmethod
     def setUpClass(cls):
         cls.user = User.objects.create(username='test')
+        cls.user.set_password('test')
+        cls.user.save()
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
         cls.consumer_secret = 'len(quote(consumer secret))>=32'
         Consumer.objects.create(
             key='client',
@@ -253,6 +255,7 @@ class OAuth1Tests(ApiTest):
         User.objects.all().delete()
 
     def test_create_token(self):
         User.objects.all().delete()
 
     def test_create_token(self):
+        # Fetch request token.
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_version=1.0".format(int(time())))
@@ -268,16 +271,26 @@ class OAuth1Tests(ApiTest):
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
         sign = quote(h)
         query = "{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get('/api/oauth/request_token/?' + query)
-        request_token = parse_qs(response.content)
+        request_token_data = parse_qs(response.content)
+        request_token = request_token_data['oauth_token'][0]
+        request_token_secret = request_token_data['oauth_token_secret'][0]
 
 
-        Token.objects.filter(
-            key=request_token['oauth_token'][0], token_type=Token.REQUEST
-        ).update(user=self.user, is_approved=True)
+        # Request token authorization.
+        self.client.login(username='test', password='test')
+        response = self.client.get('/api/oauth/authorize/?oauth_token=%s&oauth_callback=test://oauth.callback/' % request_token)
+        post_data = response.context['form'].initial
 
 
+        response = self.client.post('/api/oauth/authorize/?' + urlencode(post_data))
+        self.assertEqual(
+            response['Location'],
+            'test://oauth.callback/?oauth_token=' + request_token
+        )
+
+        # Fetch access token.
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
         base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&"
                       "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&"
                       "oauth_token={}&oauth_version=1.0".format(
-                          int(time()), request_token['oauth_token'][0]))
+                          int(time()), request_token))
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
         raw = '&'.join([
             'GET',
             quote('http://testserver/api/oauth/access_token/', safe=''),
@@ -285,7 +298,7 @@ class OAuth1Tests(ApiTest):
         ])
         h = hmac.new(
             quote(self.consumer_secret) + '&' +
         ])
         h = hmac.new(
             quote(self.consumer_secret) + '&' +
-            quote(request_token['oauth_token_secret'][0], safe=''),
+            quote(request_token_secret, safe=''),
             raw,
             hashlib.sha1
         ).digest()
             raw,
             hashlib.sha1
         ).digest()
@@ -293,11 +306,12 @@ class OAuth1Tests(ApiTest):
         sign = quote(h)
         query = u"{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get(u'/api/oauth/access_token/?' + query)
         sign = quote(h)
         query = u"{}&oauth_signature={}".format(base_query, sign)
         response = self.client.get(u'/api/oauth/access_token/?' + query)
-        access_token = parse_qs(response.content)
+        access_token_data = parse_qs(response.content)
+        access_token = access_token_data['oauth_token'][0]
 
         self.assertTrue(
             Token.objects.filter(
 
         self.assertTrue(
             Token.objects.filter(
-                key=access_token['oauth_token'][0],
+                key=access_token,
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())
                 token_type=Token.ACCESS,
                 user=self.user
             ).exists())