Skip to content

Commit

Permalink
Do not 404 when no user, add access token store to mobile authenticat…
Browse files Browse the repository at this point in the history
…e, remove unnecessary refresh token check
  • Loading branch information
Justin Zhang authored and Justin Zhang committed Aug 25, 2023
1 parent b6c04a0 commit 5a63589
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 128 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ x.y.z (UNRELEASED)
------------------
* Changes

0.9.3 (2023-02-05)
------------------
* Fix login regression for new mobile users

0.9.0 (2023-02-05)
------------------
* Introduced B2B IPC
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ from identity.identity import authenticated_b2b_request
result = authenticated_b2b_request('GET', 'http://url/path')
```

## Development Setup

### Install poetry:

`pipx install poetry`

### Install Dependencies:

`poetry install`

### Testing:

`export DJANGO_SETTINGS_MODULE=tests.settings && poetry run pytest`

### Linting:

`poetry run black . && poetry run isort . && poetry run flake8`

## Changelog

See [CHANGELOG.md](https://github.com/pennlabs/django-labs-accounts/blob/master/CHANGELOG.md)
Expand Down
19 changes: 19 additions & 0 deletions accounts/authentication.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from datetime import timedelta

import requests
from django.contrib import auth
from django.contrib.auth import get_user_model
from django.utils import timezone
from rest_framework import authentication, exceptions

from accounts.models import AccessToken
from accounts.settings import accounts_settings
from identity.identity import get_validated_claims

Expand All @@ -18,6 +22,11 @@ class PlatformAuthentication(authentication.BaseAuthentication):
HTTP header, prepended with the string "Bearer ". For example:
Authorization: Bearer abc
NOTE: When possible, always use the native DLA login routes.
One limitation of this route is that we only have access to
the bearer token, and thus cannot save a user's refresh token
to the database
"""

keyword = "Bearer"
Expand Down Expand Up @@ -51,6 +60,16 @@ def authenticate(self, request):
user_props = json["user"]
user = auth.authenticate(remote_user=user_props, tokens=False)
if user: # User authenticated successfully
# NOTE: Ideally we would want to store both access and refresh tokens,
# but only the access token is available via this route
AccessToken.objects.update_or_create(
user=user,
defaults={
"expires_at": timezone.now()
+ timedelta(seconds=user_props["token"]["expires_in"]),
"token": user_props["token"]["access_token"],
},
)
return (user, None)
else: # Error occurred
raise exceptions.AuthenticationFailed("Invalid User.")
Expand Down
51 changes: 0 additions & 51 deletions accounts/ipc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from datetime import timedelta

import requests
from django.utils import timezone

from accounts.settings import accounts_settings


# IPC on behalf of a user for when a user in a product wants to use an
Expand Down Expand Up @@ -34,15 +29,6 @@ def authenticated_request(
happen
"""

# Access token is expired. Try to refresh access token
if user.accesstoken.expires_at < timezone.now():
if not _refresh_access_token(user):
# Couldn't update the user's access token. Return a response with a 403 status code
# as if the user didn't have access to the requested resource
response = requests.models.Response
response.status_code = 403
return response

# Update Headers
headers = {} if headers is None else headers
headers["Authorization"] = f"Bearer {user.accesstoken.token}"
Expand All @@ -69,40 +55,3 @@ def authenticated_request(
cert=cert,
json=json,
)


def _refresh_access_token(user):
"""
Helper method to update a user's access token. Should be used when a user's
access token has expired, but still has a valid refresh token.
Returns:
bool: true if the access token is updated, false otherwise.
"""
body = {
"grant_type": "refresh_token",
"client_id": accounts_settings.CLIENT_ID, # from Product
"client_secret": accounts_settings.CLIENT_SECRET, # from Product
"refresh_token": user.refreshtoken.token, # refresh token from user
}
try:
data = requests.post(
url=accounts_settings.PLATFORM_URL + "/accounts/token/", data=body
)
if data.status_code == 200: # Access token refreshed successfully
data = data.json()
# Update Access token
user.accesstoken.token = data["access_token"]
user.accesstoken.expires_at = timezone.now() + timedelta(
seconds=data["expires_in"]
)
user.accesstoken.save()

# Update Refresh Token
user.refreshtoken.token = data["refresh_token"]
user.refreshtoken.save()

return True
except requests.exceptions.RequestException: # Can't connect to platform
return False
return False
34 changes: 18 additions & 16 deletions accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.contrib import auth
from django.contrib.auth import get_user_model
from django.http import HttpResponseServerError, JsonResponse
from django.shortcuts import get_object_or_404, redirect
from django.shortcuts import redirect
from django.urls import reverse
from django.utils import timezone
from django.utils.decorators import method_decorator
Expand Down Expand Up @@ -142,21 +142,23 @@ def post(self, request):
platform_request.status_code == 200
): # Connected to platform successfully
user_props = platform_request.json()["user"]
user = get_object_or_404(
User, id=user_props["pennid"], username=user_props["username"]
)
# Update user Access and Refresh tokens
AccessToken.objects.update_or_create(
user=user,
defaults={
"expires_at": timezone.now()
+ datetime.timedelta(seconds=token["expires_in"]),
"token": token["access_token"],
},
)
RefreshToken.objects.update_or_create(
user=user, defaults={"token": token["refresh_token"]}
)
user = User.objects.filter(
id=user_props["pennid"], username=user_props["username"]
).first()
# A user object will exist only after the first token retrieval
if user:
# Update user Access and Refresh tokens
AccessToken.objects.update_or_create(
user=user,
defaults={
"expires_at": timezone.now()
+ datetime.timedelta(seconds=token["expires_in"]),
"token": token["access_token"],
},
)
RefreshToken.objects.update_or_create(
user=user, defaults={"token": token["refresh_token"]}
)
return JsonResponse(response.json())
return JsonResponse({"detail": "Invalid tokens"}, status=403)
return JsonResponse({"detail": "Invalid parameters"}, status=400)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "django-labs-accounts"
version = "0.9.2"
version = "0.9.3"
description = "Reusable Django app for Penn Labs accounts"
authors = ["Penn Labs <[email protected]>"]
license = "MIT"
Expand Down
5 changes: 5 additions & 0 deletions tests/accounts/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rest_framework import status
from rest_framework.test import APIClient

from accounts.models import AccessToken


User = get_user_model()

Expand Down Expand Up @@ -53,6 +55,7 @@ def test_post_form_passing_token_auth(self, mock_request):
self.path, {"example": "example"}, HTTP_AUTHORIZATION=self.auth
)
self.assertEqual(status.HTTP_200_OK, response.status_code)
self.assertEqual(len(AccessToken.objects.all()), 1)

def test_post_form_passing_token_auth_new_user(self, mock_request):
mock_request.return_value.status_code = 200
Expand All @@ -65,6 +68,7 @@ def test_post_form_passing_token_auth_new_user(self, mock_request):
user = User.objects.get(id=456)
self.assertEqual(user, response.wsgi_request.user)
self.assertEqual(status.HTTP_200_OK, response.status_code)
self.assertEqual(len(AccessToken.objects.all()), 1)

def test_fail_authentication_if_user_is_not_active(self, mock_request):
self.user.is_active = False
Expand Down Expand Up @@ -123,6 +127,7 @@ def test_post_json_passing_token_auth(self, mock_request):
HTTP_AUTHORIZATION=self.auth,
)
self.assertEqual(status.HTTP_200_OK, response.status_code)
self.assertEqual(len(AccessToken.objects.all()), 1)

def test_post_form_failing_token_auth(self, mock_request):
"""
Expand Down
61 changes: 3 additions & 58 deletions tests/accounts/test_ipc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from datetime import timedelta
from unittest.mock import patch

import requests
from django.contrib.auth import get_user_model
from django.test import Client, TestCase
from django.test import TestCase
from django.utils import timezone

from accounts.ipc import _refresh_access_token, authenticated_request
from accounts.ipc import authenticated_request
from accounts.models import AccessToken, RefreshToken


Expand All @@ -20,63 +18,10 @@ def setUp(self):
)
RefreshToken.objects.create(user=self.user)

@patch("accounts.ipc._refresh_access_token")
def test_update_refresh_token_fail(self, mock_refresh):
mock_refresh.return_value = False
response = authenticated_request(self.user, None, None)
self.assertEqual(403, response.status_code)

@patch("accounts.ipc._refresh_access_token")
@patch("accounts.ipc.requests.Session")
def test_authorization_header(self, mock_session, mock_refresh):
mock_refresh.return_value = True
def test_authorization_header(self, mock_session):
header = {"abc": "123"}
authenticated_request(self.user, None, None, headers=header)
header["Authorization"] = f"Bearer {self.token}"
arguments = mock_session.return_value.request.call_args[1]
self.assertEqual(header, arguments["headers"])


@patch("accounts.ipc.requests.post")
class RefreshAccessTokenTestCase(TestCase):
def setUp(self):
self.client = Client()
self.user = get_user_model().objects.create(username="abc")
self.now = timezone.now()
AccessToken.objects.create(user=self.user, expires_at=self.now)
RefreshToken.objects.create(user=self.user)
self.valid_response = {
"access_token": "abc",
"refresh_token": "123",
"expires_in": 100,
}

def test_valid_refresh_token(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = self.valid_response
value = _refresh_access_token(self.user)
diff = self.now + timedelta(seconds=self.valid_response["expires_in"])
self.assertTrue(value)
self.assertTrue(diff < self.user.accesstoken.expires_at)
self.assertEqual(
self.valid_response["access_token"], self.user.accesstoken.token
)
self.assertEqual(
self.valid_response["refresh_token"], self.user.refreshtoken.token
)

def test_invalid_response(self, mock_post):
mock_post.return_value.status_code = 403
value = _refresh_access_token(self.user)
self.assertFalse(value)

def test_exception_occurred(self, mock_post):
mock_post.side_effect = requests.exceptions.RequestException
value = _refresh_access_token(self.user)
self.assertFalse(value)
self.assertNotEqual(
self.valid_response["access_token"], self.user.accesstoken.token
)
self.assertNotEqual(
self.valid_response["refresh_token"], self.user.refreshtoken.token
)
7 changes: 5 additions & 2 deletions tests/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,11 @@ def test_token_unknown_user(self, mock_requests_post, mock_oauth_post):
"verifier": "correct_verifier",
}
response = self.client.post(reverse("accounts:token"), payload)
# Should fail because User object is never created in this test
self.assertEqual(404, response.status_code)
# If no User object, token should still go through, however
# no access and refresh tokens will be stored
self.assertEqual(200, response.status_code)
self.assertEqual(len(AccessToken.objects.all()), 0)
self.assertEqual(len(RefreshToken.objects.all()), 0)

@patch("accounts.views.requests.post")
def test_token_invalid_introspect(self, mock_requests_post):
Expand Down

0 comments on commit 5a63589

Please sign in to comment.