You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

128 lines
3.9 KiB

from robyn import Response, jsonify
from util.token_utils import TokenManager
from functools import wraps
import threading
# 不需要登录的公开路径(白名单)
PUBLIC_PATHS = [
'/api/login',
'/api/logout',
'/api/register',
'/api/checkUsername',
'/resource',
]
# 使用线程本地存储来保存当前请求的用户信息
_thread_local = threading.local()
def get_current_user():
"""获取当前请求的用户信息"""
return getattr(_thread_local, 'current_user', None)
def set_current_user(user):
"""设置当前请求的用户信息"""
_thread_local.current_user = user
def is_public_path(path):
"""检查是否为公开路径"""
for public_path in PUBLIC_PATHS:
if path.startswith(public_path):
return True
return False
def require_login(func):
"""装饰器:要求登录(支持同步和异步函数)"""
import asyncio
@wraps(func)
async def async_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return await func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
@wraps(func)
def sync_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 根据原函数是否为异步函数,返回对应的wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
def global_auth_interceptor(request):
"""全局认证拦截器 - 用于 @app.before_request() 装饰器"""
# 获取请求路径
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url)
print(path)
print("1111111111111111=============")
# 公开路径直接放行
if is_public_path(path):
return request
# 验证 token
token = TokenManager.get_token_from_request(request)
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 保存用户信息
set_current_user(user)
return request