Skip to content

Commit bd45429

Browse files
committed
feat: support ratelimit and audit middleware
1 parent 764e7ed commit bd45429

File tree

9 files changed

+223
-19
lines changed

9 files changed

+223
-19
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ wheels/
1616
# Logs
1717
*.log
1818
*.out
19-
.pytest_cache/
19+
.pytest_cache/
20+
volumes

src/news_mcp_server/app.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,20 @@
44
from .mcp_server import mcp_app
55
from .middlewares.auth import SimpleAuthMiddleware
66
from .middlewares.monitor import MonitorMiddleware, metrics
7-
from .middlewares.session import SessionMiddleware
7+
from .middlewares.rate_limit import RedisRateLimitMiddleware
88
from .config.settings import app_settings
9-
middlewares = [Middleware(SessionMiddleware,
10-
secret_key=app_settings.SESSION_SECREY_KEY,
11-
max_age=3600),
12-
Middleware(MonitorMiddleware),
13-
Middleware(SimpleAuthMiddleware)
14-
]
9+
10+
middlewares = [
11+
# IP 速率限制
12+
Middleware(RedisRateLimitMiddleware,
13+
redis_url=app_settings.REDIS_URL,
14+
max_requests=app_settings.RATE_LIMIT_MAX,
15+
window_seconds=app_settings.RATE_LIMIT_WINDOW),
16+
# 监控中间件
17+
Middleware(MonitorMiddleware),
18+
# 简单认证
19+
Middleware(SimpleAuthMiddleware),
20+
]
1521

