Skip to content

Commit 07f642d

Browse files
committed
fix validate gateway url function
Signed-off-by: Keval Mahajan <[email protected]>
1 parent c71484f commit 07f642d

File tree

1 file changed

+131
-39
lines changed

1 file changed

+131
-39
lines changed

mcpgateway/services/gateway_service.py

Lines changed: 131 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -383,64 +383,156 @@ def normalize_url(url: str) -> str:
383383
return url
384384

385385
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None):
386-
"""
387-
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
386+
"""Validates whether a given URL is a valid MCP SSE or StreamableHTTP endpoint.
387+
388+
The function performs a lightweight protocol verification:
389+
* For STREAMABLEHTTP, it sends a JSON-RPC ping request.
390+
* For SSE, it sends a GET request expecting ``text/event-stream``.
391+
392+
Any authentication error, invalid content-type, unreachable endpoint,
393+
unsupported transport type, or raised exception results in ``False``.
388394
389395
Args:
390-
url (str): The full URL of the endpoint to validate.
391-
headers (dict): Headers to be included in the requests (e.g., Authorization).
392-
transport_type (str): SSE or STREAMABLEHTTP
393-
timeout (int, optional): Timeout in seconds. Defaults to settings.gateway_validation_timeout.
396+
url (str): The endpoint URL to validate.
397+
headers (dict): Request headers including authorization or protocol version.
398+
transport_type (str): Expected transport type. One of:
399+
* "SSE"
400+
* "STREAMABLEHTTP"
401+
timeout (int, optional): Request timeout in seconds. Uses default
402+
settings.gateway_validation_timeout if not provided.
394403
395404
Returns:
396-
bool: True if the endpoint is reachable and supports SSE/StreamableHTTP, otherwise False.
405+
bool: True if endpoint is reachable and matches protocol expectations.
406+
False for any failure or exception.
407+
408+
Examples:
409+
410+
Invalid transport type:
411+
>>> class T:
412+
... async def _validate_gateway_url(self, *a, **k):
413+
... return False
414+
>>> import asyncio
415+
>>> asyncio.run(T()._validate_gateway_url(
416+
... "http://example.com", {}, "WRONG"
417+
... ))
418+
False
419+
420+
Authentication failure (simulated):
421+
>>> class T:
422+
... async def _validate_gateway_url(self, *a, **k):
423+
... return False
424+
>>> asyncio.run(T()._validate_gateway_url(
425+
... "http://example.com/protected",
426+
... {"Authorization": "Invalid"},
427+
... "SSE"
428+
... ))
429+
False
430+
431+
Incorrect content-type (simulated):
432+
>>> class T:
433+
... async def _validate_gateway_url(self, *a, **k):
434+
... return False
435+
>>> asyncio.run(T()._validate_gateway_url(
436+
... "http://example.com/stream", {}, "STREAMABLEHTTP"
437+
... ))
438+
False
439+
440+
Network or unexpected exception (simulated):
441+
>>> class T:
442+
... async def _validate_gateway_url(self, *a, **k):
443+
... raise Exception("Simulated error")
444+
>>> try:
445+
... asyncio.run(T()._validate_gateway_url(
446+
... "http://example.com", {}, "SSE"
447+
... ))
448+
... except Exception as e:
449+
... isinstance(e, Exception)
450+
True
397451
"""
398-
if timeout is None:
399-
timeout = settings.gateway_validation_timeout
452+
timeout = timeout or settings.gateway_validation_timeout
453+
protocol_version = settings.protocol_version
454+
transport = (transport_type or "").upper()
455+
456+
# create validation client
400457
validation_client = ResilientHttpClient(
401458
client_args={
402-
"timeout": settings.gateway_validation_timeout,
459+
"timeout": timeout,
403460
"verify": not settings.skip_ssl_verify,
404-
# Let httpx follow only proper HTTP redirects (3xx) and
405-
# enforce a sensible redirect limit.
406461
"follow_redirects": True,
407462
"max_redirects": settings.gateway_max_redirects,
408463
}
409464
)
410465

466+
# headers copy
467+
h = dict(headers or {})
468+
469+
# Small helper
470+
def _auth_or_not_found(status: int) -> bool:
471+
return status in (401, 403, 404)
472+
411473
try:
412-
# Make a single request and let httpx follow valid redirects.
413-
async with validation_client.client.stream("GET", url, headers=headers, timeout=timeout) as response:
414-
response_headers = dict(response.headers)
415-
content_type = response_headers.get("content-type", "")
416-
logger.info(f"Validating gateway URL {url}, received status {response.status_code}, content_type: {content_type}")
417-
418-
# Authentication failures mean the endpoint is not usable
419-
if response.status_code in (401, 403, 404):
420-
logger.debug(f"Authentication failed for {url} with status {response.status_code}")
421-
return False
474+
# STREAMABLE HTTP VALIDATION
475+
if transport == "STREAMABLEHTTP":
476+
h.setdefault("Content-Type", "application/json")
477+
h.setdefault("Accept", "application/json, text/event-stream")
478+
h.setdefault("MCP-Protocol-Version", "2025-06-18")
479+
480+
ping = {
481+
"jsonrpc": "2.0",
482+
"id": "ping-1",
483+
"method": "ping",
484+
"params": {},
485+
}
486+
487+
try:
488+
async with validation_client.client.stream("POST", url, headers=h, timeout=timeout, json=ping) as resp:
489+
status = resp.status_code
490+
ctype = resp.headers.get("content-type", "")
422491

423-
# STREAMABLEHTTP: expect an MCP session id and JSON content
424-
if transport_type == "STREAMABLEHTTP":
425-
mcp_session_id = response_headers.get("mcp-session-id")
426-
if mcp_session_id is not None and mcp_session_id != "":
427-
if content_type is not None and content_type != "" and "application/json" in content_type:
492+
if _auth_or_not_found(status):
493+
return False
494+
495+
# Accept both JSON and EventStream
496+
if ("application/json" in ctype) or ("text/event-stream" in ctype):
428497
return True
429498

430-
# SSE: expect text/event-stream
431-
if transport_type == "SSE":
432-
logger.info(f"Validating SSE gateway URL {url}")
433-
if "text/event-stream" in content_type:
434-
return True
499+
return False
500+
501+
except Exception:
502+
return False
503+
504+
# SSE VALIDATION
505+
elif transport == "SSE":
506+
h.setdefault("Accept", "text/event-stream")
507+
h.setdefault("MCP-Protocol-Version", protocol_version)
508+
509+
try:
510+
async with validation_client.client.stream("GET", url, headers=h, timeout=timeout) as resp:
511+
status = resp.status_code
512+
ctype = resp.headers.get("content-type", "")
513+
514+
if _auth_or_not_found(status):
515+
return False
516+
517+
if "text/event-stream" not in ctype:
518+
return False
519+
520+
# Check if at least one SSE line arrives
521+
async for line in resp.aiter_lines():
522+
if line.strip():
523+
return True
524+
525+
return False
526+
527+
except Exception:
528+
return False
529+
530+
# INVALID TRANSPORT
531+
else:
532+
return False
435533

436-
return False
437-
except httpx.UnsupportedProtocol as e:
438-
logger.debug(f"Gateway URL Unsupported Protocol for {url}: {str(e)}", exc_info=True)
439-
return False
440-
except Exception as e:
441-
logger.debug(f"Gateway validation failed for {url}: {str(e)}", exc_info=True)
442-
return False
443534
finally:
535+
# always cleanly close the client
444536
await validation_client.aclose()
445537

446538
def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext:

0 commit comments

Comments
 (0)