8 changed files with 291 additions and 101 deletions
@ -0,0 +1,4 @@ |
|||
node_modules/ |
|||
__pycache__/ |
|||
*.pyc |
|||
.idea/ |
|||
@ -0,0 +1,127 @@ |
|||
from robyn import Response, jsonify |
|||
from util.token_utils import TokenManager |
|||
from functools import wraps |
|||
import threading |
|||
|
|||
# 不需要登录的公开路径(白名单) |
|||
PUBLIC_PATHS = [ |
|||
'/api/login', |
|||
'/api/logout', |
|||
'/api/register', |
|||
'/api/checkUsername', |
|||
'/resource', |
|||
|
|||
] |
|||
|
|||
# 使用线程本地存储来保存当前请求的用户信息 |
|||
_thread_local = threading.local() |
|||
|
|||
|
|||
def get_current_user(): |
|||
"""获取当前请求的用户信息""" |
|||
return getattr(_thread_local, 'current_user', None) |
|||
|
|||
|
|||
def set_current_user(user): |
|||
"""设置当前请求的用户信息""" |
|||
_thread_local.current_user = user |
|||
|
|||
|
|||
def is_public_path(path): |
|||
"""检查是否为公开路径""" |
|||
for public_path in PUBLIC_PATHS: |
|||
if path.startswith(public_path): |
|||
return True |
|||
return False |
|||
|
|||
|
|||
def require_login(func): |
|||
"""装饰器:要求登录(支持同步和异步函数)""" |
|||
import asyncio |
|||
|
|||
@wraps(func) |
|||
async def async_wrapper(request): |
|||
try: |
|||
# 提取token |
|||
token = TokenManager.get_token_from_request(request) |
|||
|
|||
# 验证token |
|||
user = TokenManager.validate_token(token) |
|||
|
|||
if not user: |
|||
return Response( |
|||
status_code=401, |
|||
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
|||
headers={"Content-Type": "application/json; charset=utf-8"} |
|||
) |
|||
|
|||
# 验证通过,保存用户信息到线程本地存储 |
|||
set_current_user(user) |
|||
|
|||
return await func(request) |
|||
|
|||
except Exception as e: |
|||
return Response( |
|||
status_code=500, |
|||
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), |
|||
headers={"Content-Type": "application/json; charset=utf-8"} |
|||
) |
|||
|
|||
@wraps(func) |
|||
def sync_wrapper(request): |
|||
try: |
|||
# 提取token |
|||
token = TokenManager.get_token_from_request(request) |
|||
|
|||
# 验证token |
|||
user = TokenManager.validate_token(token) |
|||
|
|||
if not user: |
|||
return Response( |
|||
status_code=401, |
|||
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
|||
headers={"Content-Type": "application/json; charset=utf-8"} |
|||
) |
|||
|
|||
# 验证通过,保存用户信息到线程本地存储 |
|||
set_current_user(user) |
|||
|
|||
return func(request) |
|||
|
|||
except Exception as e: |
|||
return Response( |
|||
status_code=500, |
|||
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), |
|||
headers={"Content-Type": "application/json; charset=utf-8"} |
|||
) |
|||
|
|||
# 根据原函数是否为异步函数,返回对应的wrapper |
|||
if asyncio.iscoroutinefunction(func): |
|||
return async_wrapper |
|||
else: |
|||
return sync_wrapper |
|||
|
|||
|
|||
def global_auth_interceptor(request): |
|||
"""全局认证拦截器 - 用于 @app.before_request() 装饰器""" |
|||
# 获取请求路径 |
|||
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url) |
|||
|
|||
# 公开路径直接放行 |
|||
if is_public_path(path): |
|||
return request |
|||
|
|||
# 验证 token |
|||
token = TokenManager.get_token_from_request(request) |
|||
user = TokenManager.validate_token(token) |
|||
|
|||
if not user: |
|||
return Response( |
|||
status_code=401, |
|||
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
|||
headers={"Content-Type": "application/json; charset=utf-8"} |
|||
) |
|||
|
|||
# 保存用户信息 |
|||
set_current_user(user) |
|||
return request |
|||
@ -0,0 +1,112 @@ |
|||
import uuid |
|||
from datetime import datetime, timedelta |
|||
|
|||
# 全局变量存储token |
|||
TEMP_TOKENS = {} |
|||
|
|||
|
|||
class TokenManager: |
|||
@staticmethod |
|||
def generate_token(): |
|||
"""生成随机token""" |
|||
return str(uuid.uuid4()) |
|||
|
|||
@staticmethod |
|||
def store_token(token, user, remember=False): |
|||
"""存储token""" |
|||
expires_at = datetime.now() + timedelta(days=7 if remember else 1) |
|||
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at} |
|||
|
|||
@staticmethod |
|||
def validate_token(token): |
|||
"""验证token并返回用户信息""" |
|||
if not token or token not in TEMP_TOKENS: |
|||
return None |
|||
|
|||
# 检查token是否过期 |
|||
if datetime.now() > TEMP_TOKENS[token]["expires_at"]: |
|||
del TEMP_TOKENS[token] |
|||
return None |
|||
|
|||
return TEMP_TOKENS[token]["user"] |
|||
|
|||
@staticmethod |
|||
def remove_token(token): |
|||
"""删除token""" |
|||
if token in TEMP_TOKENS: |
|||
del TEMP_TOKENS[token] |
|||
|
|||
@staticmethod |
|||
def get_token_from_request(request): |
|||
"""从请求中提取token""" |
|||
import json |
|||
token = None |
|||
|
|||
# 从headers获取 |
|||
if hasattr(request, 'headers'): |
|||
try: |
|||
headers = request.headers |
|||
if headers: |
|||
# 方法1:直接用get方法获取(Robyn的Headers对象支持) |
|||
try: |
|||
token_value = headers.get('token', None) |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
except: |
|||
pass |
|||
|
|||
# 方法2:尝试大写Token |
|||
if not token: |
|||
try: |
|||
token_value = headers.get('Token', None) |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
except: |
|||
pass |
|||
|
|||
# 方法3:如果headers是字符串(JSON格式),解析它 |
|||
if not token: |
|||
try: |
|||
headers_str = str(headers) |
|||
if headers_str.startswith('{'): |
|||
headers_dict = json.loads(headers_str) |
|||
token_value = headers_dict.get('token') or headers_dict.get('Token') |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
except: |
|||
pass |
|||
except: |
|||
pass |
|||
|
|||
# 从query参数获取 |
|||
if not token and hasattr(request, 'query_params'): |
|||
try: |
|||
query_params = getattr(request, 'query_params', {}) |
|||
# Robyn的query_params.get()需要提供默认值 |
|||
token_value = query_params.get('token', None) |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
|
|||
# 也尝试从 headers[token] 参数获取 |
|||
if not token: |
|||
token_value = query_params.get('headers[token]', None) |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
|
|||
# 尝试URL编码的版本 |
|||
if not token: |
|||
token_value = query_params.get('headers%5Btoken%5D', None) |
|||
if token_value: |
|||
token = token_value[0] if isinstance(token_value, list) else token_value |
|||
except: |
|||
pass |
|||
|
|||
# 从body中获取 |
|||
if not token and hasattr(request, 'body') and request.body: |
|||
try: |
|||
body_data = json.loads(request.body) |
|||
token = body_data.get('token') |
|||
except: |
|||
pass |
|||
|
|||
return token |
|||
Loading…
Reference in new issue