funkwhale/api/funkwhale_api/users/oauth/views.py

239 wiersze
8.4 KiB
Python

import json
import secrets
import urllib.parse
from django import http
from django.db.models import Q
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from oauth2_provider import exceptions as oauth2_exceptions
from oauth2_provider import views as oauth_views
from oauth2_provider.settings import oauth2_settings
from rest_framework import mixins, permissions, response, views, viewsets
from rest_framework.decorators import action
from funkwhale_api.common import throttling
from .. import models
from . import serializers
from .permissions import ScopePermission
class ApplicationViewSet(
mixins.CreateModelMixin,
mixins.ListModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.RetrieveModelMixin,
viewsets.GenericViewSet,
):
anonymous_policy = True
required_scope = {
"retrieve": None,
"create": None,
"destroy": "write:security",
"update": "write:security",
"partial_update": "write:security",
"refresh_token": "write:security",
"list": "read:security",
}
lookup_field = "client_id"
queryset = models.Application.objects.all().order_by("-created")
serializer_class = serializers.ApplicationSerializer
throttling_scopes = {
"create": {
"anonymous": "anonymous-oauth-app",
"authenticated": "authenticated-oauth-app",
}
}
def create(self, request, *args, **kwargs):
request_data = request.data.copy()
secret = secrets.token_hex(64)
request_data["client_secret"] = secret
serializer = self.get_serializer(data=request_data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
data = serializer.data
# Since the serializer returns a hashed secret, we need to override it for the response.
data["client_secret"] = secret
return response.Response(data, status=201, headers=headers)
def get_serializer_class(self):
if self.request.method.lower() == "post":
return serializers.CreateApplicationSerializer
return super().get_serializer_class()
def perform_create(self, serializer):
return serializer.save(
client_type=models.Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
user=self.request.user if self.request.user.is_authenticated else None,
token=models.get_token() if self.request.user.is_authenticated else None,
)
def get_serializer(self, *args, **kwargs):
serializer_class = self.get_serializer_class()
try:
owned = args[0].user == self.request.user
except (IndexError, AttributeError):
owned = False
if owned:
serializer_class = serializers.CreateApplicationSerializer
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)
def get_queryset(self):
qs = super().get_queryset()
if self.action in [
"list",
"destroy",
"update",
"partial_update",
"refresh_token",
]:
qs = qs.filter(user=self.request.user)
return qs
@extend_schema(operation_id="refresh_oauth_token")
@action(
detail=True,
methods=["post"],
url_name="refresh_token",
url_path="refresh-token",
)
def refresh_token(self, request, *args, **kwargs):
app = self.get_object()
if not app.user_id or request.user != app.user:
return response.Response(status=404)
app.token = models.get_token()
app.save(update_fields=["token"])
serializer = serializers.CreateApplicationSerializer(app)
return response.Response(serializer.data, status=200)
class GrantViewSet(
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
viewsets.GenericViewSet,
):
"""
This is a viewset that list applications that have access to the request user
account, to allow revoking tokens easily.
"""
permission_classes = [permissions.IsAuthenticated, ScopePermission]
required_scope = "security"
lookup_field = "client_id"
queryset = models.Application.objects.all().order_by("-created")
serializer_class = serializers.ApplicationSerializer
pagination_class = None
def get_queryset(self):
now = timezone.now()
queryset = super().get_queryset()
grants = models.Grant.objects.filter(user=self.request.user, expires__gt=now)
access_tokens = models.AccessToken.objects.filter(user=self.request.user)
refresh_tokens = models.RefreshToken.objects.filter(
user=self.request.user, revoked=None
)
return queryset.filter(
Q(pk__in=access_tokens.values("application"))
| Q(pk__in=refresh_tokens.values("application"))
| Q(pk__in=grants.values("application"))
).distinct()
def perform_create(self, serializer):
return serializer.save(
client_type=models.Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
)
def perform_destroy(self, instance):
application = instance
access_tokens = application.accesstoken_set.filter(user=self.request.user)
for token in access_tokens:
token.revoke()
refresh_tokens = application.refreshtoken_set.filter(user=self.request.user)
for token in refresh_tokens:
try:
token.revoke()
except models.AccessToken.DoesNotExist:
token.access_token = None
token.revoked = timezone.now()
token.save(update_fields=["access_token", "revoked"])
grants = application.grant_set.filter(user=self.request.user)
grants.delete()
class AuthorizeView(views.APIView, oauth_views.AuthorizationView):
permission_classes = [permissions.IsAuthenticated]
server_class = oauth2_settings.OAUTH2_SERVER_CLASS
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS
skip_authorization_completely = False
oauth2_data = {}
def form_invalid(self, form):
"""
Return a JSON response instead of a template one
"""
errors = form.errors
return self.json_payload(errors, status_code=400)
def post(self, request, *args, **kwargs):
throttling.check_request(request, "oauth-authorize")
return super().post(request, *args, **kwargs)
def form_valid(self, form):
try:
return super().form_valid(form)
except models.Application.DoesNotExist:
return self.json_payload({"non_field_errors": ["Invalid application"]}, 400)
def redirect(self, redirect_to, application):
if self.request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest":
# Web client need this to be able to redirect the user
query = urllib.parse.urlparse(redirect_to).query
code = urllib.parse.parse_qs(query)["code"][0]
return self.json_payload(
{"redirect_uri": redirect_to, "code": code}, status_code=200
)
return super().redirect(redirect_to, application)
def error_response(self, error, application):
if isinstance(error, oauth2_exceptions.FatalClientError):
return self.json_payload({"detail": error.oauthlib_error.description}, 400)
return super().error_response(error, application)
def json_payload(self, payload, status_code):
return http.HttpResponse(
json.dumps(payload), status=status_code, content_type="application/json"
)
def handle_no_permission(self):
return self.json_payload(
{"detail": "Authentication credentials were not provided."}, 401
)
class TokenView(oauth_views.TokenView):
def post(self, request, *args, **kwargs):
throttling.check_request(request, "oauth-token")
return super().post(request, *args, **kwargs)
class RevokeTokenView(oauth_views.RevokeTokenView):
def post(self, request, *args, **kwargs):
throttling.check_request(request, "oauth-revoke-token")
return super().post(request, *args, **kwargs)