|
|
|
|
from robyn import Response, jsonify
|
|
|
|
|
from util.token_utils import TokenManager
|
|
|
|
|
from functools import wraps
|
|
|
|
|
import threading
|
|
|
|
|
|
|
|
|
|
# 不需要登录的公开路径(白名单)
|
|
|
|
|
PUBLIC_PATHS = [
|
|
|
|
|
'/api/login',
|
|
|
|
|
'/api/logout',
|
|
|
|
|
'/api/register',
|
|
|
|
|
'/api/checkUsername',
|
|
|
|
|
'/resource',
|
|
|
|
|
'/api/kg/export',
|
|
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 使用线程本地存储来保存当前请求的用户信息
|
|
|
|
|
_thread_local = threading.local()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_user():
|
|
|
|
|
"""获取当前请求的用户信息"""
|
|
|
|
|
return getattr(_thread_local, 'current_user', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_current_user(user):
|
|
|
|
|
"""设置当前请求的用户信息"""
|
|
|
|
|
_thread_local.current_user = user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_public_path(path):
|
|
|
|
|
"""检查是否为公开路径"""
|
|
|
|
|
for public_path in PUBLIC_PATHS:
|
|
|
|
|
if path.startswith(public_path):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def require_login(func):
|
|
|
|
|
"""装饰器:要求登录(支持同步和异步函数)"""
|
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
|
async def async_wrapper(request):
|
|
|
|
|
try:
|
|
|
|
|
# 提取token
|
|
|
|
|
token = TokenManager.get_token_from_request(request)
|
|
|
|
|
|
|
|
|
|
# 验证token
|
|
|
|
|
user = TokenManager.validate_token(token)
|
|
|
|
|
|
|
|
|
|
if not user:
|
|
|
|
|
return Response(
|
|
|
|
|
status_code=401,
|
|
|
|
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
|
|
|
|
|
headers={"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 验证通过,保存用户信息到线程本地存储
|
|
|
|
|
set_current_user(user)
|
|
|
|
|
|
|
|
|
|
return await func(request)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return Response(
|
|
|
|
|
status_code=500,
|
|
|
|
|
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
|
|
|
|
|
headers={"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
|
def sync_wrapper(request):
|
|
|
|
|
try:
|
|
|
|
|
# 提取token
|
|
|
|
|
token = TokenManager.get_token_from_request(request)
|
|
|
|
|
|
|
|
|
|
# 验证token
|
|
|
|
|
user = TokenManager.validate_token(token)
|
|
|
|
|
|
|
|
|
|
if not user:
|
|
|
|
|
return Response(
|
|
|
|
|
status_code=401,
|
|
|
|
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
|
|
|
|
|
headers={"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 验证通过,保存用户信息到线程本地存储
|
|
|
|
|
set_current_user(user)
|
|
|
|
|
|
|
|
|
|
return func(request)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return Response(
|
|
|
|
|
status_code=500,
|
|
|
|
|
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
|
|
|
|
|
headers={"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 根据原函数是否为异步函数,返回对应的wrapper
|
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
|
return async_wrapper
|
|
|
|
|
else:
|
|
|
|
|
return sync_wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def global_auth_interceptor(request):
|
|
|
|
|
"""全局认证拦截器 - 用于 @app.before_request() 装饰器"""
|
|
|
|
|
# 获取请求路径
|
|
|
|
|
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url)
|
|
|
|
|
print(path)
|
|
|
|
|
print("1111111111111111=============")
|
|
|
|
|
# 公开路径直接放行
|
|
|
|
|
if is_public_path(path):
|
|
|
|
|
return request
|
|
|
|
|
|
|
|
|
|
# 验证 token
|
|
|
|
|
token = TokenManager.get_token_from_request(request)
|
|
|
|
|
user = TokenManager.validate_token(token)
|
|
|
|
|
|
|
|
|
|
if not user:
|
|
|
|
|
return Response(
|
|
|
|
|
status_code=401,
|
|
|
|
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
|
|
|
|
|
headers={"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 保存用户信息
|
|
|
|
|
set_current_user(user)
|
|
|
|
|
return request
|