You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

237 lines
8.1 KiB

import json
import sys
from datetime import datetime
import httpx
from app import app
from robyn import Robyn, jsonify, Response
from typing import Optional, List, Any, Dict
from service.GraphService import build_g6_subgraph_by_props, get_drug_names_from_neo4j, get_group_key, \
get_check_names_from_neo4j
from util.neo4j_utils import Neo4jUtil
from util.neo4j_utils import neo4j_client
from util.redis_utils import set as redis_set, get as redis_get # 使用你已有的模块级 Redis 工
# 缓存键
DRUG_TREE_KEY = "cache:drug_tree"
CHECK_TREE_KEY = "cache:check_tree"
# ========================
# 🔥 启动时预加载数据(在 app 启动前执行)
# ========================
def preload_data():
print("🚀 正在预加载 Drug 和 Check 树...")
try:
# --- Drug Tree ---
names = get_drug_names_from_neo4j()
groups = {}
for name in names:
key = get_group_key(name)
groups.setdefault(key, []).append(name)
alphabet = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
all_keys = alphabet + ["0-9", "其他"]
tree_data = []
for key in all_keys:
if key in groups:
children = [{"label": name, "type": "Drug"} for name in sorted(groups[key])]
tree_data.append({"label": key, "type": "Drug", "children": children})
redis_set(DRUG_TREE_KEY, json.dumps(tree_data, ensure_ascii=False), ex=3600)
# --- Check Tree ---
names = get_check_names_from_neo4j()
groups = {}
for name in names:
key = get_group_key(name)
groups.setdefault(key, []).append(name)
tree_data = []
for key in all_keys:
if key in groups:
children = [{"label": name, "type": "Check"} for name in sorted(groups[key])]
tree_data.append({"label": key, "type": "Check", "children": children})
redis_set(CHECK_TREE_KEY, json.dumps(tree_data, ensure_ascii=False), ex=3600)
print("✅ 预加载完成!数据已写入 Redis 缓存。")
except Exception as e:
print(f"❌ 预加载失败: {e}", file=sys.stderr)
# 可选:是否允许启动失败?这里选择继续启动(接口会返回错误)
# 或者 sys.exit(1) 强制退出
# 执行预加载(在 app 创建前)
preload_data()
@app.get("/api/getData")
def get_data():
try:
graph_data = build_g6_subgraph_by_props(
neo4j_client,
node_label="Disease",
node_properties={"name": "霍乱"},
direction="both",
rel_type=None
)
return Response(
status_code=200,
description=jsonify(graph_data),
headers={"Content-Type": "text/plain; charset=utf-8"}
)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.post("/api/getGraph")
def get_graph(req):
try:
# 1. 获取 JSON body(自动解析为 dict)
body = req.json()
# 2. 提取 label 字段(即疾病名称)
disease_name = body.get("label")
if not disease_name:
return jsonify({"error": "Missing 'label' in request body"}), 400
graph_data = build_g6_subgraph_by_props(
neo4j_client,
node_label=body.get("type"),
node_properties={"name": disease_name},
direction="both",
rel_type=None
)
return Response(
status_code=200,
description=jsonify(graph_data),
headers={"Content-Type": "text/plain; charset=utf-8"}
)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.get("/api/drug-tree")
def get_drug_tree():
return Response(
status_code=200,
description=redis_get(DRUG_TREE_KEY),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# try:
# names = get_drug_names_from_neo4j()
# print(f"[Step 1] Loaded {len(names)} names")
#
# groups = {}
# for name in names:
# key = get_group_key(name)
# groups.setdefault(key, []).append(name)
# print(f"[Step 2] Grouped into {len(groups)} groups")
#
# # ✅ 顺序:A-Z, 0-9, 其他
# alphabet = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
# all_keys = alphabet + ["0-9", "其他"]
#
# tree_data = []
# total_children = 0
# for key in all_keys:
# if key in groups:
# children = [{"label": name,"type":"Drug"} for name in sorted(groups[key])]
# total_children += len(children)
# tree_data.append({"label": key,"type":"Drug", "children": children})
#
# print(f"[Step 3] Final tree: {len(tree_data)} groups, {total_children} drugs")
#
# json_str = json.dumps(tree_data, ensure_ascii=False)
# print(f"[Step 4] JSON size: {len(json_str)} chars")
#
# return Response(
# status_code=200,
# description=json_str,
# headers={"Content-Type": "application/json; charset=utf-8"}
# )
# except Exception as e:
# print(f"[ERROR] {str(e)}")
# return Response(
# status_code=500,
# description=json.dumps({"error": str(e)}, ensure_ascii=False),
# headers={"Content-Type": "application/json; charset=utf-8"}
# )
@app.get("/api/check-tree")
def get_check_tree():
return Response(
status_code=200,
description=redis_get(CHECK_TREE_KEY),
headers={"Content-Type": "application/json; charset=utf-8"}
)
# try:
# names = get_check_names_from_neo4j()
# print(f"[Step 1] Loaded {len(names)} names")
#
# groups = {}
# for name in names:
# key = get_group_key(name)
# groups.setdefault(key, []).append(name)
# print(f"[Step 2] Grouped into {len(groups)} groups")
#
# # ✅ 顺序:A-Z, 0-9, 其他
# alphabet = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
# all_keys = alphabet + ["0-9", "其他"]
#
# tree_data = []
# total_children = 0
# for key in all_keys:
# if key in groups:
# children = [{"label": name,"type":"Check"} for name in sorted(groups[key])]
# total_children += len(children)
# tree_data.append({"label": key,"type":"Check", "children": children})
#
# print(f"[Step 3] Final tree: {len(tree_data)} groups, {total_children} checks")
#
# json_str = json.dumps(tree_data, ensure_ascii=False)
# print(f"[Step 4] JSON size: {len(json_str)} chars")
#
# return Response(
# status_code=200,
# description=json_str,
# headers={"Content-Type": "application/json; charset=utf-8"}
# )
# except Exception as e:
# print(f"[ERROR] {str(e)}")
# return Response(
# status_code=500,
# description=json.dumps({"error": str(e)}, ensure_ascii=False),
# headers={"Content-Type": "application/json; charset=utf-8"}
# )
@app.get("/health")
def health():
print(redis_get(DRUG_TREE_KEY))
print(redis_get(CHECK_TREE_KEY))
return {"status": "ok", "drug_cached": redis_get(DRUG_TREE_KEY) is not None}
# @app.post("/api/analyze")
# async def analyze(request):
# # 假设前端传入 JSON: {"text": "病例文本..."}
# body = request.json()
# input_text = body.get("text", "").strip()
#
# if not input_text:
# return jsonify({"error": "缺少 text 字段"}), 400
# client = httpx.AsyncClient(base_url="http://192.168.50.113:8088")
# # 调用实体关系抽取服务
# response = await client.post(
# "/extract_entities_and_relations",
# json={"text": input_text}
# )
# print(response)
# if response.status_code == 200:
# result = response.json()
# return jsonify(result)
# else:
# return jsonify({
# "error": "实体抽取服务返回错误",
# "status": response.status_code,
# "detail": response.text
# }), response.status_code