You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
4.1 KiB
112 lines
4.1 KiB
import uuid
|
|
from datetime import datetime, timedelta
|
|
|
|
# 全局变量存储token
|
|
TEMP_TOKENS = {}
|
|
|
|
|
|
class TokenManager:
|
|
@staticmethod
|
|
def generate_token():
|
|
"""生成随机token"""
|
|
return str(uuid.uuid4())
|
|
|
|
@staticmethod
|
|
def store_token(token, user, remember=False):
|
|
"""存储token"""
|
|
expires_at = datetime.now() + timedelta(days=7 if remember else 1)
|
|
TEMP_TOKENS[token] = {"user": user, "expires_at": expires_at}
|
|
|
|
@staticmethod
|
|
def validate_token(token):
|
|
"""验证token并返回用户信息"""
|
|
if not token or token not in TEMP_TOKENS:
|
|
return None
|
|
|
|
# 检查token是否过期
|
|
if datetime.now() > TEMP_TOKENS[token]["expires_at"]:
|
|
del TEMP_TOKENS[token]
|
|
return None
|
|
|
|
return TEMP_TOKENS[token]["user"]
|
|
|
|
@staticmethod
|
|
def remove_token(token):
|
|
"""删除token"""
|
|
if token in TEMP_TOKENS:
|
|
del TEMP_TOKENS[token]
|
|
|
|
@staticmethod
|
|
def get_token_from_request(request):
|
|
"""从请求中提取token"""
|
|
import json
|
|
token = None
|
|
|
|
# 从headers获取
|
|
if hasattr(request, 'headers'):
|
|
try:
|
|
headers = request.headers
|
|
if headers:
|
|
# 方法1:直接用get方法获取(Robyn的Headers对象支持)
|
|
try:
|
|
token_value = headers.get('token', None)
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
except:
|
|
pass
|
|
|
|
# 方法2:尝试大写Token
|
|
if not token:
|
|
try:
|
|
token_value = headers.get('Token', None)
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
except:
|
|
pass
|
|
|
|
# 方法3:如果headers是字符串(JSON格式),解析它
|
|
if not token:
|
|
try:
|
|
headers_str = str(headers)
|
|
if headers_str.startswith('{'):
|
|
headers_dict = json.loads(headers_str)
|
|
token_value = headers_dict.get('token') or headers_dict.get('Token')
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
except:
|
|
pass
|
|
except:
|
|
pass
|
|
|
|
# 从query参数获取
|
|
if not token and hasattr(request, 'query_params'):
|
|
try:
|
|
query_params = getattr(request, 'query_params', {})
|
|
# Robyn的query_params.get()需要提供默认值
|
|
token_value = query_params.get('token', None)
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
|
|
# 也尝试从 headers[token] 参数获取
|
|
if not token:
|
|
token_value = query_params.get('headers[token]', None)
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
|
|
# 尝试URL编码的版本
|
|
if not token:
|
|
token_value = query_params.get('headers%5Btoken%5D', None)
|
|
if token_value:
|
|
token = token_value[0] if isinstance(token_value, list) else token_value
|
|
except:
|
|
pass
|
|
|
|
# 从body中获取
|
|
if not token and hasattr(request, 'body') and request.body:
|
|
try:
|
|
body_data = json.loads(request.body)
|
|
token = body_data.get('token')
|
|
except:
|
|
pass
|
|
|
|
return token
|
|
|