8 changed files with 291 additions and 101 deletions
@ -0,0 +1,4 @@ |
|||||
|
node_modules/ |
||||
|
__pycache__/ |
||||
|
*.pyc |
||||
|
.idea/ |
||||
@ -0,0 +1,127 @@ |
|||||
|
from robyn import Response, jsonify |
||||
|
from util.token_utils import TokenManager |
||||
|
from functools import wraps |
||||
|
import threading |
||||
|
|
||||
|
# 不需要登录的公开路径(白名单) |
||||
|
PUBLIC_PATHS = [ |
||||
|
'/api/login', |
||||
|
'/api/logout', |
||||
|
'/api/register', |
||||
|
'/api/checkUsername', |
||||
|
'/resource', |
||||
|
|
||||
|
] |
||||
|
|
||||
|
# 使用线程本地存储来保存当前请求的用户信息 |
||||
|
_thread_local = threading.local() |
||||
|
|
||||
|
|
||||
|
def get_current_user(): |
||||
|
"""获取当前请求的用户信息""" |
||||
|
return getattr(_thread_local, 'current_user', None) |
||||
|
|
||||
|
|
||||
|
def set_current_user(user): |
||||
|
"""设置当前请求的用户信息""" |
||||
|
_thread_local.current_user = user |
||||
|
|
||||
|
|
||||
|
def is_public_path(path): |
||||
|
"""检查是否为公开路径""" |
||||
|
for public_path in PUBLIC_PATHS: |
||||
|
if path.startswith(public_path): |
||||
|
return True |
||||
|
return False |
||||
|
|
||||
|
|
||||
|
def require_login(func): |
||||
|
"""装饰器:要求登录(支持同步和异步函数)""" |
||||
|
import asyncio |
||||
|
|
||||
|
@wraps(func) |
||||
|
async def async_wrapper(request): |
||||
|
try: |
||||
|
# 提取token |
||||
|
token = TokenManager.get_token_from_request(request) |
||||
|
|
||||
|
# 验证token |
||||
|
user = TokenManager.validate_token(token) |
||||
|
|
||||
|
if not user: |
||||
|
return Response( |
||||
|
status_code=401, |
||||
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
||||
|
headers={"Content-Type": "application/json; charset=utf-8"} |
||||
|
) |
||||
|
|
||||
|
# 验证通过,保存用户信息到线程本地存储 |
||||
|
set_current_user(user) |
||||
|
|
||||
|
return await func(request) |
||||
|
|
||||
|
except Exception as e: |
||||
|
return Response( |
||||
|
status_code=500, |
||||
|
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), |
||||
|
headers={"Content-Type": "application/json; charset=utf-8"} |
||||
|
) |
||||
|
|
||||
|
@wraps(func) |
||||
|
def sync_wrapper(request): |
||||
|
try: |
||||
|
# 提取token |
||||
|
token = TokenManager.get_token_from_request(request) |
||||
|
|
||||
|
# 验证token |
||||
|
user = TokenManager.validate_token(token) |
||||
|
|
||||
|
if not user: |
||||
|
return Response( |
||||
|
status_code=401, |
||||
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
||||
|
headers={"Content-Type": "application/json; charset=utf-8"} |
||||
|
) |
||||
|
|
||||
|
# 验证通过,保存用户信息到线程本地存储 |
||||
|
set_current_user(user) |
||||
|
|
||||
|
return func(request) |
||||
|
|
||||
|
except Exception as e: |
||||
|
return Response( |
||||
|
status_code=500, |
||||
|
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), |
||||
|
headers={"Content-Type": "application/json; charset=utf-8"} |
||||
|
) |
||||
|
|
||||
|
# 根据原函数是否为异步函数,返回对应的wrapper |
||||
|
if asyncio.iscoroutinefunction(func): |
||||
|
return async_wrapper |
||||
|
else: |
||||
|
return sync_wrapper |
||||
|
|
||||
|
|
||||
|
def global_auth_interceptor(request): |
||||
|
"""全局认证拦截器 - 用于 @app.before_request() 装饰器""" |
||||
|
# 获取请求路径 |
||||
|
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url) |
||||
|
|
||||
|
# 公开路径直接放行 |
||||
|
if is_public_path(path): |
||||
|
return request |
||||
|
|
||||
|
# 验证 token |
||||
|
token = TokenManager.get_token_from_request(request) |
||||
|
user = TokenManager.validate_token(token) |
||||
|
|
||||
|
if not user: |
||||
|
return Response( |
||||
|
status_code=401, |
||||
|
description=jsonify({"success": False, "message": "未登录或登录已过期"}), |
||||
|
headers={"Content-Type": "application/json; charset=utf-8"} |
||||
|
) |
||||
|
|
||||
|
# 保存用户信息 |
||||
|
set_current_user(user) |
||||
|
return request |
||||
@ -0,0 +1,112 @@ |
|||||
|
import uuid |
||||
|
from datetime import datetime, timedelta |
||||
|
|
||||
|
# 全局变量存储token |
||||
|
TEMP_TOKENS = {} |
||||
|
|
||||
|
|
||||
|
class TokenManager: |
||||
|
@staticmethod |
||||
|
def generate_token(): |
||||
|
"""生成随机token""" |
||||
|
return str(uuid.uuid4()) |
||||
|
|
||||
|
@staticmethod |
||||
|
def store_token(token, user, remember=False): |
||||
|
"""存储token""" |
||||
|
expires_at = datetime.now() + timedelta(days=7 if remember else 1) |
||||
|
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at} |
||||
|
|
||||
|
@staticmethod |
||||
|
def validate_token(token): |
||||
|
"""验证token并返回用户信息""" |
||||
|
if not token or token not in TEMP_TOKENS: |
||||
|
return None |
||||
|
|
||||
|
# 检查token是否过期 |
||||
|
if datetime.now() > TEMP_TOKENS[token]["expires_at"]: |
||||
|
del TEMP_TOKENS[token] |
||||
|
return None |
||||
|
|
||||
|
return TEMP_TOKENS[token]["user"] |
||||
|
|
||||
|
@staticmethod |
||||
|
def remove_token(token): |
||||
|
"""删除token""" |
||||
|
if token in TEMP_TOKENS: |
||||
|
del TEMP_TOKENS[token] |
||||
|
|
||||
|
@staticmethod |
||||
|
def get_token_from_request(request): |
||||
|
"""从请求中提取token""" |
||||
|
import json |
||||
|
token = None |
||||
|
|
||||
|
# 从headers获取 |
||||
|
if hasattr(request, 'headers'): |
||||
|
try: |
||||
|
headers = request.headers |
||||
|
if headers: |
||||
|
# 方法1:直接用get方法获取(Robyn的Headers对象支持) |
||||
|
try: |
||||
|
token_value = headers.get('token', None) |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
# 方法2:尝试大写Token |
||||
|
if not token: |
||||
|
try: |
||||
|
token_value = headers.get('Token', None) |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
# 方法3:如果headers是字符串(JSON格式),解析它 |
||||
|
if not token: |
||||
|
try: |
||||
|
headers_str = str(headers) |
||||
|
if headers_str.startswith('{'): |
||||
|
headers_dict = json.loads(headers_str) |
||||
|
token_value = headers_dict.get('token') or headers_dict.get('Token') |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
except: |
||||
|
pass |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
# 从query参数获取 |
||||
|
if not token and hasattr(request, 'query_params'): |
||||
|
try: |
||||
|
query_params = getattr(request, 'query_params', {}) |
||||
|
# Robyn的query_params.get()需要提供默认值 |
||||
|
token_value = query_params.get('token', None) |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
|
||||
|
# 也尝试从 headers[token] 参数获取 |
||||
|
if not token: |
||||
|
token_value = query_params.get('headers[token]', None) |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
|
||||
|
# 尝试URL编码的版本 |
||||
|
if not token: |
||||
|
token_value = query_params.get('headers%5Btoken%5D', None) |
||||
|
if token_value: |
||||
|
token = token_value[0] if isinstance(token_value, list) else token_value |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
# 从body中获取 |
||||
|
if not token and hasattr(request, 'body') and request.body: |
||||
|
try: |
||||
|
body_data = json.loads(request.body) |
||||
|
token = body_data.get('token') |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
return token |
||||
Loading…
Reference in new issue