import json import traceback from app import app from robyn import jsonify, Response from service.OperationService import OperationService from urllib.parse import unquote # 实例化业务逻辑对象 operation_service = OperationService() # 2. 新增/替换这个深度转换函数 def deep_convert(data): """ 专门解决 Neo4j 对象序列化问题的工具函数 """ if isinstance(data, dict): return {k: deep_convert(v) for k, v in data.items()} elif isinstance(data, list): return [deep_convert(i) for i in data] # 增加对数值类型的保护,其他的全部转为字符串 elif isinstance(data, (int, float, bool, type(None))): return data elif isinstance(data, str): return data # 如果是 Neo4j 的 ID、Long 或其他不可识别对象,一律强转字符串 else: try: return str(data) except: return None # 3. 替换原来的 create_response def create_response(status_code, data_dict): """ 统一响应格式封装。 不再直接用 jsonify(data_dict),因为那处理不了嵌套的 Neo4j 对象。 """ # 第一步:清洗数据,把所有特殊对象转为标准 Python 类型 clean_data = deep_convert(data_dict) # 第二步:手动序列化,确保中文不乱码,且 elementId 等长字符串不被截断 return Response( status_code=status_code, description=json.dumps(clean_data, ensure_ascii=False), headers={"Content-Type": "application/json; charset=utf-8"} ) def parse_request_body(req): """ 解析器:适配 Robyn 框架。 针对前端 Vue3 + ElementPlus 的请求进行深度解析,确保获取 ID、nodeId、name 和 label。 """ try: body = getattr(req, "body", None) if not body: return {} # 1. 处理 bytes 类型 (Robyn 常见的 body 类型) if isinstance(body, (bytes, bytearray)): body = body.decode('utf-8') # 2. 如果已经是字典,直接返回 if isinstance(body, dict): return body # 3. 处理字符串 (JSON 序列化后的字符串) if isinstance(body, str): try: data = json.loads(body) # 处理双层 JSON 序列化的情况 (有些前端框架会序列化两次) if isinstance(data, str): data = json.loads(data) return data except json.JSONDecodeError: # 尝试解析 URL 编码格式 (application/x-www-form-urlencoded) try: from urllib.parse import parse_qs params = parse_qs(body) return {k: v[0] for k, v in params.items()} except: return {} return {} except Exception as e: print(f"Request Body Parse Error: {e}") return {} def get_query_param(req, key, default=""): """ 提取 URL 查询参数。适配不同版本的 Robyn 参数存放位置。 """ try: # 尝试从新版/旧版 Robyn 的不同属性中提取 data_source = getattr(req, "queries", None) if data_source is None or (isinstance(data_source, dict) and not data_source): data_source = getattr(req, "query_params", {}) # 适配 Robyn 特有的 Query 对象 if hasattr(data_source, "to_dict"): data_source = data_source.to_dict() val = data_source.get(key) if val is None: return default # 提取值并进行 URL 解码 raw_val = str(val[0]) if isinstance(val, list) else str(val) return unquote(raw_val).strip() except Exception as e: print(f"Get Param Error ({key}): {e}") return default # --- 0. 数据治理修复接口 --- @app.post("/api/kg/admin/fix-ids") def fix_node_ids(req): """ 手动触发:修复数据库中 nodeId 为空或为 0 的存量数据 """ try: result = operation_service.fix_all_missing_node_ids() return create_response(200, { "code": 200 if result.get("success") else 500, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"修复接口异常: {str(e)}"}) # --- 1. 获取全量动态标签 (节点管理用) --- @app.get("/api/kg/labels") def get_labels(req): try: labels = operation_service.get_all_labels() return create_response(200, {"code": 200, "data": labels, "msg": "success"}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"获取标签失败: {str(e)}"}) # --- 新增:获取全量动态关系类型 --- @app.get("/api/kg/relationship-types") def get_rel_types(req): """ 从数据库动态获取所有关系类型 type 及其 label 映射 """ try: rel_types = operation_service.get_all_relationship_types() return create_response(200, {"code": 200, "data": rel_types, "msg": "success"}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"获取关系类型失败: {str(e)}"}) # --- 2. 输入联想建议 --- @app.get("/api/kg/node/suggest") def suggest_node(req): """ 联想词接口: 支持 keyword 模糊搜索,同时支持 label 强过滤。 """ try: # 1. 提取前端传来的参数 clean_keyword = get_query_param(req, "keyword", "") clean_label = get_query_param(req, "label", "") # 2. 调用 Service 层 # 如果 label 为 "全部" 或空,Service 层会自动处理成全库建议 suggestions = operation_service.suggest_nodes(clean_keyword, clean_label) return create_response(200, {"code": 200, "data": suggestions, "msg": "success"}) except Exception as e: print(f"Suggest Interface Error: {e}") return create_response(200, {"code": 500, "msg": f"联想接口异常: {str(e)}"}) # --- 3. 获取分页节点列表 --- @app.get("/api/kg/nodes") def get_nodes(req): try: name_raw = get_query_param(req, "name", "") label_raw = get_query_param(req, "label", "") page_str = get_query_param(req, "page", "1") size_str = get_query_param(req, "pageSize", "20") page = int(page_str) if page_str.isdigit() else 1 page_size = int(size_str) if size_str.isdigit() else 20 # 清洗参数 name = name_raw if name_raw else None label = label_raw if (label_raw and label_raw not in ["全部", "", "null"]) else None res_data = operation_service.get_nodes_subset(page, page_size, name=name, label=label) return create_response(200, {"code": 200, "data": res_data, "msg": "success"}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"获取节点失败: {str(e)}"}) # --- 4. 获取分页关系列表 --- @app.get("/api/kg/relationships") def get_relationships(req): try: source_raw = get_query_param(req, "source", "") target_raw = get_query_param(req, "target", "") type_raw = get_query_param(req, "type", "") page_str = get_query_param(req, "page", "1") size_str = get_query_param(req, "pageSize", "20") page = int(page_str) if page_str.isdigit() else 1 page_size = int(size_str) if size_str.isdigit() else 20 source = source_raw if source_raw else None target = target_raw if target_raw else None rel_type = type_raw if (type_raw and type_raw not in ["全部", ""]) else None res_data = operation_service.get_relationships_subset(page, page_size, source, target, rel_type) return create_response(200, {"code": 200, "data": res_data, "msg": "success"}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"获取关系失败: {str(e)}"}) # --- 5. 新增节点 --- @app.post("/api/kg/node/add") def add_node(req): try: body = parse_request_body(req) label = str(body.get("label", "Drug")).strip() name = str(body.get("name", "")).strip() if not name: return create_response(200, {"code": 400, "msg": "名称不能为空"}) result = operation_service.add_node(label, name) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"新增异常: {str(e)}"}) # --- 6. 修改节点 --- @app.post("/api/kg/node/update") def update_node(req): try: body = parse_request_body(req) # 兼容两种写法:id (elementId) 或 nodeId (业务ID) node_id = body.get("id") or body.get("nodeId") name = str(body.get("name", "")).strip() label = str(body.get("label", "")).strip() if not node_id or not name: return create_response(200, {"code": 400, "msg": "参数缺失: 修改必须包含ID和名称"}) result = operation_service.update_node(node_id, name, label) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"更新异常: {str(e)}"}) # --- 7. 新增关系 --- @app.post("/api/kg/rel/add") def add_relationship(req): try: body = parse_request_body(req) source = str(body.get("source", "")).strip() target = str(body.get("target", "")).strip() rel_type = str(body.get("type", "")).strip() rel_label = str(body.get("label", "")).strip() or rel_type if not all([source, target, rel_type]): return create_response(200, {"code": 400, "msg": "参数缺失: 起点、终点和类型为必填项"}) result = operation_service.add_relationship(source, target, rel_type, rel_label) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"新增关系异常: {str(e)}"}) # --- 8. 修改关系 --- @app.post("/api/kg/rel/update") def update_rel(req): try: body = parse_request_body(req) rel_id = body.get("id") source = str(body.get("source", "")).strip() target = str(body.get("target", "")).strip() rel_type = str(body.get("type", "")).strip() rel_label = str(body.get("label", "")).strip() or rel_type if not rel_id: return create_response(200, {"code": 400, "msg": "修改失败:关系ID缺失"}) result = operation_service.update_relationship(rel_id, source, target, rel_type, rel_label) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"修改关系异常: {str(e)}"}) # --- 9. 删除节点 --- @app.post("/api/kg/node/delete") def delete_node(req): try: body = parse_request_body(req) node_id = body.get("id") if not node_id: return create_response(200, {"code": 400, "msg": "删除失败: 未指定节点系统ID"}) result = operation_service.delete_node(node_id) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"删除节点异常: {str(e)}"}) # --- 10. 删除关系 --- @app.post("/api/kg/rel/delete") def delete_rel(req): try: body = parse_request_body(req) rel_id = body.get("id") if not rel_id: return create_response(200, {"code": 400, "msg": "删除失败: 未指定关系系统ID"}) result = operation_service.delete_relationship(rel_id) return create_response(200, { "code": 200 if result.get("success") else 400, "msg": result.get("msg") }) except Exception as e: return create_response(200, {"code": 500, "msg": f"删除关系异常: {str(e)}"}) # --- 11. 获取图谱全局统计数据 --- @app.get("/api/kg/stats") def get_kg_stats(req): try: result = operation_service.get_kg_stats() if result and result.get("success"): return create_response(200, {"code": 200, "data": result.get("data"), "msg": "success"}) else: msg = result.get("msg") if result else "未能获取统计数据" return create_response(200, {"code": 400, "msg": msg}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"统计数据异常: {str(e)}"}) # --- 12. 数据导出接口 --- @app.get("/api/kg/export/nodes") def export_nodes(req): """ 节点导出接口:全量导出满足筛选条件的节点 """ try: # 1. 提取筛选参数 name_raw = get_query_param(req, "name", "") label_raw = get_query_param(req, "label", "") # 2. 参数清洗 name = name_raw.strip() if name_raw and str(name_raw).lower() not in ["null", "undefined"] else None label = label_raw.strip() if label_raw and label_raw not in ["全部", "", "null", "undefined"] else None # 3. 调用 Service:移除 limit 参数,执行全量导出 result = operation_service.export_nodes_to_json(label=label, name=name) if result.get("success"): return create_response(200, { "code": 200, "data": result.get("data"), "total": result.get("count", 0), "msg": "success" }) else: return create_response(200, {"code": 500, "msg": result.get("msg", "获取导出数据失败")}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"导出节点接口异常: {str(e)}"}) @app.get("/api/kg/export/relationships") def export_relationships(req): """ 关系导出接口:全量导出满足筛选条件的关系 """ try: # 1. 提取筛选参数 source_raw = get_query_param(req, "source", "") target_raw = get_query_param(req, "target", "") type_raw = get_query_param(req, "type", "") # 2. 参数清洗 source = source_raw.strip() if source_raw and str(source_raw).lower() not in ["null", "undefined"] else None target = target_raw.strip() if target_raw and str(target_raw).lower() not in ["null", "undefined"] else None rel_type = type_raw.strip() if type_raw and type_raw not in ["全部", "", "null", "undefined"] else None # 3. 执行导出查询:移除 limit 参数,执行全量导出 result = operation_service.export_relationships_to_json( source=source, target=target, rel_type=rel_type ) if result.get("success"): return create_response(200, { "code": 200, "data": result.get("data"), "total": result.get("count", 0), "msg": "success" }) else: return create_response(200, { "code": 500, "msg": result.get("msg", "获取导出关系失败") }) except Exception as e: traceback.print_exc() return create_response(200, { "code": 500, "msg": f"导出关系接口异常: {str(e)}" }) # --- 13. 批量导入核心接口 (预检 & 执行) --- @app.post("/api/kg/import/nodes/precheck") def import_nodes_precheck(req): """ 预检接口:接收前端上传的 nodes 数组,进行冲突和有效性扫描。 """ try: body = parse_request_body(req) nodes = body.get("nodes", []) if not nodes: return create_response(200, {"code": 400, "msg": "数据为空"}) result = operation_service.precheck_nodes_batch(nodes) return create_response(200, { "code": 200, "data": { "conflicts": result.get("conflicts", []), "invalid": result.get("invalid", []), # 这里会包含因缺少 nodeId 而被过滤的数据 "summary": result.get("summary", {}) }, "msg": "预检完成" }) except Exception as e: return create_response(200, {"code": 500, "msg": f"预检异常: {str(e)}"}) @app.post("/api/kg/import/nodes/execute") def import_nodes_execute(req): """ 执行导入接口:支持模式 mode (strict/skip/update)。 """ try: body = parse_request_body(req) nodes = body.get("nodes", []) mode = body.get("mode", "skip") if not nodes: return create_response(200, {"code": 400, "msg": "批次数据为空"}) result = operation_service.execute_node_import_batch(nodes, mode=mode) if result.get("success"): return create_response(200, {"code": 200, "msg": result.get("msg")}) else: # 严格模式下如果因冲突失败,返回详情 return create_response(200, { "code": 500, "msg": result.get("msg"), "data": {"conflicts": result.get("conflicts", [])} }) except Exception as e: return create_response(200, {"code": 500, "msg": f"执行异常: {str(e)}"}) # --- 关系导入核心修正 --- @app.post("/api/kg/import/relationships/precheck") def import_rels_precheck(req): """ 修正说明:适配 Service 层返回的 summary 结构,确保前端准确识别冲突数 """ try: body = parse_request_body(req) # 兼容 relationships 和 rels 两种写法 rels = body.get("relationships") or body.get("rels") or [] if not rels: return create_response(200, {"code": 400, "msg": "预检数据为空"}) result = operation_service.precheck_rels_batch(rels) return create_response(200, { "code": 200, "data": { "conflicts": result.get("conflicts", []), "invalid": result.get("invalid", []), "summary": result.get("summary", {}) # 包含准确的 conflict 计数 }, "msg": "预检成功" }) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"关系预检异常: {str(e)}"}) @app.post("/api/kg/import/relationships/execute") def import_rels_execute(req): """ 修正说明:将 Service 层生成的包含 elementId 的标准 Graph 数据透传回前端 """ try: body = parse_request_body(req) rels = body.get("relationships") or body.get("rels") or [] mode = body.get("mode", "skip") if not rels: return create_response(200, {"code": 400, "msg": "执行数据为空"}) result = operation_service.execute_rel_import_batch(rels, mode=mode) if result.get("success"): return create_response(200, { "code": 200, "msg": result.get("msg"), "data": result.get("data"), # 核心:返回标准嵌套结构的数组 "count": result.get("count", 0) }) else: return create_response(200, {"code": 500, "msg": result.get("msg")}) except Exception as e: traceback.print_exc() return create_response(200, {"code": 500, "msg": f"关系导入执行异常: {str(e)}"})