Browse Source

Changes

yangrongze
yangrongze 3 months ago
parent
commit
c4a4a63b29
  1. 4
      .gitignore
  2. 6
      app.py
  3. 115
      controller/LoginController.py
  4. 127
      util/auth_interceptor.py
  5. 112
      util/token_utils.py
  6. 11
      vue/src/api/profile.js
  7. 12
      vue/src/utils/request.js
  8. 5
      web_main.py

4
.gitignore

@ -0,0 +1,4 @@
node_modules/
__pycache__/
*.pyc
.idea/

6
app.py

@ -9,4 +9,8 @@ if web_server_path not in sys.path:
sys.path.insert(0, web_server_path)
app = Robyn(__file__)
app = Robyn(__file__)
# 注册全局认证拦截器
from util.auth_interceptor import global_auth_interceptor
app.before_request()(global_auth_interceptor)

115
controller/LoginController.py

@ -1,30 +1,19 @@
import os
from robyn import jsonify, Response, Request
from app import app
from datetime import datetime, timedelta
import uuid
import json
from service.UserService import user_service
from util.token_utils import TokenManager
from util.auth_interceptor import require_login, get_current_user
# 临时存储token,用于会话管理
TEMP_TOKENS = {}
def generate_token() -> str:
"""生成随机token"""
return str(uuid.uuid4())
def validate_token(token: str) -> dict:
"""验证token并返回用户信息"""
if not token or token not in TEMP_TOKENS:
return None
# 检查token是否过期
if datetime.now() > TEMP_TOKENS[token]["expires_at"]:
del TEMP_TOKENS[token]
return None
return TEMP_TOKENS[token]["user"]
@app.post("/api/login")
def login_route(request):
@ -63,10 +52,9 @@ def login_route(request):
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 生成token并设置过期时间
token = generate_token()
expires_at = datetime.now() + timedelta(days=7 if remember else 1)
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at}
token=TokenManager.generate_token()
TokenManager.store_token(token,user, remember)
return Response(
status_code=200,
@ -84,10 +72,8 @@ def login_route(request):
def logout_route(request):
"""登出接口"""
try:
request_data = json.loads(request.body) if request.body else {}
token = request_data.get("token", "")
# 删除token
TEMP_TOKENS.pop(token, None)
token=TokenManager.get_token_from_request(request)
TokenManager.remove_token(token)
return Response(
status_code=200,
@ -102,27 +88,18 @@ def logout_route(request):
)
@app.get("/api/userInfo")
@require_login
def user_info_route(request):
"""获取用户信息接口"""
try:
query_params = getattr(request, 'query_params', {})
token = query_params.get("token", "")
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
user = get_current_user()
# 从数据库获取最新的用户信息
username = user["username"]
db_user = user_service.get_user_info(username)
if db_user:
# 更新TEMP_TOKENS中的用户信息
TEMP_TOKENS[token]["user"] = db_user
user_info = db_user
else:
user_info = user
@ -140,48 +117,22 @@ def user_info_route(request):
)
@app.post("/api/updateAvatar")
@require_login
async def update_avatar_route(request: Request):
"""更新用户头像接口"""
try:
user = get_current_user()
# 从files中获取文件和token
avatar_file = request.files.get('avatar') if hasattr(request, 'files') else None
token = None
# 从form_data中获取token
if hasattr(request, 'form_data'):
token = request.form_data.get('token')
# 如果files中没有直接找到'avatar',尝试获取第一个文件
if not avatar_file and hasattr(request, 'files') and request.files:
first_key = list(request.files.keys())[0]
avatar_file = request.files[first_key]
# 如果form_data中没有token,尝试从headers获取
if not token:
headers_dict = {}
if hasattr(request, 'headers'):
try:
headers_dict = dict(request.headers)
except:
pass
token = headers_dict.get('Authorization') or headers_dict.get('authorization')
# 如果还是没有token,尝试从body中解析
if not token and hasattr(request, 'body'):
try:
body_data = json.loads(request.body)
token = body_data.get('token')
except:
pass
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 检查文件是否存在
if not avatar_file:
@ -282,8 +233,7 @@ async def update_avatar_route(request: Request):
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 更新token中的用户信息
TEMP_TOKENS[token]["user"]["avatar"] = avatar_relative_path
return Response(
status_code=200,
@ -302,32 +252,21 @@ async def update_avatar_route(request: Request):
)
@app.post("/api/updatePassword")
@require_login
def update_password_route(request):
"""更新用户密码接口"""
try:
# 解析请求数据
request_data = json.loads(request.body) if request.body else {}
token = request_data.get("token", "")
current_password = request_data.get("currentPassword", "")
new_password = request_data.get("newPassword", "")
# 验证输入
if not current_password or not new_password:
return Response(
status_code=400,
description=jsonify({"success": False, "message": "当前密码和新密码不能为空"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 获取当前用户
user = get_current_user()
# 获取用户信息
username = user["username"]

127
util/auth_interceptor.py

@ -0,0 +1,127 @@
from robyn import Response, jsonify
from util.token_utils import TokenManager
from functools import wraps
import threading
# 不需要登录的公开路径(白名单)
PUBLIC_PATHS = [
'/api/login',
'/api/logout',
'/api/register',
'/api/checkUsername',
'/resource',
]
# 使用线程本地存储来保存当前请求的用户信息
_thread_local = threading.local()
def get_current_user():
"""获取当前请求的用户信息"""
return getattr(_thread_local, 'current_user', None)
def set_current_user(user):
"""设置当前请求的用户信息"""
_thread_local.current_user = user
def is_public_path(path):
"""检查是否为公开路径"""
for public_path in PUBLIC_PATHS:
if path.startswith(public_path):
return True
return False
def require_login(func):
"""装饰器:要求登录(支持同步和异步函数)"""
import asyncio
@wraps(func)
async def async_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return await func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
@wraps(func)
def sync_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 根据原函数是否为异步函数,返回对应的wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
def global_auth_interceptor(request):
"""全局认证拦截器 - 用于 @app.before_request() 装饰器"""
# 获取请求路径
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url)
# 公开路径直接放行
if is_public_path(path):
return request
# 验证 token
token = TokenManager.get_token_from_request(request)
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 保存用户信息
set_current_user(user)
return request

112
util/token_utils.py

@ -0,0 +1,112 @@
import uuid
from datetime import datetime, timedelta
# 全局变量存储token
TEMP_TOKENS = {}
class TokenManager:
@staticmethod
def generate_token():
"""生成随机token"""
return str(uuid.uuid4())
@staticmethod
def store_token(token, user, remember=False):
"""存储token"""
expires_at = datetime.now() + timedelta(days=7 if remember else 1)
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at}
@staticmethod
def validate_token(token):
"""验证token并返回用户信息"""
if not token or token not in TEMP_TOKENS:
return None
# 检查token是否过期
if datetime.now() > TEMP_TOKENS[token]["expires_at"]:
del TEMP_TOKENS[token]
return None
return TEMP_TOKENS[token]["user"]
@staticmethod
def remove_token(token):
"""删除token"""
if token in TEMP_TOKENS:
del TEMP_TOKENS[token]
@staticmethod
def get_token_from_request(request):
"""从请求中提取token"""
import json
token = None
# 从headers获取
if hasattr(request, 'headers'):
try:
headers = request.headers
if headers:
# 方法1:直接用get方法获取(Robyn的Headers对象支持)
try:
token_value = headers.get('token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 方法2:尝试大写Token
if not token:
try:
token_value = headers.get('Token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 方法3:如果headers是字符串(JSON格式),解析它
if not token:
try:
headers_str = str(headers)
if headers_str.startswith('{'):
headers_dict = json.loads(headers_str)
token_value = headers_dict.get('token') or headers_dict.get('Token')
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
except:
pass
# 从query参数获取
if not token and hasattr(request, 'query_params'):
try:
query_params = getattr(request, 'query_params', {})
# Robyn的query_params.get()需要提供默认值
token_value = query_params.get('token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
# 也尝试从 headers[token] 参数获取
if not token:
token_value = query_params.get('headers[token]', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
# 尝试URL编码的版本
if not token:
token_value = query_params.get('headers%5Btoken%5D', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 从body中获取
if not token and hasattr(request, 'body') and request.body:
try:
body_data = json.loads(request.body)
token = body_data.get('token')
except:
pass
return token

11
vue/src/api/profile.js

@ -10,9 +10,7 @@ export function getUserProfile(token) {
return request({
url: '/api/userInfo',
method: 'get',
params: {
token
}
});
}
@ -26,9 +24,7 @@ export function updateAvatar(formData) {
url: '/api/updateAvatar',
method: 'post',
data: formData,
headers: {
'Content-Type': 'multipart/form-data'
}
});
}
@ -41,6 +37,7 @@ export function updatePassword(data) {
return request({
url: '/api/updatePassword',
method: 'post',
data
data,
});
}

12
vue/src/utils/request.js

@ -7,6 +7,10 @@ const service = axios.create({
service.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token');
if (token) {
config.headers['token'] = token;
}
return config;
},
(error) => {
@ -14,14 +18,20 @@ service.interceptors.request.use(
}
);
service.interceptors.response.use(
(response) => {
return response.data;
},
(error) => {
console.error('请求错误拦截器:', error);
if (error.response && error.response.status === 401) {
localStorage.removeItem('token');
window.location.href = '/login';
}
return Promise.reject(error);
}
);
export default service;
export default service;

5
web_main.py

@ -1,5 +1,3 @@
from robyn import ALLOW_CORS
from app import app
import controller
from service.UserService import init_mysql_connection
@ -10,7 +8,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
resource_dir = os.path.join(current_dir, "resource")
if os.path.exists(resource_dir):
app.serve_directory("/resource", resource_dir)
print(f"静态资源目录已配置: {resource_dir}")
if __name__ == "__main__":
init_mysql_connection() and app.start(host="0.0.0.0", port=8088)
init_mysql_connection() and app.start(host="0.0.0.0", port=8088)

Loading…
Cancel
Save