You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
4.8 KiB

3 months ago
import json
import traceback
3 months ago
import uuid
3 months ago
from util.neo4j_utils import Neo4jUtil, neo4j_client
3 months ago
import httpx
from robyn import jsonify, Response
from app import app
from controller.client import client
3 months ago
import uuid
3 months ago
def convert_to_g6_format(data):
entities = data["entities"]
relations = data["relations"]
# 创建实体名称到唯一ID的映射
name_to_id = {}
nodes = []
for ent in entities:
name = ent["n"]
if name not in name_to_id:
node_id = str(uuid.uuid4())
name_to_id[name] = node_id
nodes.append({
"id": node_id,
"label": name,
3 months ago
"data": {
3 months ago
"type": ent["t"] # 可用于 G6 的节点样式区分
}
3 months ago
})
3 months ago
# 构建边,并为每条边生成唯一 ID
3 months ago
edges = []
for rel in relations:
e1 = rel["e1"]
e2 = rel["e2"]
r = rel["r"]
source_id = name_to_id.get(e1)
target_id = name_to_id.get(e2)
3 months ago
3 months ago
if source_id and target_id:
3 months ago
edge_id = str(uuid.uuid4()) # 👈 为边生成唯一 ID
3 months ago
edges.append({
3 months ago
"id": edge_id, # ✅ 添加 id 字段
3 months ago
"source": source_id,
"target": target_id,
3 months ago
"label": r, # G6 支持直接使用 label(非必须放 data)
3 months ago
"data": {
3 months ago
"label": r # 保留 data.label 便于扩展
3 months ago
}
3 months ago
})
else:
print(f"Warning: Entity not found for relation: {rel}")
return {
"nodes": nodes,
"edges": edges
}
3 months ago
@app.post("/api/qa/analyze")
async def analyze(request):
body = request.json()
input_text = body.get("text", "").strip()
if not input_text:
return jsonify({"error": "缺少 text 字段"}), 400
try:
# 直接转发到大模型服务(假设它返回 { "task_id": "xxx" })
resp = await client.post(
"/getEntity",
json={"text": input_text},
timeout=1800.0 # 30分钟
)
3 months ago
qaList = []
3 months ago
if resp.status_code == 202 or resp.status_code == 200:
resp_json = resp.json()
resp_json_data = resp_json.get("data",{})
resp_json_data = json.loads(resp_json_data)
entities = resp_json_data.get("entities", [])
3 months ago
print(entities)
3 months ago
data = []
3 months ago
for name in entities:
3 months ago
neighbors =neo4j_client.find_neighbors_with_relationshipsAI(
3 months ago
node_label=None,
direction="both",
node_properties={"name": name},
3 months ago
rel_type=None
3 months ago
)
3 months ago
data.append({
name:neighbors
})
resp = await client.post(
"/question_agent",
json={"neo4j_data": [],
"text": input_text},
timeout=1800.0 # 30分钟
)
3 months ago
resp_data = resp.json()
inner_data = json.loads(resp_data["data"])
# 第二步:获取 json 数组
items = inner_data["json"]
# 第三步:按 sort 排序(虽然所有都是 0.9,但为了通用性还是排序)
# 如果 sort 相同,可以保留原始顺序(使用 stable sort),或按 xh 排序等
sorted_items = sorted(items, key=lambda x: x["sort"], reverse=True)
# 第四步:取前5个
3 months ago
top5 = sorted_items[:5]
3 months ago
for item in top5:
resp = await client.post(
"/extract_entities_and_relations",
json={"text": item['answer']},
timeout=1800.0 # 30分钟
)
if resp.status_code in (200, 202):
result = resp.json()
print(result)
3 months ago
g6_data = convert_to_g6_format(result)
print(g6_data)
3 months ago
qaList.append({
"answer": item["answer"],
3 months ago
"result": g6_data,
3 months ago
})
print(f"xh: {item['xh']}, answer: {item['answer']}, sort: {item['sort']}")
print(resp.json())
3 months ago
return Response(
status_code=200,
3 months ago
description=jsonify(qaList),
3 months ago
headers={"Content-Type": "text/plain; charset=utf-8"}
)
else:
return jsonify({
"error": "提交失败",
"detail": resp.text
}), resp.status_code
except Exception as e:
error_trace = traceback.format_exc()
print("❌ 发生异常:")
print(error_trace)
return jsonify({"error": str(e),"traceback": error_trace}), 500