From c4a4a63b29865b06a21a80b3c8331763b364f59e Mon Sep 17 00:00:00 2001 From: yangrongze <2303542064@qq.com> Date: Wed, 31 Dec 2025 15:12:48 +0800 Subject: [PATCH] Changes --- .gitignore | 4 ++ app.py | 6 +- controller/LoginController.py | 115 +++++++++----------------------------- util/auth_interceptor.py | 127 ++++++++++++++++++++++++++++++++++++++++++ util/token_utils.py | 112 +++++++++++++++++++++++++++++++++++++ vue/src/api/profile.js | 11 ++-- vue/src/utils/request.js | 12 +++- web_main.py | 5 +- 8 files changed, 291 insertions(+), 101 deletions(-) create mode 100644 .gitignore create mode 100644 util/auth_interceptor.py create mode 100644 util/token_utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cd52a4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +__pycache__/ +*.pyc +.idea/ diff --git a/app.py b/app.py index a67e49f..d6dea93 100644 --- a/app.py +++ b/app.py @@ -9,4 +9,8 @@ if web_server_path not in sys.path: sys.path.insert(0, web_server_path) -app = Robyn(__file__) \ No newline at end of file +app = Robyn(__file__) + +# 注册全局认证拦截器 +from util.auth_interceptor import global_auth_interceptor +app.before_request()(global_auth_interceptor) \ No newline at end of file diff --git a/controller/LoginController.py b/controller/LoginController.py index 3662096..a24a4bb 100644 --- a/controller/LoginController.py +++ b/controller/LoginController.py @@ -1,30 +1,19 @@ import os + from robyn import jsonify, Response, Request from app import app -from datetime import datetime, timedelta -import uuid + + import json from service.UserService import user_service +from util.token_utils import TokenManager +from util.auth_interceptor import require_login, get_current_user + + -# 临时存储token,用于会话管理 -TEMP_TOKENS = {} -def generate_token() -> str: - """生成随机token""" - return str(uuid.uuid4()) -def validate_token(token: str) -> dict: - """验证token并返回用户信息""" - if not token or token not in TEMP_TOKENS: - return None - - # 检查token是否过期 - if datetime.now() > TEMP_TOKENS[token]["expires_at"]: - del TEMP_TOKENS[token] - return None - - return TEMP_TOKENS[token]["user"] @app.post("/api/login") def login_route(request): @@ -63,10 +52,9 @@ def login_route(request): headers={"Content-Type": "application/json; charset=utf-8"} ) - # 生成token并设置过期时间 - token = generate_token() - expires_at = datetime.now() + timedelta(days=7 if remember else 1) - TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at} + + token=TokenManager.generate_token() + TokenManager.store_token(token,user, remember) return Response( status_code=200, @@ -84,10 +72,8 @@ def login_route(request): def logout_route(request): """登出接口""" try: - request_data = json.loads(request.body) if request.body else {} - token = request_data.get("token", "") - # 删除token - TEMP_TOKENS.pop(token, None) + token=TokenManager.get_token_from_request(request) + TokenManager.remove_token(token) return Response( status_code=200, @@ -102,27 +88,18 @@ def logout_route(request): ) @app.get("/api/userInfo") +@require_login def user_info_route(request): """获取用户信息接口""" try: - query_params = getattr(request, 'query_params', {}) - token = query_params.get("token", "") - - # 验证token - user = validate_token(token) - if not user: - return Response( - status_code=401, - description=jsonify({"success": False, "message": "未登录或登录已过期"}), - headers={"Content-Type": "application/json; charset=utf-8"} - ) + user = get_current_user() # 从数据库获取最新的用户信息 username = user["username"] db_user = user_service.get_user_info(username) if db_user: # 更新TEMP_TOKENS中的用户信息 - TEMP_TOKENS[token]["user"] = db_user + user_info = db_user else: user_info = user @@ -140,48 +117,22 @@ def user_info_route(request): ) @app.post("/api/updateAvatar") +@require_login async def update_avatar_route(request: Request): """更新用户头像接口""" try: + + user = get_current_user() # 从files中获取文件和token avatar_file = request.files.get('avatar') if hasattr(request, 'files') else None - token = None + - # 从form_data中获取token - if hasattr(request, 'form_data'): - token = request.form_data.get('token') - - # 如果files中没有直接找到'avatar',尝试获取第一个文件 + if not avatar_file and hasattr(request, 'files') and request.files: first_key = list(request.files.keys())[0] avatar_file = request.files[first_key] - # 如果form_data中没有token,尝试从headers获取 - if not token: - headers_dict = {} - if hasattr(request, 'headers'): - try: - headers_dict = dict(request.headers) - except: - pass - token = headers_dict.get('Authorization') or headers_dict.get('authorization') - - # 如果还是没有token,尝试从body中解析 - if not token and hasattr(request, 'body'): - try: - body_data = json.loads(request.body) - token = body_data.get('token') - except: - pass - - # 验证token - user = validate_token(token) - if not user: - return Response( - status_code=401, - description=jsonify({"success": False, "message": "未登录或登录已过期"}), - headers={"Content-Type": "application/json; charset=utf-8"} - ) + # 检查文件是否存在 if not avatar_file: @@ -282,8 +233,7 @@ async def update_avatar_route(request: Request): headers={"Content-Type": "application/json; charset=utf-8"} ) - # 更新token中的用户信息 - TEMP_TOKENS[token]["user"]["avatar"] = avatar_relative_path + return Response( status_code=200, @@ -302,32 +252,21 @@ async def update_avatar_route(request: Request): ) @app.post("/api/updatePassword") +@require_login def update_password_route(request): """更新用户密码接口""" try: # 解析请求数据 request_data = json.loads(request.body) if request.body else {} - token = request_data.get("token", "") + current_password = request_data.get("currentPassword", "") new_password = request_data.get("newPassword", "") # 验证输入 - if not current_password or not new_password: - return Response( - status_code=400, - description=jsonify({"success": False, "message": "当前密码和新密码不能为空"}), - headers={"Content-Type": "application/json; charset=utf-8"} - ) - - # 验证token - user = validate_token(token) - if not user: - return Response( - status_code=401, - description=jsonify({"success": False, "message": "未登录或登录已过期"}), - headers={"Content-Type": "application/json; charset=utf-8"} - ) + + # 获取当前用户 + user = get_current_user() # 获取用户信息 username = user["username"] diff --git a/util/auth_interceptor.py b/util/auth_interceptor.py new file mode 100644 index 0000000..ad74ac9 --- /dev/null +++ b/util/auth_interceptor.py @@ -0,0 +1,127 @@ +from robyn import Response, jsonify +from util.token_utils import TokenManager +from functools import wraps +import threading + +# 不需要登录的公开路径(白名单) +PUBLIC_PATHS = [ + '/api/login', + '/api/logout', + '/api/register', + '/api/checkUsername', + '/resource', + +] + +# 使用线程本地存储来保存当前请求的用户信息 +_thread_local = threading.local() + + +def get_current_user(): + """获取当前请求的用户信息""" + return getattr(_thread_local, 'current_user', None) + + +def set_current_user(user): + """设置当前请求的用户信息""" + _thread_local.current_user = user + + +def is_public_path(path): + """检查是否为公开路径""" + for public_path in PUBLIC_PATHS: + if path.startswith(public_path): + return True + return False + + +def require_login(func): + """装饰器:要求登录(支持同步和异步函数)""" + import asyncio + + @wraps(func) + async def async_wrapper(request): + try: + # 提取token + token = TokenManager.get_token_from_request(request) + + # 验证token + user = TokenManager.validate_token(token) + + if not user: + return Response( + status_code=401, + description=jsonify({"success": False, "message": "未登录或登录已过期"}), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + + # 验证通过,保存用户信息到线程本地存储 + set_current_user(user) + + return await func(request) + + except Exception as e: + return Response( + status_code=500, + description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + + @wraps(func) + def sync_wrapper(request): + try: + # 提取token + token = TokenManager.get_token_from_request(request) + + # 验证token + user = TokenManager.validate_token(token) + + if not user: + return Response( + status_code=401, + description=jsonify({"success": False, "message": "未登录或登录已过期"}), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + + # 验证通过,保存用户信息到线程本地存储 + set_current_user(user) + + return func(request) + + except Exception as e: + return Response( + status_code=500, + description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + + # 根据原函数是否为异步函数,返回对应的wrapper + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + +def global_auth_interceptor(request): + """全局认证拦截器 - 用于 @app.before_request() 装饰器""" + # 获取请求路径 + path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url) + + # 公开路径直接放行 + if is_public_path(path): + return request + + # 验证 token + token = TokenManager.get_token_from_request(request) + user = TokenManager.validate_token(token) + + if not user: + return Response( + status_code=401, + description=jsonify({"success": False, "message": "未登录或登录已过期"}), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + + # 保存用户信息 + set_current_user(user) + return request diff --git a/util/token_utils.py b/util/token_utils.py new file mode 100644 index 0000000..099a217 --- /dev/null +++ b/util/token_utils.py @@ -0,0 +1,112 @@ +import uuid +from datetime import datetime, timedelta + +# 全局变量存储token +TEMP_TOKENS = {} + + +class TokenManager: + @staticmethod + def generate_token(): + """生成随机token""" + return str(uuid.uuid4()) + + @staticmethod + def store_token(token, user, remember=False): + """存储token""" + expires_at = datetime.now() + timedelta(days=7 if remember else 1) + TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at} + + @staticmethod + def validate_token(token): + """验证token并返回用户信息""" + if not token or token not in TEMP_TOKENS: + return None + + # 检查token是否过期 + if datetime.now() > TEMP_TOKENS[token]["expires_at"]: + del TEMP_TOKENS[token] + return None + + return TEMP_TOKENS[token]["user"] + + @staticmethod + def remove_token(token): + """删除token""" + if token in TEMP_TOKENS: + del TEMP_TOKENS[token] + + @staticmethod + def get_token_from_request(request): + """从请求中提取token""" + import json + token = None + + # 从headers获取 + if hasattr(request, 'headers'): + try: + headers = request.headers + if headers: + # 方法1:直接用get方法获取(Robyn的Headers对象支持) + try: + token_value = headers.get('token', None) + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + except: + pass + + # 方法2:尝试大写Token + if not token: + try: + token_value = headers.get('Token', None) + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + except: + pass + + # 方法3:如果headers是字符串(JSON格式),解析它 + if not token: + try: + headers_str = str(headers) + if headers_str.startswith('{'): + headers_dict = json.loads(headers_str) + token_value = headers_dict.get('token') or headers_dict.get('Token') + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + except: + pass + except: + pass + + # 从query参数获取 + if not token and hasattr(request, 'query_params'): + try: + query_params = getattr(request, 'query_params', {}) + # Robyn的query_params.get()需要提供默认值 + token_value = query_params.get('token', None) + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + + # 也尝试从 headers[token] 参数获取 + if not token: + token_value = query_params.get('headers[token]', None) + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + + # 尝试URL编码的版本 + if not token: + token_value = query_params.get('headers%5Btoken%5D', None) + if token_value: + token = token_value[0] if isinstance(token_value, list) else token_value + except: + pass + + # 从body中获取 + if not token and hasattr(request, 'body') and request.body: + try: + body_data = json.loads(request.body) + token = body_data.get('token') + except: + pass + + return token diff --git a/vue/src/api/profile.js b/vue/src/api/profile.js index 6945a06..04e2b1e 100644 --- a/vue/src/api/profile.js +++ b/vue/src/api/profile.js @@ -10,9 +10,7 @@ export function getUserProfile(token) { return request({ url: '/api/userInfo', method: 'get', - params: { - token - } + }); } @@ -26,9 +24,7 @@ export function updateAvatar(formData) { url: '/api/updateAvatar', method: 'post', data: formData, - headers: { - 'Content-Type': 'multipart/form-data' - } + }); } @@ -41,6 +37,7 @@ export function updatePassword(data) { return request({ url: '/api/updatePassword', method: 'post', - data + data, + }); } \ No newline at end of file diff --git a/vue/src/utils/request.js b/vue/src/utils/request.js index e2e5969..013e72a 100644 --- a/vue/src/utils/request.js +++ b/vue/src/utils/request.js @@ -7,6 +7,10 @@ const service = axios.create({ service.interceptors.request.use( (config) => { + const token = localStorage.getItem('token'); + if (token) { + config.headers['token'] = token; + } return config; }, (error) => { @@ -14,14 +18,20 @@ service.interceptors.request.use( } ); + service.interceptors.response.use( (response) => { return response.data; }, (error) => { console.error('请求错误拦截器:', error); + if (error.response && error.response.status === 401) { + localStorage.removeItem('token'); + window.location.href = '/login'; + } return Promise.reject(error); } ); -export default service; \ No newline at end of file + +export default service; diff --git a/web_main.py b/web_main.py index 5f8d544..8b1d75c 100644 --- a/web_main.py +++ b/web_main.py @@ -1,5 +1,3 @@ -from robyn import ALLOW_CORS - from app import app import controller from service.UserService import init_mysql_connection @@ -10,7 +8,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) resource_dir = os.path.join(current_dir, "resource") if os.path.exists(resource_dir): app.serve_directory("/resource", resource_dir) - print(f"静态资源目录已配置: {resource_dir}") if __name__ == "__main__": - init_mysql_connection() and app.start(host="0.0.0.0", port=8088) \ No newline at end of file + init_mysql_connection() and app.start(host="0.0.0.0", port=8088)