mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-07-05 04:21:43 +02:00
Add API tokens and OAuth2 client support for external integrations
- Personal API tokens (model, user-settings UI, admin, management command, DRF auth class) for non-interactive API access from automations like n8n. Raw token shown once; only a SHA-256 hash is stored; last_used_at writes are throttled. - OAuth2 authorization server via django-oauth-toolkit with authorization server metadata and optional, off-by-default Dynamic Client Registration (RFC 7591), so remote OAuth/MCP clients can authenticate and self-register. - Tests for token auth, DCR gating and the management commands, plus .env.example and README documentation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.users.models import APIToken
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Creates a hashed API token for a WYGIWYH user and prints the raw token once."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("email", help="WYGIWYH user email that will own this token.")
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="n8n",
|
||||
help="Human-readable token name. Defaults to 'n8n'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expires-in-days",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Optional token lifetime in whole days.",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
email = options["email"].strip()
|
||||
name = options["name"].strip()
|
||||
expires_in_days = options["expires_in_days"]
|
||||
|
||||
if not email:
|
||||
raise CommandError("Email is required.")
|
||||
if not name:
|
||||
raise CommandError("Token name cannot be empty.")
|
||||
if expires_in_days is not None and expires_in_days <= 0:
|
||||
raise CommandError("--expires-in-days must be greater than zero.")
|
||||
|
||||
user = get_user_model().objects.filter(email__iexact=email).first()
|
||||
if user is None:
|
||||
raise CommandError(f"No WYGIWYH user exists for '{email}'.")
|
||||
|
||||
expires_at = None
|
||||
if expires_in_days is not None:
|
||||
expires_at = timezone.now() + timedelta(days=expires_in_days)
|
||||
|
||||
token, raw_token = APIToken.objects.create_token(
|
||||
user=user,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Created API token '{token.name}' for {user.email} ({token.token_key})."
|
||||
)
|
||||
)
|
||||
self.stdout.write(raw_token)
|
||||
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
|
||||
Application = get_application_model()
|
||||
|
||||
|
||||
def _get_env(name: str) -> str:
|
||||
return os.getenv(name, "").strip()
|
||||
|
||||
|
||||
def _get_bool_env(name: str, default: bool = False) -> bool:
|
||||
raw = _get_env(name)
|
||||
if not raw:
|
||||
return default
|
||||
return raw.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = (
|
||||
"Creates or updates the OAuth application used by MCP clients when "
|
||||
"MCP_OAUTH_CLIENT_* environment variables are configured."
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
client_id = _get_env("MCP_OAUTH_CLIENT_ID")
|
||||
client_secret = _get_env("MCP_OAUTH_CLIENT_SECRET")
|
||||
redirect_uris = " ".join(_get_env("MCP_OAUTH_REDIRECT_URIS").split())
|
||||
name = _get_env("MCP_OAUTH_CLIENT_NAME") or "WYGIWYH MCP"
|
||||
skip_authorization = _get_bool_env("MCP_OAUTH_SKIP_AUTHORIZATION", default=False)
|
||||
|
||||
if not any([client_id, client_secret, redirect_uris]):
|
||||
self.stdout.write(
|
||||
self.style.NOTICE(
|
||||
"MCP OAuth client env vars are not set. Skipping OAuth application setup."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
missing = []
|
||||
if not client_id:
|
||||
missing.append("MCP_OAUTH_CLIENT_ID")
|
||||
if not client_secret:
|
||||
missing.append("MCP_OAUTH_CLIENT_SECRET")
|
||||
if not redirect_uris:
|
||||
missing.append("MCP_OAUTH_REDIRECT_URIS")
|
||||
if missing:
|
||||
raise CommandError(
|
||||
"Missing required MCP OAuth settings: " + ", ".join(missing)
|
||||
)
|
||||
|
||||
application, created = Application.objects.get_or_create(
|
||||
client_id=client_id,
|
||||
defaults={
|
||||
"name": name,
|
||||
"client_type": Application.CLIENT_CONFIDENTIAL,
|
||||
"authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE,
|
||||
"redirect_uris": redirect_uris,
|
||||
"skip_authorization": skip_authorization,
|
||||
"client_secret": client_secret,
|
||||
"hash_client_secret": True,
|
||||
},
|
||||
)
|
||||
|
||||
updated_fields = []
|
||||
if application.name != name:
|
||||
application.name = name
|
||||
updated_fields.append("name")
|
||||
if application.client_type != Application.CLIENT_CONFIDENTIAL:
|
||||
application.client_type = Application.CLIENT_CONFIDENTIAL
|
||||
updated_fields.append("client_type")
|
||||
if (
|
||||
application.authorization_grant_type
|
||||
!= Application.GRANT_AUTHORIZATION_CODE
|
||||
):
|
||||
application.authorization_grant_type = Application.GRANT_AUTHORIZATION_CODE
|
||||
updated_fields.append("authorization_grant_type")
|
||||
if application.redirect_uris != redirect_uris:
|
||||
application.redirect_uris = redirect_uris
|
||||
updated_fields.append("redirect_uris")
|
||||
if application.skip_authorization != skip_authorization:
|
||||
application.skip_authorization = skip_authorization
|
||||
updated_fields.append("skip_authorization")
|
||||
if application.hash_client_secret is not True:
|
||||
application.hash_client_secret = True
|
||||
updated_fields.append("hash_client_secret")
|
||||
if not application.client_secret or not check_password(
|
||||
client_secret,
|
||||
application.client_secret,
|
||||
):
|
||||
application.client_secret = client_secret
|
||||
updated_fields.append("client_secret")
|
||||
|
||||
try:
|
||||
application.full_clean()
|
||||
except ValidationError as exc:
|
||||
errors = "; ".join(
|
||||
f"{field}: {', '.join(messages)}"
|
||||
for field, messages in exc.message_dict.items()
|
||||
)
|
||||
raise CommandError(f"Invalid MCP OAuth application settings: {errors}") from exc
|
||||
|
||||
if created:
|
||||
application.save()
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Created MCP OAuth application '{application.client_id}'."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if updated_fields:
|
||||
application.save(update_fields=updated_fields)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Updated MCP OAuth application '{application.client_id}'."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"MCP OAuth application '{application.client_id}' is already up to date."
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,253 @@
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from secrets import token_urlsafe
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.http import JsonResponse
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.decorators.http import require_http_methods
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
|
||||
Application = get_application_model()
|
||||
|
||||
SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {
|
||||
"none": Application.CLIENT_PUBLIC,
|
||||
"client_secret_basic": Application.CLIENT_CONFIDENTIAL,
|
||||
"client_secret_post": Application.CLIENT_CONFIDENTIAL,
|
||||
}
|
||||
SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
|
||||
SUPPORTED_RESPONSE_TYPES = {"code"}
|
||||
|
||||
|
||||
def _base_url(request):
|
||||
return settings.PUBLIC_BASE_URL or request.build_absolute_uri("/").rstrip("/")
|
||||
|
||||
|
||||
def _json_error(error, error_description, status=400):
|
||||
response = JsonResponse(
|
||||
{"error": error, "error_description": error_description},
|
||||
status=status,
|
||||
)
|
||||
response["Cache-Control"] = "no-store"
|
||||
response["Pragma"] = "no-cache"
|
||||
return response
|
||||
|
||||
|
||||
def _set_no_store_headers(response):
|
||||
response["Cache-Control"] = "no-store"
|
||||
response["Pragma"] = "no-cache"
|
||||
return response
|
||||
|
||||
|
||||
def _parse_json_request_body(request):
|
||||
try:
|
||||
payload = json.loads(request.body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise ValueError("Request body must be valid JSON.") from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Request body must be a JSON object.")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _get_string_list(payload, field_name, *, required=False, default=None):
|
||||
value = payload.get(field_name, default)
|
||||
if value is None:
|
||||
if required:
|
||||
raise ValueError(f"'{field_name}' is required.")
|
||||
return None
|
||||
|
||||
if not isinstance(value, list) or not value:
|
||||
raise ValueError(f"'{field_name}' must be a non-empty array of strings.")
|
||||
|
||||
normalized = []
|
||||
for item in value:
|
||||
if not isinstance(item, str) or not item.strip():
|
||||
raise ValueError(f"'{field_name}' must contain only non-empty strings.")
|
||||
normalized.append(item.strip())
|
||||
return normalized
|
||||
|
||||
|
||||
def _get_supported_scopes():
|
||||
return set(settings.OAUTH2_PROVIDER.get("SCOPES", {}).keys())
|
||||
|
||||
|
||||
def _dcr_initial_access_token_ok(request):
|
||||
"""Validate the optional RFC 7591 initial access token, if one is configured."""
|
||||
expected = settings.OAUTH2_DCR_INITIAL_ACCESS_TOKEN
|
||||
if not expected:
|
||||
return True
|
||||
|
||||
header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
scheme, _, value = header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not value:
|
||||
return False
|
||||
return hmac.compare_digest(value, expected)
|
||||
|
||||
|
||||
@require_http_methods(["GET"])
|
||||
def authorization_server_metadata(request):
|
||||
base_url = _base_url(request)
|
||||
metadata = {
|
||||
"issuer": base_url,
|
||||
"authorization_endpoint": f"{base_url}/oauth/authorize/",
|
||||
"token_endpoint": f"{base_url}/oauth/token/",
|
||||
"revocation_endpoint": f"{base_url}/oauth/revoke_token/",
|
||||
"introspection_endpoint": f"{base_url}/oauth/introspect/",
|
||||
"scopes_supported": sorted(settings.OAUTH2_PROVIDER["SCOPES"].keys()),
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"none",
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
# Only advertise registration when DCR is actually enabled.
|
||||
if settings.OAUTH2_DCR_ENABLED:
|
||||
metadata["registration_endpoint"] = f"{base_url}/oauth/register/"
|
||||
return JsonResponse(metadata)
|
||||
|
||||
|
||||
@csrf_exempt
|
||||
@require_http_methods(["POST"])
|
||||
def dynamic_client_registration(request):
|
||||
if not settings.OAUTH2_DCR_ENABLED:
|
||||
return _json_error(
|
||||
"not_found",
|
||||
"Dynamic client registration is disabled.",
|
||||
status=404,
|
||||
)
|
||||
|
||||
if not _dcr_initial_access_token_ok(request):
|
||||
return _json_error(
|
||||
"invalid_token",
|
||||
"A valid initial access token is required to register a client.",
|
||||
status=401,
|
||||
)
|
||||
|
||||
try:
|
||||
payload = _parse_json_request_body(request)
|
||||
redirect_uris = _get_string_list(payload, "redirect_uris", required=True)
|
||||
grant_types = _get_string_list(
|
||||
payload,
|
||||
"grant_types",
|
||||
default=["authorization_code"],
|
||||
)
|
||||
response_types = _get_string_list(
|
||||
payload,
|
||||
"response_types",
|
||||
default=["code"],
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _json_error("invalid_client_metadata", str(exc))
|
||||
|
||||
unsupported_grant_types = sorted(set(grant_types) - SUPPORTED_GRANT_TYPES)
|
||||
if unsupported_grant_types:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"Unsupported grant_types: " + ", ".join(unsupported_grant_types),
|
||||
)
|
||||
|
||||
if "authorization_code" not in grant_types:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"grant_types must include 'authorization_code'.",
|
||||
)
|
||||
|
||||
unsupported_response_types = sorted(set(response_types) - SUPPORTED_RESPONSE_TYPES)
|
||||
if unsupported_response_types:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"Unsupported response_types: "
|
||||
+ ", ".join(unsupported_response_types),
|
||||
)
|
||||
|
||||
if "code" not in response_types:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"response_types must include 'code'.",
|
||||
)
|
||||
|
||||
token_endpoint_auth_method = payload.get(
|
||||
"token_endpoint_auth_method",
|
||||
"client_secret_basic",
|
||||
)
|
||||
if token_endpoint_auth_method not in SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"Unsupported token_endpoint_auth_method: "
|
||||
+ token_endpoint_auth_method,
|
||||
)
|
||||
|
||||
supported_scopes = _get_supported_scopes()
|
||||
raw_scope = payload.get("scope", "mcp")
|
||||
if not isinstance(raw_scope, str):
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"'scope' must be a space-delimited string.",
|
||||
)
|
||||
requested_scope = raw_scope.strip() or "mcp"
|
||||
requested_scopes = set(requested_scope.split())
|
||||
unsupported_scopes = sorted(requested_scopes - supported_scopes)
|
||||
if unsupported_scopes:
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"Unsupported scope values: " + ", ".join(unsupported_scopes),
|
||||
)
|
||||
|
||||
client_name = str(payload.get("client_name", "Dynamic MCP Client")).strip()
|
||||
if not client_name:
|
||||
client_name = "Dynamic MCP Client"
|
||||
|
||||
client_secret = None
|
||||
client_type = SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS[token_endpoint_auth_method]
|
||||
if client_type == Application.CLIENT_CONFIDENTIAL:
|
||||
client_secret = token_urlsafe(48)
|
||||
|
||||
application = Application(
|
||||
name=client_name,
|
||||
client_type=client_type,
|
||||
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
|
||||
redirect_uris=" ".join(redirect_uris),
|
||||
skip_authorization=False,
|
||||
hash_client_secret=True,
|
||||
client_secret=client_secret or "",
|
||||
)
|
||||
|
||||
try:
|
||||
application.full_clean()
|
||||
except ValidationError as exc:
|
||||
errors = []
|
||||
for field, messages in exc.message_dict.items():
|
||||
errors.extend(f"{field}: {message}" for message in messages)
|
||||
return _json_error(
|
||||
"invalid_client_metadata",
|
||||
"; ".join(errors),
|
||||
)
|
||||
|
||||
application.save()
|
||||
|
||||
response_payload = {
|
||||
"client_id": application.client_id,
|
||||
"client_id_issued_at": int(time.time()),
|
||||
"client_name": client_name,
|
||||
"redirect_uris": redirect_uris,
|
||||
# Report what was actually provisioned, not the raw request echo. The app
|
||||
# is created with the authorization_code grant; refresh_token is implicit
|
||||
# to that grant in django-oauth-toolkit rather than a separate capability.
|
||||
"grant_types": sorted(set(grant_types) & SUPPORTED_GRANT_TYPES),
|
||||
"response_types": sorted(set(response_types) & SUPPORTED_RESPONSE_TYPES),
|
||||
"scope": " ".join(sorted(requested_scopes)),
|
||||
"token_endpoint_auth_method": token_endpoint_auth_method,
|
||||
}
|
||||
if client_secret is not None:
|
||||
response_payload["client_secret"] = client_secret
|
||||
response_payload["client_secret_expires_at"] = 0
|
||||
|
||||
return _set_no_store_headers(JsonResponse(response_payload, status=201))
|
||||
@@ -0,0 +1,298 @@
|
||||
import os
|
||||
import json
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.management import call_command
|
||||
from django.test import SimpleTestCase, TestCase, override_settings
|
||||
from django.utils import timezone
|
||||
from django.urls import reverse
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
from apps.users.models import APIToken
|
||||
|
||||
Application = get_application_model()
|
||||
|
||||
|
||||
@override_settings(
|
||||
PUBLIC_BASE_URL="https://wygiwyh.example.com",
|
||||
SECRET_KEY="test-secret-key",
|
||||
OAUTH2_PROVIDER={"SCOPES": {"mcp": "Access WYGIWYH from MCP clients."}},
|
||||
)
|
||||
class AuthorizationServerMetadataTests(SimpleTestCase):
|
||||
@override_settings(OAUTH2_DCR_ENABLED=True)
|
||||
def test_returns_oauth_authorization_server_metadata(self):
|
||||
response = self.client.get(reverse("oauth-authorization-server-metadata"))
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json()["issuer"], "https://wygiwyh.example.com")
|
||||
self.assertEqual(
|
||||
response.json()["authorization_endpoint"],
|
||||
"https://wygiwyh.example.com/oauth/authorize/",
|
||||
)
|
||||
self.assertEqual(
|
||||
response.json()["registration_endpoint"],
|
||||
"https://wygiwyh.example.com/oauth/register/",
|
||||
)
|
||||
self.assertEqual(response.json()["scopes_supported"], ["mcp"])
|
||||
self.assertIn("none", response.json()["token_endpoint_auth_methods_supported"])
|
||||
|
||||
@override_settings(OAUTH2_DCR_ENABLED=False)
|
||||
def test_omits_registration_endpoint_when_dcr_disabled(self):
|
||||
response = self.client.get(reverse("oauth-authorization-server-metadata"))
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertNotIn("registration_endpoint", response.json())
|
||||
|
||||
|
||||
@override_settings(
|
||||
PUBLIC_BASE_URL="https://wygiwyh.example.com",
|
||||
SECRET_KEY="test-secret-key",
|
||||
OAUTH2_PROVIDER={"SCOPES": {"mcp": "Access WYGIWYH from MCP clients."}},
|
||||
OAUTH2_DCR_ENABLED=True,
|
||||
OAUTH2_DCR_INITIAL_ACCESS_TOKEN="",
|
||||
)
|
||||
class DynamicClientRegistrationTests(TestCase):
|
||||
def test_registers_public_client_for_pkce_flow(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"client_name": "Copilot MCP",
|
||||
"redirect_uris": ["http://127.0.0.1:8765/callback"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scope": "mcp",
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 201)
|
||||
payload = response.json()
|
||||
self.assertEqual(payload["client_name"], "Copilot MCP")
|
||||
self.assertEqual(
|
||||
payload["redirect_uris"],
|
||||
["http://127.0.0.1:8765/callback"],
|
||||
)
|
||||
self.assertEqual(
|
||||
payload["grant_types"],
|
||||
["authorization_code", "refresh_token"],
|
||||
)
|
||||
self.assertEqual(payload["response_types"], ["code"])
|
||||
self.assertEqual(payload["scope"], "mcp")
|
||||
self.assertEqual(payload["token_endpoint_auth_method"], "none")
|
||||
self.assertNotIn("client_secret", payload)
|
||||
|
||||
application = Application.objects.get(client_id=payload["client_id"])
|
||||
self.assertEqual(application.name, "Copilot MCP")
|
||||
self.assertEqual(application.client_type, Application.CLIENT_PUBLIC)
|
||||
self.assertEqual(
|
||||
application.authorization_grant_type,
|
||||
Application.GRANT_AUTHORIZATION_CODE,
|
||||
)
|
||||
self.assertEqual(
|
||||
application.redirect_uris,
|
||||
"http://127.0.0.1:8765/callback",
|
||||
)
|
||||
|
||||
def test_registers_confidential_client_with_generated_secret(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"client_name": "Confidential MCP",
|
||||
"redirect_uris": ["http://127.0.0.1:8765/callback"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 201)
|
||||
payload = response.json()
|
||||
self.assertEqual(payload["token_endpoint_auth_method"], "client_secret_basic")
|
||||
self.assertEqual(payload["scope"], "mcp")
|
||||
self.assertEqual(payload["client_secret_expires_at"], 0)
|
||||
self.assertTrue(payload["client_secret"])
|
||||
|
||||
application = Application.objects.get(client_id=payload["client_id"])
|
||||
self.assertEqual(application.client_type, Application.CLIENT_CONFIDENTIAL)
|
||||
self.assertTrue(check_password(payload["client_secret"], application.client_secret))
|
||||
|
||||
def test_rejects_unsupported_token_auth_method(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"redirect_uris": ["http://127.0.0.1:8765/callback"],
|
||||
"token_endpoint_auth_method": "private_key_jwt",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"], "invalid_client_metadata")
|
||||
self.assertIn("token_endpoint_auth_method", response.json()["error_description"])
|
||||
|
||||
def test_rejects_missing_redirect_uris(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps({"client_name": "No redirect"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"], "invalid_client_metadata")
|
||||
self.assertIn("redirect_uris", response.json()["error_description"])
|
||||
|
||||
@override_settings(OAUTH2_DCR_ENABLED=False)
|
||||
def test_returns_404_when_dcr_disabled(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps({"redirect_uris": ["http://127.0.0.1:8765/callback"]}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(Application.objects.count(), 0)
|
||||
|
||||
|
||||
@override_settings(
|
||||
PUBLIC_BASE_URL="https://wygiwyh.example.com",
|
||||
SECRET_KEY="test-secret-key",
|
||||
OAUTH2_PROVIDER={"SCOPES": {"mcp": "Access WYGIWYH from MCP clients."}},
|
||||
OAUTH2_DCR_ENABLED=True,
|
||||
OAUTH2_DCR_INITIAL_ACCESS_TOKEN="s3cret-iat",
|
||||
)
|
||||
class DynamicClientRegistrationInitialAccessTokenTests(TestCase):
|
||||
def test_rejects_registration_without_initial_access_token(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps({"redirect_uris": ["http://127.0.0.1:8765/callback"]}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 401)
|
||||
self.assertEqual(response.json()["error"], "invalid_token")
|
||||
self.assertEqual(Application.objects.count(), 0)
|
||||
|
||||
def test_allows_registration_with_initial_access_token(self):
|
||||
response = self.client.post(
|
||||
reverse("oauth-dynamic-client-registration"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"redirect_uris": ["http://127.0.0.1:8765/callback"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
HTTP_AUTHORIZATION="Bearer s3cret-iat",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertEqual(Application.objects.count(), 1)
|
||||
|
||||
|
||||
class SetupOAuthCommandTests(TestCase):
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"MCP_OAUTH_CLIENT_ID": "mcp-wygiwyh",
|
||||
"MCP_OAUTH_CLIENT_SECRET": "super-secret",
|
||||
"MCP_OAUTH_REDIRECT_URIS": "http://127.0.0.1:8765/callback",
|
||||
},
|
||||
clear=False,
|
||||
)
|
||||
def test_creates_mcp_oauth_application(self):
|
||||
call_command("setup_oauth")
|
||||
|
||||
application = Application.objects.get(client_id="mcp-wygiwyh")
|
||||
self.assertEqual(application.name, "WYGIWYH MCP")
|
||||
self.assertEqual(application.client_type, Application.CLIENT_CONFIDENTIAL)
|
||||
self.assertEqual(
|
||||
application.authorization_grant_type,
|
||||
Application.GRANT_AUTHORIZATION_CODE,
|
||||
)
|
||||
self.assertEqual(
|
||||
application.redirect_uris,
|
||||
"http://127.0.0.1:8765/callback",
|
||||
)
|
||||
self.assertFalse(application.skip_authorization)
|
||||
self.assertTrue(check_password("super-secret", application.client_secret))
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"MCP_OAUTH_CLIENT_ID": "mcp-wygiwyh",
|
||||
"MCP_OAUTH_CLIENT_SECRET": "new-secret",
|
||||
"MCP_OAUTH_REDIRECT_URIS": "http://127.0.0.1:8765/callback http://localhost:8765/callback",
|
||||
"MCP_OAUTH_CLIENT_NAME": "WYGIWYH MCP Production",
|
||||
"MCP_OAUTH_SKIP_AUTHORIZATION": "true",
|
||||
},
|
||||
clear=False,
|
||||
)
|
||||
def test_updates_existing_mcp_oauth_application(self):
|
||||
Application.objects.create(
|
||||
client_id="mcp-wygiwyh",
|
||||
client_secret="old-secret",
|
||||
name="Old Name",
|
||||
client_type=Application.CLIENT_CONFIDENTIAL,
|
||||
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
|
||||
redirect_uris="http://127.0.0.1:8765/callback",
|
||||
skip_authorization=False,
|
||||
)
|
||||
|
||||
call_command("setup_oauth")
|
||||
|
||||
application = Application.objects.get(client_id="mcp-wygiwyh")
|
||||
self.assertEqual(application.name, "WYGIWYH MCP Production")
|
||||
self.assertEqual(
|
||||
application.redirect_uris,
|
||||
"http://127.0.0.1:8765/callback http://localhost:8765/callback",
|
||||
)
|
||||
self.assertTrue(application.skip_authorization)
|
||||
self.assertTrue(check_password("new-secret", application.client_secret))
|
||||
|
||||
|
||||
class CreateAPITokenCommandTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = get_user_model().objects.create_user(
|
||||
email="n8n@example.com",
|
||||
password="test-password",
|
||||
)
|
||||
|
||||
def test_creates_hashed_api_token_and_prints_raw_value(self):
|
||||
stdout = StringIO()
|
||||
|
||||
call_command(
|
||||
"create_api_token",
|
||||
self.user.email,
|
||||
"--name",
|
||||
"n8n sync",
|
||||
stdout=stdout,
|
||||
)
|
||||
|
||||
token = APIToken.objects.get(user=self.user, name="n8n sync")
|
||||
lines = [line.strip() for line in stdout.getvalue().splitlines() if line.strip()]
|
||||
raw_token = lines[-1]
|
||||
|
||||
self.assertTrue(raw_token.startswith(APIToken.TOKEN_PREFIX))
|
||||
self.assertNotEqual(token.token_hash, raw_token)
|
||||
self.assertTrue(token.check_secret(APIToken.parse_raw_token(raw_token)[1]))
|
||||
|
||||
def test_supports_expiring_tokens(self):
|
||||
call_command(
|
||||
"create_api_token",
|
||||
self.user.email,
|
||||
"--expires-in-days",
|
||||
"7",
|
||||
)
|
||||
|
||||
token = APIToken.objects.get(user=self.user)
|
||||
self.assertIsNotNone(token.expires_at)
|
||||
self.assertGreater(token.expires_at, timezone.now())
|
||||
Reference in New Issue
Block a user