diff --git a/backend/api/authentication/middleware.py b/backend/api/authentication/middleware.py new file mode 100644 index 0000000..277edd6 --- /dev/null +++ b/backend/api/authentication/middleware.py @@ -0,0 +1,46 @@ +# api/authentication/middleware.py + +from rest_framework_simplejwt.authentication import JWTAuthentication +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError + + +class JWTParamMiddleware: + """ + Middleware that allows JWT authentication via query parameters. + + This middleware extracts a JWT token from a query parameter named 'token' + and authenticates the user if the token is valid. + """ + + def __init__(self, get_response): + self.get_response = get_response + self.jwt_auth = JWTAuthentication() + + def __call__(self, request): + self._authenticate_token_param(request) + response = self.get_response(request) + return response + + def _authenticate_token_param(self, request): + # Don't authenticate if already authenticated via headers + if hasattr(request, "user") and request.user.is_authenticated: + return + + # Get token from the query parameter + token = request.GET.get("token") + if not token: + return + + # Validate the token + try: + validated_token = self.jwt_auth.get_validated_token(token) + user = self.jwt_auth.get_user(validated_token) + + # Set the authenticated user on the request + request.user = user + + # Also set auth in DRF format for API views + request._auth = validated_token + except (InvalidToken, TokenError): + # Don't raise exceptions, just leave as anonymous + pass diff --git a/backend/api/files/serializers.py b/backend/api/files/serializers.py index 48bc0c8..1dc4089 100644 --- a/backend/api/files/serializers.py +++ b/backend/api/files/serializers.py @@ -5,28 +5,51 @@ from apps.files.models import PostFileModel class PostFileSerializer(serializers.ModelSerializer): """Serializer for PostFileModel.""" + filename = serializers.SerializerMethodField() + thumbnails = serializers.SerializerMethodField() + download_url = serializers.SerializerMethodField() + class Meta: model = PostFileModel - fields = ["hash_blake3", "file_type", "file", "thumbnail"] - # Add any other fields you need + fields = [ + "hash_blake3", + "file_type", + "file", + "thumbnail", + "filename", + "thumbnails", + "download_url", + ] - def to_representation(self, instance): - """Customize the representation of the model.""" - representation = super().to_representation(instance) - - # Add file name from related model + def get_filename(self, obj): try: - representation["filename"] = instance.name.first().filename + return obj.name.first().filename except (AttributeError, IndexError): - representation["filename"] = "Unknown" + return "Unknown" - # Add URLs for different thumbnail sizes - base_url = f"/api/files/{instance.hash_blake3}/" + def get_thumbnails(self, obj): + base_url = f"/api/files/{obj.hash_blake3}/" thumbnails = {} - for size_key in THUMBNAIL_SIZES: + for size_key in ["sx", "sm", "md", "lg", "xl"]: thumbnails[size_key] = f"{base_url}?t={size_key}" + return thumbnails - representation["thumbnails"] = thumbnails - representation["download_url"] = f"{base_url}?d=0" + def get_download_url(self, obj): + return f"/api/files/{obj.hash_blake3}/?d=0" - return representation + +class FileResponseSerializer(serializers.Serializer): + """ + Dummy serializer for file response schema documentation. + This is only used for OpenAPI schema generation and will never be used to serialize data. + """ + + file = serializers.FileField(help_text="The file content") + + +class ErrorResponseSerializer(serializers.Serializer): + """ + Serializer for error responses. + """ + + error = serializers.CharField(help_text="Error message") diff --git a/backend/api/files/urls.py b/backend/api/files/urls.py index 67b40f5..1b3112b 100644 --- a/backend/api/files/urls.py +++ b/backend/api/files/urls.py @@ -3,7 +3,7 @@ from .views import FileServeView, FileDetailView urlpatterns = [ # Serve the actual file - path("files//", FileServeView.as_view(), name="serve_file"), + path("/", FileServeView.as_view(), name="serve_file"), # Get file metadata - path("files//info/", FileDetailView.as_view(), name="file_info"), + path("/info/", FileDetailView.as_view(), name="file_info"), ] diff --git a/backend/api/files/views.py b/backend/api/files/views.py index 16bb908..c2340e1 100644 --- a/backend/api/files/views.py +++ b/backend/api/files/views.py @@ -2,12 +2,17 @@ import os from django.conf import settings from django.http import FileResponse from rest_framework import status -from rest_framework.views import APIView +from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework.permissions import IsAuthenticated +from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiResponse from sorl.thumbnail import get_thumbnail from apps.files.models import PostFileModel -from .serializers import PostFileSerializer # You'll need to create this +from .serializers import ( + PostFileSerializer, + FileResponseSerializer, + ErrorResponseSerializer, +) THUMBNAIL_SIZES = { "sx": (64, ".thumb_64.jpg"), @@ -18,13 +23,19 @@ THUMBNAIL_SIZES = { } -class FileServeView(APIView): +class FileServeView(GenericAPIView): """ API view to serve content files for download or inline viewing. + + Authentication can be provided via: + 1. Authorization header (JWT token) + 2. 'token' query parameter (JWT token) """ - # Uncomment the following line if authentication is required - # permission_classes = [IsAuthenticated] + # Set permissions as needed + permission_classes = [IsAuthenticated] + serializer_class = FileResponseSerializer + queryset = PostFileModel.objects.all() def get_thumbnail_file(self, source_path, size_key): """Generates and retrieves the thumbnail file.""" @@ -36,6 +47,34 @@ class FileServeView(APIView): ), suffix return None, "" + @extend_schema( + parameters=[ + OpenApiParameter( + name="d", + description="Download flag (0 = download, otherwise inline)", + required=False, + type=str, + ), + OpenApiParameter( + name="t", + description="Thumbnail size (sx, sm, md, lg, xl)", + required=False, + type=str, + ), + OpenApiParameter( + name="token", + description="JWT token for authentication (alternative to Authorization header)", + required=False, + type=str, + ), + ], + responses={ + 200: OpenApiResponse(description="File returned successfully"), + 401: ErrorResponseSerializer, + 404: ErrorResponseSerializer, + 500: ErrorResponseSerializer, + }, + ) def get(self, request, file_hash): """Handle GET requests for file serving.""" download = request.query_params.get("d") == "0" @@ -82,14 +121,35 @@ class FileServeView(APIView): ) -class FileDetailView(APIView): +class FileDetailView(GenericAPIView): """ API view to get file metadata without serving the actual file. + + Authentication can be provided via: + 1. Authorization header (JWT token) + 2. 'token' query parameter (JWT token) """ - # Uncomment the following line if authentication is required - # permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticated] + serializer_class = PostFileSerializer + queryset = PostFileModel.objects.all() + @extend_schema( + parameters=[ + OpenApiParameter( + name="token", + description="JWT token for authentication (alternative to Authorization header)", + required=False, + type=str, + ) + ], + responses={ + 200: PostFileSerializer, + 401: ErrorResponseSerializer, + 404: ErrorResponseSerializer, + 500: ErrorResponseSerializer, + }, + ) def get(self, request, file_hash): """Return file metadata.""" try: @@ -99,7 +159,7 @@ class FileDetailView(APIView): {"error": "File not found"}, status=status.HTTP_404_NOT_FOUND ) - serializer = PostFileSerializer(obj_file) + serializer = self.get_serializer(obj_file) return Response(serializer.data) except Exception as e: diff --git a/backend/core/settings.py b/backend/core/settings.py index 2955fad..1627100 100644 --- a/backend/core/settings.py +++ b/backend/core/settings.py @@ -69,6 +69,7 @@ MIDDLEWARE = [ "django.middleware.csrf.CsrfViewMiddleware", "django.middleware.locale.LocaleMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", + "api.authentication.middleware.JWTParamMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", ]