You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
9.3 KiB

3 months ago
# 全局 client(可复用)
3 months ago
import base64
import io
import json
import os
import re
import tempfile
3 months ago
import traceback
3 months ago
from docx import Document
3 months ago
import httpx
3 months ago
from robyn import jsonify, Response, Request
3 months ago
from app import app
3 months ago
from controller.client import client
3 months ago
from util.neo4j_utils import neo4j_client
3 months ago
3 months ago
# 中文类型到英文标签的映射字典
CHINESE_TO_ENGLISH_LABEL = {
"疾病": "Disease",
"症状": "Symptom",
"检查项目": "AuxiliaryExamination",
"药物": "Drug",
"手术": "Operation",
"解剖部位": "CheckSubject", # 或 AnatomicalSite,根据你的图谱设计
"并发症": "Complication",
"诊断": "Diagnosis",
"治疗": "Treatment",
"辅助治疗": "AdjuvantTherapy",
"不良反应": "AdverseReactions",
"检查": "Check",
"部门": "Department",
"疾病部位": "DiseaseSite",
"相关疾病": "RelatedDisease",
"相关症状": "RelatedSymptom",
"传播途径": "SpreadWay",
"阶段": "Stage",
"主题/主体": "Subject",
"症状与体征": "SymptomAndSign",
"治疗方案": "TreatmentPrograms",
"类型": "Type",
"原因": "Cause",
"属性": "Attribute",
"指示/适应症": "Indications",
"成分": "Ingredients",
"病原学": "Pathogenesis",
"病理类型": "PathologicalType",
"发病机制": "Pathophysiology",
"注意事项": "Precautions",
"预后": "Prognosis",
"预后生存时间": "PrognosticSurvivalTime",
"疾病比率": "DiseaseRatio",
"药物治疗": "DrugTherapy",
"感染性": "Infectious",
"实体": "Entity",
# 可根据实际需要补充更多
}
3 months ago
3 months ago
def json_response(data: dict, status_code: int = 200):
body = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
headers = {"Content-Type": "application/json; charset=utf-8"}
return Response(
status_code=status_code,
description=body, # ✅ 关键:字段名是 response
headers=headers # ✅ 必须是 dict
)
@app.post("/api/builder/analyze")
async def analyze(request: Request):
3 months ago
try:
3 months ago
# 1. 解析 JSON 请求体
body = request.json()
input_text = body.get("text", "").strip() or ""
file_b64 = body.get("file_base64")
filename = body.get("filename", "unknown.docx")
# 2. 如果既无 text 也无文件,报错
if not input_text and not file_b64:
return json_response({"error": "必须提供 text 或 file"}, status_code=400)
# 3. 从 .docx 提取文本(如果有文件)
file_text = ""
if file_b64:
try:
file_data = base64.b64decode(file_b64)
doc = Document(io.BytesIO(file_data))
file_text = "\n".join([para.text for para in doc.paragraphs]).strip()
except Exception as e:
return json_response({"error": f"解析 .docx 文件失败: {str(e)}"}, status_code=400)
# 4. 合并文本:优先用文件内容,或拼接两者(按你需求调整)
# 方案 A:只用文件内容(如果提供了文件)
# final_text = file_text if file_text else input_text
# 方案 B:拼接(推荐,更灵活)
final_text = (input_text + "\n\n" + file_text).strip()
if not final_text:
return json_response({"error": "合并后文本为空"}, status_code=400)
print(f"📄 最终提交文本(前200字符):\n{final_text[:200]}...")
# 5. 转发给大模型服务
3 months ago
resp = await client.post(
"/extract_entities_and_relations",
3 months ago
json={"text": final_text},
3 months ago
timeout=1800.0 # 30分钟
)
3 months ago
# 6. 返回结果
if resp.status_code in (200, 202):
try:
result = resp.json()
except:
result = {"raw_response": resp.text}
return json_response(result, status_code=resp.status_code)
3 months ago
else:
3 months ago
return json_response({
"error": "大模型服务调用失败",
"detail": resp.text,
"status_code": resp.status_code
}, status_code=resp.status_code)
3 months ago
except Exception as e:
3 months ago
error_trace = traceback.format_exc()
3 months ago
print("❌ 后端异常:")
3 months ago
print(error_trace)
3 months ago
return json_response({
"error": str(e),
"traceback": error_trace
}, status_code=500)
# @app.post("/api/builder/analyze")
# async def analyze(request: Request):
# ct = (request.headers.get("content-type") or "").lower()
# # === 关键:打印 body 前 100 字节的原始内容(作为字符串,忽略编码错误)===
# preview = request.body[:100].decode('utf-8', errors='replace')
# print("📦 Body preview (first 100 chars):", repr(preview))
# print("🔍 Content-Type:", repr(ct))
# print("📦 Body length:", len(request.body))
# if "multipart/form-data" not in ct:
# return json_response({"error": "仅支持 multipart/form-data"}, 400)
#
# try:
# form_data = parse_multipart(request.body, request.headers.get("content-type"))
# except Exception as e:
# return json_response({"error": f"表单解析失败: {str(e)}"}, 400)
#
# # 获取字段
# text_input = form_data.get("text", "")
# uploaded_file = form_data.get("file") # 是 dict,含 filename/file/content_type
#
# if not uploaded_file or not isinstance(uploaded_file, dict):
# return json_response({"error": "未提供有效文件"}, 400)
#
# file_content = uploaded_file["file"] # bytes
# filename = uploaded_file["filename"]
#
# # 后续处理 .docx 等逻辑保持不变...
3 months ago
3 months ago
# @app.post("/api/builder/analyze")
# async def analyze(request):
# body = request.json()
# input_text = body.get("text", "").strip()
# if not input_text:
# return jsonify({"error": "缺少 text 字段"}), 400
# try:
# # 直接转发到大模型服务(假设它返回 { "task_id": "xxx" })
# resp = await client.post(
# "/extract_entities_and_relations",
# json={"text": input_text},
# timeout=1800.0 # 30分钟
# )
# print(resp)
#
# if resp.status_code == 202 or resp.status_code == 200:
# return Response(
# status_code=200,
# description=jsonify(resp.json()),
# headers={"Content-Type": "text/plain; charset=utf-8"}
# )
# else:
# return jsonify({
# "error": "提交失败",
# "detail": resp.text
# }), resp.status_code
# except Exception as e:
# error_trace = traceback.format_exc()
# print("❌ 发生异常:")
# print(error_trace)
#
# return jsonify({"error": str(e), "traceback": error_trace}), 500
@app.post("/api/builder/build")
async def build(request):
body = request.json()
entities = body.get("entities", "[]")
relations=body.get("relations", "[]")
try:
# 确保是字符串后再 loads
if isinstance(entities, str):
entities = json.loads(entities)
else:
entities = entities # 已经是 list(理想情况)
if isinstance(relations, str):
relations = json.loads(relations)
else:
relations = relations
except Exception as e:
print("JSON decode error:", e)
return Response(status_code=400, description=f"Invalid JSON in entities or relations: {e}")
name_to_label = {}
for ent in entities:
name = ent.get("n")
typ = ent.get("t")
print(f"Entity: {name}, Type: {typ}")
# 将中文类型转为英文标签
label = CHINESE_TO_ENGLISH_LABEL.get(typ)
if label is None:
print(f"⚠️ Warning: Unknown entity type '{typ}' for entity '{name}'. Skipping or using generic label.")
label = typ # 默认回退标签
name_to_label[name] = label
# 查询 Neo4j(假设函数按属性查)
print(label)
node = neo4j_client.find_nodes_with_element_id(label=label,properties={"name": name})
if not node:
print("1111111")
if label is None:
print("sssss")
node_id = neo4j_client.insert_node(label=None, properties={"name":name})
else:
print("2222222")
node_id = neo4j_client.insert_node(label=label, properties={"name": name})
print("Found node:", node)
for rel in relations:
e1 = rel.get("e1")
r = rel.get("r")
e2 = rel.get("e2")
src_label = name_to_label.get(e1)
tgt_label = name_to_label.get(e2)
relationships = neo4j_client.find_relationships_by_condition(
source_label=src_label,
source_props={"name": e1},
target_label=tgt_label,
target_props={"name": e2},
rel_type=r,
rel_properties={"label": r}
)
if not relationships:
neo4j_client.create_relationship(
source_label=src_label,
source_props={"name": e1},
target_label=tgt_label,
target_props={"name": e2},
rel_type=r,
rel_properties={"label": r}
)
# nodes=neo4j_client.find_nodes_with_element_id(properties={"name": "糖尿病"})
print(body)