Browse Source

Merge branch 'yangrongze' of http://124.70.32.114:3100/hanyuqing/KGPython into hanyuqing

hanyuqing
hanyuqing 3 months ago
parent
commit
06669daae9
  1. 4
      .gitignore
  2. 6
      app.py
  3. 115
      controller/LoginController.py
  4. 127
      util/auth_interceptor.py
  5. 112
      util/token_utils.py
  6. 11
      vue/src/api/profile.js
  7. 12
      vue/src/utils/request.js
  8. 5
      web_main.py

4
.gitignore

@ -0,0 +1,4 @@
node_modules/
__pycache__/
*.pyc
.idea/

6
app.py

@ -9,4 +9,8 @@ if web_server_path not in sys.path:
sys.path.insert(0, web_server_path) sys.path.insert(0, web_server_path)
app = Robyn(__file__) app = Robyn(__file__)
# 注册全局认证拦截器
from util.auth_interceptor import global_auth_interceptor
app.before_request()(global_auth_interceptor)

115
controller/LoginController.py

@ -1,30 +1,19 @@
import os import os
from robyn import jsonify, Response, Request from robyn import jsonify, Response, Request
from app import app from app import app
from datetime import datetime, timedelta
import uuid
import json import json
from service.UserService import user_service from service.UserService import user_service
from util.token_utils import TokenManager
from util.auth_interceptor import require_login, get_current_user
# 临时存储token,用于会话管理
TEMP_TOKENS = {}
def generate_token() -> str:
"""生成随机token"""
return str(uuid.uuid4())
def validate_token(token: str) -> dict:
"""验证token并返回用户信息"""
if not token or token not in TEMP_TOKENS:
return None
# 检查token是否过期
if datetime.now() > TEMP_TOKENS[token]["expires_at"]:
del TEMP_TOKENS[token]
return None
return TEMP_TOKENS[token]["user"]
@app.post("/api/login") @app.post("/api/login")
def login_route(request): def login_route(request):
@ -63,10 +52,9 @@ def login_route(request):
headers={"Content-Type": "application/json; charset=utf-8"} headers={"Content-Type": "application/json; charset=utf-8"}
) )
# 生成token并设置过期时间
token = generate_token() token=TokenManager.generate_token()
expires_at = datetime.now() + timedelta(days=7 if remember else 1) TokenManager.store_token(token,user, remember)
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at}
return Response( return Response(
status_code=200, status_code=200,
@ -84,10 +72,8 @@ def login_route(request):
def logout_route(request): def logout_route(request):
"""登出接口""" """登出接口"""
try: try:
request_data = json.loads(request.body) if request.body else {} token=TokenManager.get_token_from_request(request)
token = request_data.get("token", "") TokenManager.remove_token(token)
# 删除token
TEMP_TOKENS.pop(token, None)
return Response( return Response(
status_code=200, status_code=200,
@ -102,27 +88,18 @@ def logout_route(request):
) )
@app.get("/api/userInfo") @app.get("/api/userInfo")
@require_login
def user_info_route(request): def user_info_route(request):
"""获取用户信息接口""" """获取用户信息接口"""
try: try:
query_params = getattr(request, 'query_params', {}) user = get_current_user()
token = query_params.get("token", "")
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 从数据库获取最新的用户信息 # 从数据库获取最新的用户信息
username = user["username"] username = user["username"]
db_user = user_service.get_user_info(username) db_user = user_service.get_user_info(username)
if db_user: if db_user:
# 更新TEMP_TOKENS中的用户信息 # 更新TEMP_TOKENS中的用户信息
TEMP_TOKENS[token]["user"] = db_user
user_info = db_user user_info = db_user
else: else:
user_info = user user_info = user
@ -140,48 +117,22 @@ def user_info_route(request):
) )
@app.post("/api/updateAvatar") @app.post("/api/updateAvatar")
@require_login
async def update_avatar_route(request: Request): async def update_avatar_route(request: Request):
"""更新用户头像接口""" """更新用户头像接口"""
try: try:
user = get_current_user()
# 从files中获取文件和token # 从files中获取文件和token
avatar_file = request.files.get('avatar') if hasattr(request, 'files') else None avatar_file = request.files.get('avatar') if hasattr(request, 'files') else None
token = None
# 从form_data中获取token
if hasattr(request, 'form_data'):
token = request.form_data.get('token')
# 如果files中没有直接找到'avatar',尝试获取第一个文件
if not avatar_file and hasattr(request, 'files') and request.files: if not avatar_file and hasattr(request, 'files') and request.files:
first_key = list(request.files.keys())[0] first_key = list(request.files.keys())[0]
avatar_file = request.files[first_key] avatar_file = request.files[first_key]
# 如果form_data中没有token,尝试从headers获取
if not token:
headers_dict = {}
if hasattr(request, 'headers'):
try:
headers_dict = dict(request.headers)
except:
pass
token = headers_dict.get('Authorization') or headers_dict.get('authorization')
# 如果还是没有token,尝试从body中解析
if not token and hasattr(request, 'body'):
try:
body_data = json.loads(request.body)
token = body_data.get('token')
except:
pass
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 检查文件是否存在 # 检查文件是否存在
if not avatar_file: if not avatar_file:
@ -282,8 +233,7 @@ async def update_avatar_route(request: Request):
headers={"Content-Type": "application/json; charset=utf-8"} headers={"Content-Type": "application/json; charset=utf-8"}
) )
# 更新token中的用户信息
TEMP_TOKENS[token]["user"]["avatar"] = avatar_relative_path
return Response( return Response(
status_code=200, status_code=200,
@ -302,32 +252,21 @@ async def update_avatar_route(request: Request):
) )
@app.post("/api/updatePassword") @app.post("/api/updatePassword")
@require_login
def update_password_route(request): def update_password_route(request):
"""更新用户密码接口""" """更新用户密码接口"""
try: try:
# 解析请求数据 # 解析请求数据
request_data = json.loads(request.body) if request.body else {} request_data = json.loads(request.body) if request.body else {}
token = request_data.get("token", "")
current_password = request_data.get("currentPassword", "") current_password = request_data.get("currentPassword", "")
new_password = request_data.get("newPassword", "") new_password = request_data.get("newPassword", "")
# 验证输入 # 验证输入
if not current_password or not new_password:
return Response( # 获取当前用户
status_code=400, user = get_current_user()
description=jsonify({"success": False, "message": "当前密码和新密码不能为空"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证token
user = validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 获取用户信息 # 获取用户信息
username = user["username"] username = user["username"]

127
util/auth_interceptor.py

@ -0,0 +1,127 @@
from robyn import Response, jsonify
from util.token_utils import TokenManager
from functools import wraps
import threading
# 不需要登录的公开路径(白名单)
PUBLIC_PATHS = [
'/api/login',
'/api/logout',
'/api/register',
'/api/checkUsername',
'/resource',
]
# 使用线程本地存储来保存当前请求的用户信息
_thread_local = threading.local()
def get_current_user():
"""获取当前请求的用户信息"""
return getattr(_thread_local, 'current_user', None)
def set_current_user(user):
"""设置当前请求的用户信息"""
_thread_local.current_user = user
def is_public_path(path):
"""检查是否为公开路径"""
for public_path in PUBLIC_PATHS:
if path.startswith(public_path):
return True
return False
def require_login(func):
"""装饰器:要求登录(支持同步和异步函数)"""
import asyncio
@wraps(func)
async def async_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return await func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
@wraps(func)
def sync_wrapper(request):
try:
# 提取token
token = TokenManager.get_token_from_request(request)
# 验证token
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 验证通过,保存用户信息到线程本地存储
set_current_user(user)
return func(request)
except Exception as e:
return Response(
status_code=500,
description=jsonify({"success": False, "message": f"认证异常: {str(e)}"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 根据原函数是否为异步函数,返回对应的wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
def global_auth_interceptor(request):
"""全局认证拦截器 - 用于 @app.before_request() 装饰器"""
# 获取请求路径
path = str(request.url.path) if hasattr(request.url, 'path') else str(request.url)
# 公开路径直接放行
if is_public_path(path):
return request
# 验证 token
token = TokenManager.get_token_from_request(request)
user = TokenManager.validate_token(token)
if not user:
return Response(
status_code=401,
description=jsonify({"success": False, "message": "未登录或登录已过期"}),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# 保存用户信息
set_current_user(user)
return request

112
util/token_utils.py

@ -0,0 +1,112 @@
import uuid
from datetime import datetime, timedelta
# 全局变量存储token
TEMP_TOKENS = {}
class TokenManager:
@staticmethod
def generate_token():
"""生成随机token"""
return str(uuid.uuid4())
@staticmethod
def store_token(token, user, remember=False):
"""存储token"""
expires_at = datetime.now() + timedelta(days=7 if remember else 1)
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at}
@staticmethod
def validate_token(token):
"""验证token并返回用户信息"""
if not token or token not in TEMP_TOKENS:
return None
# 检查token是否过期
if datetime.now() > TEMP_TOKENS[token]["expires_at"]:
del TEMP_TOKENS[token]
return None
return TEMP_TOKENS[token]["user"]
@staticmethod
def remove_token(token):
"""删除token"""
if token in TEMP_TOKENS:
del TEMP_TOKENS[token]
@staticmethod
def get_token_from_request(request):
"""从请求中提取token"""
import json
token = None
# 从headers获取
if hasattr(request, 'headers'):
try:
headers = request.headers
if headers:
# 方法1:直接用get方法获取(Robyn的Headers对象支持)
try:
token_value = headers.get('token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 方法2:尝试大写Token
if not token:
try:
token_value = headers.get('Token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 方法3:如果headers是字符串(JSON格式),解析它
if not token:
try:
headers_str = str(headers)
if headers_str.startswith('{'):
headers_dict = json.loads(headers_str)
token_value = headers_dict.get('token') or headers_dict.get('Token')
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
except:
pass
# 从query参数获取
if not token and hasattr(request, 'query_params'):
try:
query_params = getattr(request, 'query_params', {})
# Robyn的query_params.get()需要提供默认值
token_value = query_params.get('token', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
# 也尝试从 headers[token] 参数获取
if not token:
token_value = query_params.get('headers[token]', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
# 尝试URL编码的版本
if not token:
token_value = query_params.get('headers%5Btoken%5D', None)
if token_value:
token = token_value[0] if isinstance(token_value, list) else token_value
except:
pass
# 从body中获取
if not token and hasattr(request, 'body') and request.body:
try:
body_data = json.loads(request.body)
token = body_data.get('token')
except:
pass
return token

11
vue/src/api/profile.js

@ -10,9 +10,7 @@ export function getUserProfile(token) {
return request({ return request({
url: '/api/userInfo', url: '/api/userInfo',
method: 'get', method: 'get',
params: {
token
}
}); });
} }
@ -26,9 +24,7 @@ export function updateAvatar(formData) {
url: '/api/updateAvatar', url: '/api/updateAvatar',
method: 'post', method: 'post',
data: formData, data: formData,
headers: {
'Content-Type': 'multipart/form-data'
}
}); });
} }
@ -41,6 +37,7 @@ export function updatePassword(data) {
return request({ return request({
url: '/api/updatePassword', url: '/api/updatePassword',
method: 'post', method: 'post',
data data,
}); });
} }

12
vue/src/utils/request.js

@ -7,6 +7,10 @@ const service = axios.create({
service.interceptors.request.use( service.interceptors.request.use(
(config) => { (config) => {
const token = localStorage.getItem('token');
if (token) {
config.headers['token'] = token;
}
return config; return config;
}, },
(error) => { (error) => {
@ -14,14 +18,20 @@ service.interceptors.request.use(
} }
); );
service.interceptors.response.use( service.interceptors.response.use(
(response) => { (response) => {
return response.data; return response.data;
}, },
(error) => { (error) => {
console.error('请求错误拦截器:', error); console.error('请求错误拦截器:', error);
if (error.response && error.response.status === 401) {
localStorage.removeItem('token');
window.location.href = '/login';
}
return Promise.reject(error); return Promise.reject(error);
} }
); );
export default service;
export default service;

5
web_main.py

@ -1,5 +1,3 @@
from robyn import ALLOW_CORS
from app import app from app import app
import controller import controller
from service.UserService import init_mysql_connection from service.UserService import init_mysql_connection
@ -10,7 +8,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
resource_dir = os.path.join(current_dir, "resource") resource_dir = os.path.join(current_dir, "resource")
if os.path.exists(resource_dir): if os.path.exists(resource_dir):
app.serve_directory("/resource", resource_dir) app.serve_directory("/resource", resource_dir)
print(f"静态资源目录已配置: {resource_dir}")
if __name__ == "__main__": if __name__ == "__main__":
init_mysql_connection() and app.start(host="0.0.0.0", port=8088) init_mysql_connection() and app.start(host="0.0.0.0", port=8088)

Loading…
Cancel
Save