1622
allow_origins = [
1723
"http://localhost:8000",

src/news_mcp_server/config/settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ class ApplicationSettings(BaseModel):
1010
CORS_HEADERS: list = ["*"]
1111
CORS_ALLOW_CREDENTIALS: bool = True
1212
API_KEY: str | None = os.getenv("NEWS_MCP_API_KEY")
13-
SESSION_SECREY_KEY: str = os.getenv("SESSION_SECREY_KEY")
13+
SESSION_SECRET_KEY: str = os.getenv("SESSION_SECRET_KEY")
14+
REDIS_URL: str = os.getenv("REDIS_URL", "redis://redis:6379/0")
15+
# IP 限流配置
16+
RATE_LIMIT_MAX: int = int(os.getenv("RATE_LIMIT_MAX", 100)) # 单个 IP 在时间窗口内最大请求数
17+
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", 60)) # 限流窗口时长(秒)
1418
TRANSPORT: str = "streamable-http"
1519

1620

src/news_mcp_server/mcp_server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import List
22
from fastmcp import FastMCP, Context
3+
from fastmcp.server.http import Middleware
34
from pydantic import Field
45
import contextlib
56
from .services.news_service import NewsService
67
from .clients.elastic_client import AsyncElasticClient
8+
from .middlewares.audit import AuditMiddleware
79
import structlog
810
logger = structlog.get_logger(__name__)
911

@@ -13,7 +15,10 @@ class NewsMCP(FastMCP):
1315
pass
1416

1517
def create_http_app(mcp):
16-
mcp_app = mcp.http_app("/es-news-mcp")
18+
middlewares = [
19+
Middleware(AuditMiddleware)
20+
]
21+
mcp_app = mcp.http_app("/es-news-mcp", middleware=middlewares)
1722
return mcp_app
1823
app_services = {}
1924

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from starlette.middleware.base import BaseHTTPMiddleware
2+
from starlette.requests import Request
3+
from starlette.responses import Response
4+
import time
5+
from ..utils.logger import logger
6+
7+
class AuditMiddleware(BaseHTTPMiddleware):
8+
async def dispatch(self, request: Request, call_next) -> Response:
9+
method = None
10+
params = None
11+
# 尝试解析 JSON-RPC 请求体中的方法名和参数
12+
if request.method.upper() == "POST":
13+
try:
14+
body = await request.json()
15+
method = body.get("method")
16+
params = body.get("params")
17+
except Exception:
18+
pass
19+
# 获取客户端 IP
20+
client_ip = None
21+
if request.client:
22+
client_ip = request.client.host
23+
start_time = time.time()
24+
response = await call_next(request)
25+
duration = time.time() - start_time
26+
# 审计记录:工具名、参数、客户端IP、状态码、耗时(ms)
27+
logger.info(
28+
"mcp_tool_audit",
29+
method=method,
30+
params=params,
31+
client_ip=client_ip,
32+
status_code=response.status_code,
33+
duration_ms=int(duration * 1000)
34+
)
35+
return response
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# @Author: Zhu Guowei
22
# @Date: 2025/6/17
33
# @Function:
4-
from starlette.middleware.authentication import AuthenticationMiddleware
54
from starlette.middleware.base import BaseHTTPMiddleware
65
from starlette.requests import Request
76
from starlette.responses import JSONResponse
7+
from starlette import status
88
from structlog import get_logger
99
logger = get_logger(__name__)
1010
import os
@@ -20,24 +20,26 @@ async def dispatch(self, request: Request, call_next):
2020
# 如果未配置 API_KEY,且允许跳过,则直接放行(便于开发环境)
2121

2222
host = request.headers.get("HOST")
23-
logger.info(f"host: {host}, {request.headers}")
23+
headers = request.headers
24+
auth_header = headers.get("authorization")
25+
2426
for h in ALLOW_HOSTS:
2527
if h in host:
2628
return await call_next(request)
2729
if not self.api_key:
28-
logger.info(f"request.url: {request.url}")
2930
return await call_next(request)
3031

3132
# 获取 Authorization头
32-
auth_header = request.headers.get("authorization")
33-
logger.info(f"url: {request.url}, query_params: {request.query_params}")
34-
logger.info(f"auth-header: {auth_header}")
33+
logger.info(f"simple-auth", host=host, header=auth_header)
34+
3535
if not auth_header or not auth_header.lower().startswith("bearer "):
36-
return JSONResponse({"detail": "Bearer Token Not Provided"}, status_code=401)
36+
logger.warning(f"simple-auth", host=host, detail="Bearer Token Not Provided", header=auth_header)
37+
return JSONResponse({"detail": "Bearer Token Not Provided"}, status_code=status.HTTP_401_UNAUTHORIZED)
3738

3839
token = auth_header[7:].strip()
3940
if token != self.api_key:
40-
return JSONResponse({"detail": "Invalid Token"}, status_code=403)
41+
logger.warning(f"simple-auth", host=host, detail="Invalid Token", token=token, header=auth_header)
42+
return JSONResponse({"detail": "Invalid Token"}, status_code=status.HTTP_403_FORBIDDEN)
4143

4244
# 认证通过,继续处理请求
4345
return await call_next(request)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import time
2+
from starlette.middleware.base import BaseHTTPMiddleware
3+
from starlette.requests import Request
4+
from starlette.responses import JSONResponse
5+
from starlette import status
6+
from redis import asyncio as aioredis
7+
from ..utils.logger import logger
8+
9+
10+
class RedisRateLimitMiddleware(BaseHTTPMiddleware):
11+
"""
12+
基于 Redis 的简单 IP 限流中间件。
13+
每个 IP 在固定时间窗口内最多允许 max_requests 次请求。
14+
"""
15+
def __init__(self, app, redis_url: str, max_requests: int, window_seconds: int):
16+
super().__init__(app)
17+
self.redis_url = redis_url
18+
self.max_requests = max_requests
19+
self.window = window_seconds
20+
self._redis = None
21+
22+
async def _get_redis(self):
23+
if self._redis is None:
24+
# 仅在首次调用时创建 Redis 连接
25+
self._redis = await aioredis.from_url(
26+
self.redis_url, encoding="utf-8", decode_responses=True
27+
)
28+
return self._redis
29+
30+
async def dispatch(self, request: Request, call_next):
31+
# 获取客户端 IP
32+
client_host = request.client.host if request.client else "unknown"
33+
logger.debug("rate-limiter", host=client_host)
34+
# 计算当前时间窗口
35+
now = int(time.time())
36+
window_key = now // self.window
37+
key = f"ratelimit:{client_host}:{window_key}"
38+
redis = await self._get_redis()
39+
# 自增计数
40+
count = await redis.incr(key)
41+
if count == 1:
42+
# 设置过期时间为一个窗口长度
43+
await redis.expire(key, self.window)
44+
45+
# 超出限流阈值,返回 429
46+
if count > self.max_requests:
47+
logger.info("rate-limiter", host=client_host, key=key)
48+
return JSONResponse(
49+
{"detail": "请求过多,请稍后重试"},
50+
status_code=status.HTTP_429_TOO_MANY_REQUESTS
51+
)
52+
53+
# 继续处理请求
54+
response = await call_next(request)
55+
return response
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import json
2+
import uuid
3+
from itsdangerous import TimestampSigner, BadSignature
4+
from starlette.middleware.base import BaseHTTPMiddleware
5+
from starlette.requests import Request
6+
from starlette.responses import Response
7+
from redis import asyncio as aioredis
8+
9+
10+
class RedisSessionMiddleware(BaseHTTPMiddleware):
11+
"""
12+
TODO 基于 Redis 的服务端 Session 中间件。
13+
- session 使用 UUID 作为 key 存储在 Redis 中
14+
- cookie 存储签名后的 session_id
15+
"""
16+
def __init__(self, app, secret_key: str, redis_url: str, cookie_name: str = "session", max_age: int = 14*24*60*60):
17+
super().__init__(app)
18+
if not secret_key:
19+
raise ValueError("SESSION_SECRET_KEY 未配置")
20+
self.signer = TimestampSigner(secret_key)
21+
self.redis_url = redis_url
22+
self.cookie_name = cookie_name
23+
self.max_age = max_age
24+
self._redis = None
25+
26+
async def _get_redis(self):
27+
if self._redis is None:
28+
self._redis = await aioredis.from_url(self.redis_url, encoding="utf-8", decode_responses=True)
29+
return self._redis
30+
31+
async def dispatch(self, request: Request, call_next):
32+
# 获取或生成 session_id
33+
session_id = None
34+
cookie = request.cookies.get(self.cookie_name)
35+
if cookie:
36+
try:
37+
unsigned = self.signer.unsign(cookie, max_age=self.max_age)
38+
session_id = unsigned.decode()
39+
except BadSignature:
40+
session_id = None
41+
if not session_id:
42+
session_id = str(uuid.uuid4())
43+
44+
# 取 Redis 中的数据
45+
redis = await self._get_redis()
46+
raw = await redis.get(f"session:{session_id}")
47+
try:
48+
session_data = json.loads(raw) if raw else {}
49+
except json.JSONDecodeError:
50+
session_data = {}
51+
request.scope["session"] = session_data
52+
53+
# 调用下游
54+
response: Response = await call_next(request)
55+
56+
# 写回 Redis
57+
await redis.setex(f"session:{session_id}", self.max_age, json.dumps(request.scope.get('session', {})))
58+
59+
# 设置 cookie
60+
signed = self.signer.sign(session_id.encode()).decode()
61+
response.set_cookie(
62+
self.cookie_name,
63+
signed,
64+
max_age=self.max_age,
65+
httponly=True,
66+
samesite="lax"
67+
)
68+
return response
Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,30 @@
1+
import os
2+
import logging
3+
from pathlib import Path
14
from structlog import get_logger
2-
logger = get_logger("suwen-news-mcp-server")
5+
from logging.handlers import TimedRotatingFileHandler
6+
BASE_DIR = Path(__file__).parent.parent.parent
7+
8+
9+
logger = get_logger("suwen-news-mcp-server")
10+
# === 文件日志处理器 ===
11+
LOG_DIR = os.getenv("LOG_DIR", os.path.join(BASE_DIR, "logs"))
12+
os.makedirs(LOG_DIR, exist_ok=True)
13+
LOG_FILE = os.path.join(LOG_DIR, "es_news_mcp_server.log")
14+
15+
# === 标准库日志根配置 ===
16+
root_logger = logging.getLogger()
17+
root_logger.setLevel(logging.INFO)
18+
19+
# 控制台 Handler(保留 JSON 格式)
20+
console_handler = logging.StreamHandler()
21+
console_handler.setFormatter(logging.Formatter("%(message)s"))
22+
23+
# 文件 Handler:每日 0 点轮转,保留 7 天
24+
file_handler = TimedRotatingFileHandler(LOG_FILE, when="midnight", backupCount=7, encoding="utf-8")
25+
file_handler.setFormatter(logging.Formatter("%(message)s"))
26+
27+
# 仅在首次配置时添加(防止重复)
28+
if not root_logger.handlers:
29+
root_logger.addHandler(console_handler)
30+
root_logger.addHandler(file_handler)

0 commit comments

Comments
 (0)