From 40377c9b3b045814857e1190047ef83334c1131f Mon Sep 17 00:00:00 2001 From: Radek Czajka Date: Tue, 12 Feb 2019 22:51:47 +0100 Subject: [PATCH] App compatibility fix + some error handling. --- src/api/piston_patch.py | 38 ++++++++++++++++++++++++++------------ src/api/tests/tests.py | 30 ++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/src/api/piston_patch.py b/src/api/piston_patch.py index 3c7e50f4f..6a80e15cd 100644 --- a/src/api/piston_patch.py +++ b/src/api/piston_patch.py @@ -2,10 +2,10 @@ # This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later. # Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information. # -from oauthlib.oauth1 import AuthorizationEndpoint +from oauthlib.oauth1 import AuthorizationEndpoint, OAuth1Error from django.contrib.auth.decorators import login_required from django import forms -from django.http import HttpResponseRedirect +from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import render from .request_validator import PistonRequestValidator from .utils import oauthlib_request, oauthlib_response @@ -21,15 +21,26 @@ class OAuthAuthenticationForm(forms.Form): # removed authorize_access - redundant +class OAuth1AuthorizationEndpoint(AuthorizationEndpoint): + def create_verifier(self, request, credentials): + verifier = super(OAuth1AuthorizationEndpoint, self).create_verifier(request, credentials) + return { + 'oauth_token': verifier['oauth_token'], + } + + @login_required def oauth_user_auth(request): - endpoint = AuthorizationEndpoint(PistonRequestValidator()) + endpoint = OAuth1AuthorizationEndpoint(PistonRequestValidator()) if request.method == "GET": # Why not just get oauth_token here? # This is fairly straightforward, in't? - realms, credentials = endpoint.get_realms_and_credentials( - **oauthlib_request(request)) + try: + realms, credentials = endpoint.get_realms_and_credentials( + **oauthlib_request(request)) + except OAuth1Error as e: + return HttpResponse(e.message, status=400) callback = request.GET.get('oauth_callback') form = OAuthAuthenticationForm(initial={ @@ -40,11 +51,14 @@ def oauth_user_auth(request): return render(request, 'piston/authorize_token.html', {'form': form}) elif request.method == "POST": - response = oauthlib_response( - endpoint.create_authorization_response( - credentials={"user": request.user}, - **oauthlib_request(request) + try: + response = oauthlib_response( + endpoint.create_authorization_response( + credentials={"user": request.user}, + **oauthlib_request(request) + ) ) - ) - - return response + except OAuth1Error as e: + return HttpResponse(e.message, status=400) + else: + return response diff --git a/src/api/tests/tests.py b/src/api/tests/tests.py index c8e07d26c..298a79433 100644 --- a/src/api/tests/tests.py +++ b/src/api/tests/tests.py @@ -242,6 +242,8 @@ class OAuth1Tests(ApiTest): @classmethod def setUpClass(cls): cls.user = User.objects.create(username='test') + cls.user.set_password('test') + cls.user.save() cls.consumer_secret = 'len(quote(consumer secret))>=32' Consumer.objects.create( key='client', @@ -253,6 +255,7 @@ class OAuth1Tests(ApiTest): User.objects.all().delete() def test_create_token(self): + # Fetch request token. base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&" "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&" "oauth_version=1.0".format(int(time()))) @@ -268,16 +271,26 @@ class OAuth1Tests(ApiTest): sign = quote(h) query = "{}&oauth_signature={}".format(base_query, sign) response = self.client.get('/api/oauth/request_token/?' + query) - request_token = parse_qs(response.content) + request_token_data = parse_qs(response.content) + request_token = request_token_data['oauth_token'][0] + request_token_secret = request_token_data['oauth_token_secret'][0] - Token.objects.filter( - key=request_token['oauth_token'][0], token_type=Token.REQUEST - ).update(user=self.user, is_approved=True) + # Request token authorization. + self.client.login(username='test', password='test') + response = self.client.get('/api/oauth/authorize/?oauth_token=%s&oauth_callback=test://oauth.callback/' % request_token) + post_data = response.context['form'].initial + response = self.client.post('/api/oauth/authorize/?' + urlencode(post_data)) + self.assertEqual( + response['Location'], + 'test://oauth.callback/?oauth_token=' + request_token + ) + + # Fetch access token. base_query = ("oauth_consumer_key=client&oauth_nonce=12345678&" "oauth_signature_method=HMAC-SHA1&oauth_timestamp={}&" "oauth_token={}&oauth_version=1.0".format( - int(time()), request_token['oauth_token'][0])) + int(time()), request_token)) raw = '&'.join([ 'GET', quote('http://testserver/api/oauth/access_token/', safe=''), @@ -285,7 +298,7 @@ class OAuth1Tests(ApiTest): ]) h = hmac.new( quote(self.consumer_secret) + '&' + - quote(request_token['oauth_token_secret'][0], safe=''), + quote(request_token_secret, safe=''), raw, hashlib.sha1 ).digest() @@ -293,11 +306,12 @@ class OAuth1Tests(ApiTest): sign = quote(h) query = u"{}&oauth_signature={}".format(base_query, sign) response = self.client.get(u'/api/oauth/access_token/?' + query) - access_token = parse_qs(response.content) + access_token_data = parse_qs(response.content) + access_token = access_token_data['oauth_token'][0] self.assertTrue( Token.objects.filter( - key=access_token['oauth_token'][0], + key=access_token, token_type=Token.ACCESS, user=self.user ).exists()) -- 2.20.1