import json import traceback import uuid from util.neo4j_utils import Neo4jUtil, neo4j_client import httpx from robyn import jsonify, Response from app import app from controller.client import client def convert_to_g6_format(data): entities = data["entities"] relations = data["relations"] # 创建实体名称到唯一ID的映射 name_to_id = {} nodes = [] for ent in entities: name = ent["n"] if name not in name_to_id: node_id = str(uuid.uuid4()) name_to_id[name] = node_id nodes.append({ "id": node_id, "label": name, "data":{ "type": ent["t"] # 可用于 G6 的节点样式区分 } }) # 构建边 edges = [] for rel in relations: e1 = rel["e1"] e2 = rel["e2"] r = rel["r"] source_id = name_to_id.get(e1) target_id = name_to_id.get(e2) if source_id and target_id: edges.append({ "source": source_id, "target": target_id, "data": { "label": r } }) else: print(f"Warning: Entity not found for relation: {rel}") return { "nodes": nodes, "edges": edges } @app.post("/api/qa/analyze") async def analyze(request): body = request.json() input_text = body.get("text", "").strip() if not input_text: return jsonify({"error": "缺少 text 字段"}), 400 try: # 直接转发到大模型服务(假设它返回 { "task_id": "xxx" }) resp = await client.post( "/getEntity", json={"text": input_text}, timeout=1800.0 # 30分钟 ) qaList = [] if resp.status_code == 202 or resp.status_code == 200: resp_json = resp.json() resp_json_data = resp_json.get("data",{}) resp_json_data = json.loads(resp_json_data) entities = resp_json_data.get("entities", []) print(entities) data = [] for name in entities: neighbors =neo4j_client.find_neighbors_with_relationshipsAI( node_label=None, direction="both", node_properties={"name": name}, rel_type=None ) data.append({ name:neighbors }) resp = await client.post( "/question_agent", json={"neo4j_data": [], "text": input_text}, timeout=1800.0 # 30分钟 ) resp_data = resp.json() inner_data = json.loads(resp_data["data"]) # 第二步:获取 json 数组 items = inner_data["json"] # 第三步:按 sort 排序(虽然所有都是 0.9,但为了通用性还是排序) # 如果 sort 相同,可以保留原始顺序(使用 stable sort),或按 xh 排序等 sorted_items = sorted(items, key=lambda x: x["sort"], reverse=True) # 第四步:取前5个 top5 = sorted_items[:5] for item in top5: resp = await client.post( "/extract_entities_and_relations", json={"text": item['answer']}, timeout=1800.0 # 30分钟 ) if resp.status_code in (200, 202): result = resp.json() print(result) g6_data = convert_to_g6_format(result) print(g6_data) qaList.append({ "answer": item["answer"], "result": g6_data, }) print(f"xh: {item['xh']}, answer: {item['answer']}, sort: {item['sort']}") print(resp.json()) return Response( status_code=200, description=jsonify(qaList), headers={"Content-Type": "text/plain; charset=utf-8"} ) else: return jsonify({ "error": "提交失败", "detail": resp.text }), resp.status_code except Exception as e: error_trace = traceback.format_exc() print("❌ 发生异常:") print(error_trace) return jsonify({"error": str(e),"traceback": error_trace}), 500