Skip to content

Commit c1ae697

Browse files
committed
refactor
1 parent 50c0a65 commit c1ae697

File tree

1 file changed

+53
-22
lines changed
  • src/news_mcp_server/middlewares

1 file changed

+53
-22
lines changed
Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,76 @@
1-
# @Author: Zhu Guowei
2-
# @Date: 2025/6/17
3-
# @Function:
41
from starlette.middleware.base import BaseHTTPMiddleware
52
from starlette.requests import Request
63
from starlette.responses import JSONResponse
74
from starlette import status
85
from structlog import get_logger
96
logger = get_logger(__name__)
107
import os
8+
from typing import Optional
9+
import ipaddress
1110

1211
ALLOW_HOSTS = ["172.20.80.1", "127.0.0.1"]
12+
13+
# 提取获取客户端真实 IP 的函数
14+
def get_client_ip(request: Request) -> str:
15+
headers = request.headers
16+
x_forwarded_for = headers.get("x-forwarded-for")
17+
if x_forwarded_for:
18+
return x_forwarded_for.split(",")[0].strip()
19+
x_real_ip = headers.get("x-real-ip")
20+
if x_real_ip:
21+
return x_real_ip.strip()
22+
return request.client.host
23+
24+
# 提取 IP 白名单判断函数
25+
def is_ip_allowed(client_ip: str) -> bool:
26+
try:
27+
ip = ipaddress.ip_address(client_ip)
28+
except ValueError:
29+
return False
30+
# 私有网络或回环地址自动放行
31+
if ip.is_private or ip.is_loopback:
32+
return True
33+
# 额外白名单
34+
return client_ip in ALLOW_HOSTS
35+
36+
# 提取解析 Bearer Token 的函数
37+
def get_bearer_token(auth_header: Optional[str]) -> Optional[str]:
38+
if not auth_header or not auth_header.lower().startswith("bearer "):
39+
return None
40+
return auth_header[7:].strip()
41+
42+
# 提取设置认证通过标识函数
43+
def mark_session_authenticated(request: Request) -> None:
44+
# 如果启用了 SessionMiddleware 或 RedisSessionMiddleware,scope['session'] 应该是 dict
45+
session = request.scope.get("session")
46+
if isinstance(session, dict):
47+
session["is_authenticated"] = True
48+
1349
class SimpleAuthMiddleware(BaseHTTPMiddleware):
1450
def __init__(self, app, api_key_env="API_KEY"):
1551
super().__init__(app)
1652
self.api_key = os.getenv(api_key_env)
17-
logger.info(f"api key: {self.api_key}")
53+
if not self.api_key:
54+
raise RuntimeError("API_KEY must be set for SimpleAuthMiddleware")
55+
logger.info("SimpleAuthMiddleware initialized")
1856

1957
async def dispatch(self, request: Request, call_next):
20-
# 如果未配置 API_KEY,且允许跳过,则直接放行(便于开发环境)
21-
22-
host = request.headers.get("HOST")
23-
headers = request.headers
24-
auth_header = headers.get("authorization")
25-
26-
for h in ALLOW_HOSTS:
27-
if h in host:
28-
return await call_next(request)
29-
if not self.api_key:
58+
auth_header = request.headers.get("authorization")
59+
client_ip = get_client_ip(request)
60+
if is_ip_allowed(client_ip):
61+
mark_session_authenticated(request)
3062
return await call_next(request)
63+
logger.info("simple-auth", client_ip=client_ip, header=auth_header)
3164

32-
# 获取 Authorization头
33-
logger.info(f"simple-auth", host=host, header=auth_header)
34-
35-
if not auth_header or not auth_header.lower().startswith("bearer "):
36-
logger.warning(f"simple-auth", host=host, detail="Bearer Token Not Provided", header=auth_header)
65+
token = get_bearer_token(auth_header)
66+
if token is None:
67+
logger.warning("simple-auth", client_ip=client_ip, detail="Bearer Token Not Provided", header=auth_header)
3768
return JSONResponse({"detail": "Bearer Token Not Provided"}, status_code=status.HTTP_401_UNAUTHORIZED)
3869

39-
token = auth_header[7:].strip()
4070
if token != self.api_key:
41-
logger.warning(f"simple-auth", host=host, detail="Invalid Token", token=token, header=auth_header)
71+
logger.warning("simple-auth", client_ip=client_ip, detail="Invalid Token", token=token, header=auth_header)
4272
return JSONResponse({"detail": "Invalid Token"}, status_code=status.HTTP_403_FORBIDDEN)
4373

44-
# 认证通过,继续处理请求
74+
# 认证通过,设置 session 标识并继续处理
75+
mark_session_authenticated(request)
4576
return await call_next(request)

0 commit comments

Comments
 (0)