diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ca5b7b45a..8cc01248d 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_serializer, field_validator class OAuthToken(BaseModel): @@ -128,6 +128,18 @@ class OAuthMetadata(BaseModel): code_challenge_methods_supported: list[str] | None = None client_id_metadata_document_supported: bool | None = None + @field_serializer("issuer") + def serialize_issuer_without_trailing_slash(self, v: AnyHttpUrl) -> str: + """Strip trailing slash from issuer URL during serialization. + + RFC 8414 examples show issuer URLs without trailing slashes, and some + OAuth clients (Google ADK, IBM MCP Context Forge) require exact match + between discovery URL and returned issuer per RFC 8414 Section 3.3. + Pydantic's AnyHttpUrl automatically adds a trailing slash, which breaks + these clients. See: https://github.com/modelcontextprotocol/python-sdk/issues/1919 + """ + return str(v).rstrip("/") + class ProtectedResourceMetadata(BaseModel): """RFC 9728 OAuth 2.0 Protected Resource Metadata. @@ -150,3 +162,17 @@ class ProtectedResourceMetadata(BaseModel): dpop_signing_alg_values_supported: list[str] | None = None # dpop_bound_access_tokens_required default is False, but omitted here for clarity dpop_bound_access_tokens_required: bool | None = None + + @field_serializer("resource") + def serialize_resource_without_trailing_slash(self, v: AnyHttpUrl) -> str: + """Strip trailing slash from resource URL during serialization. + + Same rationale as OAuthMetadata.issuer - RFC specs show URLs without + trailing slashes, and clients may require exact URL matching. + """ + return str(v).rstrip("/") + + @field_serializer("authorization_servers") + def serialize_auth_servers_without_trailing_slash(self, v: list[AnyHttpUrl]) -> list[str]: + """Strip trailing slashes from authorization server URLs during serialization.""" + return [str(url).rstrip("/") for url in v] diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e36..188f4d518 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1414,6 +1414,7 @@ async def mock_callback() -> tuple[str, str | None]: @pytest.mark.parametrize( ( "issuer_url", + "expected_issuer", "service_documentation_url", "authorization_endpoint", "token_endpoint", @@ -1421,9 +1422,8 @@ async def mock_callback() -> tuple[str, str | None]: "revocation_endpoint", ), ( - # Pydantic's AnyUrl incorrectly adds trailing slash to base URLs - # This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+) pytest.param( + "https://auth.example.com", "https://auth.example.com", "https://auth.example.com/docs", "https://auth.example.com/authorize", @@ -1431,12 +1431,10 @@ async def mock_callback() -> tuple[str, str | None]: "https://auth.example.com/register", "https://auth.example.com/revoke", id="simple-url", - marks=pytest.mark.xfail( - reason="Pydantic AnyUrl adds trailing slash to base URLs - fixed in Pydantic 2.12+" - ), ), pytest.param( "https://auth.example.com/", + "https://auth.example.com", # trailing slash stripped per RFC 8414 "https://auth.example.com/docs", "https://auth.example.com/authorize", "https://auth.example.com/token", @@ -1445,6 +1443,7 @@ async def mock_callback() -> tuple[str, str | None]: id="with-trailing-slash", ), pytest.param( + "https://auth.example.com/v1/mcp", "https://auth.example.com/v1/mcp", "https://auth.example.com/v1/mcp/docs", "https://auth.example.com/v1/mcp/authorize", @@ -1457,6 +1456,7 @@ async def mock_callback() -> tuple[str, str | None]: ) def test_build_metadata( issuer_url: str, + expected_issuer: str, service_documentation_url: str, authorization_endpoint: str, token_endpoint: str, @@ -1472,7 +1472,7 @@ def test_build_metadata( assert metadata.model_dump(exclude_defaults=True, mode="json") == snapshot( { - "issuer": Is(issuer_url), + "issuer": Is(expected_issuer), "authorization_endpoint": Is(authorization_endpoint), "token_endpoint": Is(token_endpoint), "registration_endpoint": Is(registration_endpoint), diff --git a/tests/server/auth/test_protected_resource.py b/tests/server/auth/test_protected_resource.py index 413a80276..a04f89efa 100644 --- a/tests/server/auth/test_protected_resource.py +++ b/tests/server/auth/test_protected_resource.py @@ -94,10 +94,11 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC # For root resource, metadata should be at standard location response = await root_resource_client.get("/.well-known/oauth-protected-resource") assert response.status_code == 200 + # Note: URLs should NOT have trailing slashes per RFC 8414/9728 (see issue #1919) assert response.json() == snapshot( { - "resource": "https://example.com/", - "authorization_servers": ["https://auth.example.com/"], + "resource": "https://example.com", + "authorization_servers": ["https://auth.example.com"], "scopes_supported": ["read"], "resource_name": "Root Resource", "bearer_methods_supported": ["header"], diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 602f5cc75..4f77d74fb 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -312,7 +312,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert response.status_code == 200 metadata = response.json() - assert metadata["issuer"] == "https://auth.example.com/" + assert metadata["issuer"] == "https://auth.example.com" assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register" diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index cd3c35332..d66914c67 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -1,6 +1,10 @@ """Tests for OAuth 2.0 shared code.""" -from mcp.shared.auth import OAuthMetadata +import json + +from pydantic import AnyHttpUrl + +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata def test_oauth(): @@ -58,3 +62,72 @@ def test_oauth_with_jarm(): "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], } ) + + +class TestIssuerTrailingSlash: + """Tests for issue #1919: trailing slash in issuer URL. + + RFC 8414 examples show issuer URLs without trailing slashes, and some + OAuth clients require exact match between discovery URL and returned issuer. + Pydantic's AnyHttpUrl automatically adds a trailing slash, so we strip it + during serialization. + """ + + def test_oauth_metadata_issuer_no_trailing_slash_in_json(self): + """Serialized issuer should not have trailing slash.""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://example.com"), + authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"), + ) + serialized = json.loads(metadata.model_dump_json()) + assert serialized["issuer"] == "https://example.com" + assert not serialized["issuer"].endswith("/") + + def test_oauth_metadata_issuer_with_path_preserves_path(self): + """Issuer with path should preserve the path, only strip trailing slash.""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://example.com/auth"), + authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"), + ) + serialized = json.loads(metadata.model_dump_json()) + assert serialized["issuer"] == "https://example.com/auth" + assert not serialized["issuer"].endswith("/") + + def test_oauth_metadata_issuer_with_path_and_trailing_slash(self): + """Issuer with path and trailing slash should only strip the trailing slash.""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://example.com/auth/"), + authorization_endpoint=AnyHttpUrl("https://example.com/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://example.com/oauth2/token"), + ) + serialized = json.loads(metadata.model_dump_json()) + assert serialized["issuer"] == "https://example.com/auth" + + def test_protected_resource_metadata_no_trailing_slash(self): + """ProtectedResourceMetadata.resource should not have trailing slash.""" + metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://example.com"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + serialized = json.loads(metadata.model_dump_json()) + assert serialized["resource"] == "https://example.com" + assert not serialized["resource"].endswith("/") + + def test_protected_resource_metadata_auth_servers_no_trailing_slash(self): + """ProtectedResourceMetadata.authorization_servers should not have trailing slashes.""" + metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://example.com"), + authorization_servers=[ + AnyHttpUrl("https://auth1.example.com"), + AnyHttpUrl("https://auth2.example.com/path"), + ], + ) + serialized = json.loads(metadata.model_dump_json()) + assert serialized["authorization_servers"] == [ + "https://auth1.example.com", + "https://auth2.example.com/path", + ] + for url in serialized["authorization_servers"]: + assert not url.endswith("/")