1- # @Author: Zhu Guowei
2- # @Date: 2025/6/17
3- # @Function:
41from starlette .middleware .base import BaseHTTPMiddleware
52from starlette .requests import Request
63from starlette .responses import JSONResponse
74from starlette import status
85from structlog import get_logger
96logger = get_logger (__name__ )
107import os
8+ from typing import Optional
9+ import ipaddress
1110
1211ALLOW_HOSTS = ["172.20.80.1" , "127.0.0.1" ]
12+
13+ # 提取获取客户端真实 IP 的函数
14+ def get_client_ip (request : Request ) -> str :
15+ headers = request .headers
16+ x_forwarded_for = headers .get ("x-forwarded-for" )
17+ if x_forwarded_for :
18+ return x_forwarded_for .split ("," )[0 ].strip ()
19+ x_real_ip = headers .get ("x-real-ip" )
20+ if x_real_ip :
21+ return x_real_ip .strip ()
22+ return request .client .host
23+
24+ # 提取 IP 白名单判断函数
25+ def is_ip_allowed (client_ip : str ) -> bool :
26+ try :
27+ ip = ipaddress .ip_address (client_ip )
28+ except ValueError :
29+ return False
30+ # 私有网络或回环地址自动放行
31+ if ip .is_private or ip .is_loopback :
32+ return True
33+ # 额外白名单
34+ return client_ip in ALLOW_HOSTS
35+
36+ # 提取解析 Bearer Token 的函数
37+ def get_bearer_token (auth_header : Optional [str ]) -> Optional [str ]:
38+ if not auth_header or not auth_header .lower ().startswith ("bearer " ):
39+ return None
40+ return auth_header [7 :].strip ()
41+
42+ # 提取设置认证通过标识函数
43+ def mark_session_authenticated (request : Request ) -> None :
44+ # 如果启用了 SessionMiddleware 或 RedisSessionMiddleware,scope['session'] 应该是 dict
45+ session = request .scope .get ("session" )
46+ if isinstance (session , dict ):
47+ session ["is_authenticated" ] = True
48+
1349class SimpleAuthMiddleware (BaseHTTPMiddleware ):
1450 def __init__ (self , app , api_key_env = "API_KEY" ):
1551 super ().__init__ (app )
1652 self .api_key = os .getenv (api_key_env )
17- logger .info (f"api key: { self .api_key } " )
53+ if not self .api_key :
54+ raise RuntimeError ("API_KEY must be set for SimpleAuthMiddleware" )
55+ logger .info ("SimpleAuthMiddleware initialized" )
1856
1957 async def dispatch (self , request : Request , call_next ):
20- # 如果未配置 API_KEY,且允许跳过,则直接放行(便于开发环境)
21-
22- host = request .headers .get ("HOST" )
23- headers = request .headers
24- auth_header = headers .get ("authorization" )
25-
26- for h in ALLOW_HOSTS :
27- if h in host :
28- return await call_next (request )
29- if not self .api_key :
58+ auth_header = request .headers .get ("authorization" )
59+ client_ip = get_client_ip (request )
60+ if is_ip_allowed (client_ip ):
61+ mark_session_authenticated (request )
3062 return await call_next (request )
63+ logger .info ("simple-auth" , client_ip = client_ip , header = auth_header )
3164
32- # 获取 Authorization头
33- logger .info (f"simple-auth" , host = host , header = auth_header )
34-
35- if not auth_header or not auth_header .lower ().startswith ("bearer " ):
36- logger .warning (f"simple-auth" , host = host , detail = "Bearer Token Not Provided" , header = auth_header )
65+ token = get_bearer_token (auth_header )
66+ if token is None :
67+ logger .warning ("simple-auth" , client_ip = client_ip , detail = "Bearer Token Not Provided" , header = auth_header )
3768 return JSONResponse ({"detail" : "Bearer Token Not Provided" }, status_code = status .HTTP_401_UNAUTHORIZED )
3869
39- token = auth_header [7 :].strip ()
4070 if token != self .api_key :
41- logger .warning (f "simple-auth" , host = host , detail = "Invalid Token" , token = token , header = auth_header )
71+ logger .warning ("simple-auth" , client_ip = client_ip , detail = "Invalid Token" , token = token , header = auth_header )
4272 return JSONResponse ({"detail" : "Invalid Token" }, status_code = status .HTTP_403_FORBIDDEN )
4373
44- # 认证通过,继续处理请求
74+ # 认证通过,设置 session 标识并继续处理
75+ mark_session_authenticated (request )
4576 return await call_next (request )
0 commit comments