-
+ From 3b53b2d7d1b93747596d9b952f2f8f732437d40a Mon Sep 17 00:00:00 2001 From: hanyuqing <1106611654@qq.com> Date: Wed, 14 Jan 2026 08:47:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=A0=8F=E5=A4=8D=E7=94=A8+?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/GraphStyleController.py | 117 +++-- controller/OperationController.py | 223 ++++++++- controller/QAController.py | 129 +++--- service/GraphStyleService.py | 283 ++++++------ service/OperationService.py | 519 +++++++++++++++++++-- util/auth_interceptor.py | 1 + vue/src/api/data.js | 95 ++++ vue/src/components/GraphToolbar.vue | 260 +++++++++++ vue/src/system/GraphDemo.vue | 281 +++++++----- vue/src/system/GraphQA.vue | 280 +++++++++--- vue/src/system/GraphStyle.vue | 362 +++++++++++---- vue/src/system/KGData.vue | 874 ++++++++++++++++++++++++++++-------- 12 files changed, 2702 insertions(+), 722 deletions(-) create mode 100644 vue/src/components/GraphToolbar.vue diff --git a/controller/GraphStyleController.py b/controller/GraphStyleController.py index f601d03..ceced76 100644 --- a/controller/GraphStyleController.py +++ b/controller/GraphStyleController.py @@ -1,9 +1,12 @@ -# controller/GraphStyleController.py import json +import logging from robyn import jsonify, Response from app import app from service.GraphStyleService import GraphStyleService +# 配置日志记录 +logger = logging.getLogger(__name__) + # --- 核心工具函数:解决乱码 --- def create_response(status_code, data_dict): @@ -20,46 +23,76 @@ def create_response(status_code, data_dict): @app.post("/api/graph/style/save") async def save_style_config(request): """ - 保存配置接口 - 增强防跑偏版 - 逻辑: - 1. 如果 body 中包含 is_auto_save: true,则强制忽略 group_name,防止自动保存篡改归属。 - 2. 如果是普通保存或移动,则正常传递 group_name。 + 保存配置接口 - 增强校验版 + 支持: + 1. 跨组精准移动 (id + target_group_id) - 优先级最高 + 2. 跨组名称移动 (id + group_name + is_auto_save: false) + 3. 手动/自动保存 (id + is_auto_save: true/false) + 4. 新建保存 (无 id) """ try: + # 1. 解析请求体 body = request.json() + # 提取参数 config_id = body.get('id') canvas_name = body.get('canvas_name') current_label = body.get('current_label') styles = body.get('styles') - # 核心改动:获取 group_name + # 核心改动点:接收精准 ID + target_group_id = body.get('target_group_id') group_name = body.get('group_name') - # 增加一个前端标识:如果是实时同步(防抖保存),前端可以传这个字段 + # 默认为 False,代表这是一次手动操作(可能是保存,也可能是移动) is_auto_save = body.get('is_auto_save', False) + # 2. 基础参数校验 if not all([canvas_name, current_label, styles]): - return create_response(200, {"code": 400, "msg": "参数不完整"}) - - # 如果是自动保存模式,显式清空 group_name,强制 Service 进入“仅更新样式”逻辑 - final_group_name = None if is_auto_save else group_name - - # 将处理后的参数传给 Service 层 - success = GraphStyleService.save_config( - canvas_name=canvas_name, - current_label=current_label, - styles_dict=styles, - group_name=final_group_name, - config_id=config_id - ) - - if success: - return create_response(200, {"code": 200, "msg": "操作成功"}) + return create_response(200, {"code": 400, "msg": "参数不完整:缺失标签名或样式数据"}) + + # --- 核心逻辑分流 --- + + # 情况 A:更新记录 (前端传了 ID) + if config_id: + # 判断动作类型用于日志和反馈 + # 只要传了 target_group_id 或提供了 group_name 且非自动保存,就视为移动 + is_moving = (target_group_id is not None) or (group_name is not None and not is_auto_save) + action_label = "移动" if is_moving else "更新" + + # 修改点:将 target_group_id 显式传递给 Service 层 + success = GraphStyleService.update_config( + config_id=int(config_id), + canvas_name=canvas_name, + current_label=current_label, + styles_dict=styles, + group_name=group_name, + target_group_id=target_group_id, # 确保这一行存在! + is_auto_save=is_auto_save + ) + + if success: + return create_response(200, {"code": 200, "msg": f"{action_label}操作完成"}) + else: + return create_response(200, + {"code": 500, "msg": f"{action_label}失败,请确认配置是否存在或内容是否有变化"}) + + # 情况 B:新增记录 (前端未传 ID) else: - return create_response(200, {"code": 500, "msg": "操作失败"}) + success = GraphStyleService.create_config( + canvas_name=canvas_name, + current_label=current_label, + styles_dict=styles, + group_name=group_name + ) + if success: + return create_response(200, {"code": 200, "msg": "新配置已创建成功"}) + else: + return create_response(200, {"code": 500, "msg": "新建配置失败,请重试"}) + except Exception as e: - return create_response(200, {"code": 500, "msg": f"系统异常: {str(e)}"}) + logger.error(f"Controller 异常: {str(e)}", exc_info=True) + return create_response(200, {"code": 500, "msg": f"服务器内部错误: {str(e)}"}) @app.get("/api/graph/style/list/grouped") @@ -69,12 +102,13 @@ async def get_grouped_style_list(request): data = GraphStyleService.get_grouped_configs() return create_response(200, {"code": 200, "data": data, "msg": "查询成功"}) except Exception as e: + logger.error(f"查询异常: {str(e)}") return create_response(200, {"code": 500, "msg": f"查询异常: {str(e)}"}) @app.post("/api/graph/style/group/apply") async def apply_style_group(request): - """应用全案""" + """应用全案:一键切换当前激活的样式组""" try: body = request.json() group_id = body.get('group_id') @@ -86,7 +120,7 @@ async def apply_style_group(request): if success: return create_response(200, {"code": 200, "msg": "方案已成功应用"}) else: - return create_response(200, {"code": 500, "msg": "应用全案失败"}) + return create_response(200, {"code": 500, "msg": "应用全案失败,请检查方案是否存在"}) except Exception as e: return create_response(200, {"code": 500, "msg": f"操作异常: {str(e)}"}) @@ -112,7 +146,7 @@ async def set_default_style_group(request): @app.get("/api/graph/style/groups") async def get_group_names(request): - """获取所有已存在的方案组列表""" + """获取所有已存在的方案组列表(用于下拉选择)""" try: data = GraphStyleService.get_group_list() return create_response(200, {"code": 200, "data": data, "msg": "查询成功"}) @@ -120,19 +154,9 @@ async def get_group_names(request): return create_response(200, {"code": 500, "msg": f"查询异常: {str(e)}"}) -@app.get("/api/graph/style/list") -async def get_style_list(request): - """获取原始扁平配置列表""" - try: - data = GraphStyleService.get_all_configs() - return create_response(200, {"code": 200, "data": data, "msg": "查询成功"}) - except Exception as e: - return create_response(200, {"code": 500, "msg": f"查询异常: {str(e)}"}) - - @app.post("/api/graph/style/delete") async def delete_style_config(request): - """删除单条画布配置""" + """删除单条配置记录""" try: body = request.json() config_id = body.get('id') @@ -144,14 +168,14 @@ async def delete_style_config(request): if success: return create_response(200, {"code": 200, "msg": "删除成功"}) else: - return create_response(200, {"code": 500, "msg": "删除失败"}) + return create_response(200, {"code": 404, "msg": "删除失败,配置可能已被删除"}) except Exception as e: return create_response(200, {"code": 500, "msg": f"操作异常: {str(e)}"}) @app.post("/api/graph/style/group/delete") async def delete_style_group(request): - """删除整个方案组及其下属所有配置""" + """级联删除整个方案组及其下属所有配置""" try: body = request.json() group_id = body.get('group_id') @@ -161,7 +185,7 @@ async def delete_style_group(request): success = GraphStyleService.delete_group(group_id) if success: - return create_response(200, {"code": 200, "msg": "方案组已彻底删除"}) + return create_response(200, {"code": 200, "msg": "方案组及关联配置已彻底删除"}) else: return create_response(200, {"code": 500, "msg": "方案组删除失败"}) except Exception as e: @@ -175,6 +199,7 @@ async def batch_delete_style(request): body = request.json() config_ids = body.get('ids') + # 容错:处理前端可能以 JSON 字符串形式发送的列表 if isinstance(config_ids, str): try: config_ids = json.loads(config_ids) @@ -182,9 +207,13 @@ async def batch_delete_style(request): pass if not config_ids or not isinstance(config_ids, list): - return create_response(200, {"code": 400, "msg": "参数格式错误"}) + return create_response(200, {"code": 400, "msg": "参数格式错误,请提供ID列表"}) count = GraphStyleService.batch_delete_configs(config_ids) - return create_response(200, {"code": 200, "msg": f"成功删除 {count} 条配置", "count": count}) + return create_response(200, { + "code": 200, + "msg": f"成功删除 {count} 条配置数据", + "count": count + }) except Exception as e: return create_response(200, {"code": 500, "msg": f"批量删除异常: {str(e)}"}) \ No newline at end of file diff --git a/controller/OperationController.py b/controller/OperationController.py index afe162a..fc30c91 100644 --- a/controller/OperationController.py +++ b/controller/OperationController.py @@ -9,15 +9,40 @@ from urllib.parse import unquote operation_service = OperationService() -# --- 核心工具函数 --- - +# 2. 新增/替换这个深度转换函数 +def deep_convert(data): + """ + 专门解决 Neo4j 对象序列化问题的工具函数 + """ + if isinstance(data, dict): + return {k: deep_convert(v) for k, v in data.items()} + elif isinstance(data, list): + return [deep_convert(i) for i in data] + # 增加对数值类型的保护,其他的全部转为字符串 + elif isinstance(data, (int, float, bool, type(None))): + return data + elif isinstance(data, str): + return data + # 如果是 Neo4j 的 ID、Long 或其他不可识别对象,一律强转字符串 + else: + try: + return str(data) + except: + return None + +# 3. 替换原来的 create_response def create_response(status_code, data_dict): """ - 统一响应格式封装,强制使用 UTF-8 防止中文乱码。 + 统一响应格式封装。 + 不再直接用 jsonify(data_dict),因为那处理不了嵌套的 Neo4j 对象。 """ + # 第一步:清洗数据,把所有特殊对象转为标准 Python 类型 + clean_data = deep_convert(data_dict) + + # 第二步:手动序列化,确保中文不乱码,且 elementId 等长字符串不被截断 return Response( status_code=status_code, - description=jsonify(data_dict), + description=json.dumps(clean_data, ensure_ascii=False), headers={"Content-Type": "application/json; charset=utf-8"} ) @@ -333,4 +358,192 @@ def get_kg_stats(req): return create_response(200, {"code": 400, "msg": msg}) except Exception as e: traceback.print_exc() - return create_response(200, {"code": 500, "msg": f"统计数据异常: {str(e)}"}) \ No newline at end of file + return create_response(200, {"code": 500, "msg": f"统计数据异常: {str(e)}"}) + + +# --- 12. 数据导出接口 --- + +@app.get("/api/kg/export/nodes") +def export_nodes(req): + """ + 节点导出接口:全量导出满足筛选条件的节点 + """ + try: + # 1. 提取筛选参数 + name_raw = get_query_param(req, "name", "") + label_raw = get_query_param(req, "label", "") + + # 2. 参数清洗 + name = name_raw.strip() if name_raw and str(name_raw).lower() not in ["null", "undefined"] else None + label = label_raw.strip() if label_raw and label_raw not in ["全部", "", "null", "undefined"] else None + + # 3. 调用 Service:移除 limit 参数,执行全量导出 + result = operation_service.export_nodes_to_json(label=label, name=name) + + if result.get("success"): + return create_response(200, { + "code": 200, + "data": result.get("data"), + "total": result.get("count", 0), + "msg": "success" + }) + else: + return create_response(200, {"code": 500, "msg": result.get("msg", "获取导出数据失败")}) + + except Exception as e: + traceback.print_exc() + return create_response(200, {"code": 500, "msg": f"导出节点接口异常: {str(e)}"}) + + +@app.get("/api/kg/export/relationships") +def export_relationships(req): + """ + 关系导出接口:全量导出满足筛选条件的关系 + """ + try: + # 1. 提取筛选参数 + source_raw = get_query_param(req, "source", "") + target_raw = get_query_param(req, "target", "") + type_raw = get_query_param(req, "type", "") + + # 2. 参数清洗 + source = source_raw.strip() if source_raw and str(source_raw).lower() not in ["null", "undefined"] else None + target = target_raw.strip() if target_raw and str(target_raw).lower() not in ["null", "undefined"] else None + rel_type = type_raw.strip() if type_raw and type_raw not in ["全部", "", "null", "undefined"] else None + + # 3. 执行导出查询:移除 limit 参数,执行全量导出 + result = operation_service.export_relationships_to_json( + source=source, + target=target, + rel_type=rel_type + ) + + if result.get("success"): + return create_response(200, { + "code": 200, + "data": result.get("data"), + "total": result.get("count", 0), + "msg": "success" + }) + else: + return create_response(200, { + "code": 500, + "msg": result.get("msg", "获取导出关系失败") + }) + + except Exception as e: + traceback.print_exc() + return create_response(200, { + "code": 500, + "msg": f"导出关系接口异常: {str(e)}" + }) + + +# --- 13. 批量导入核心接口 (预检 & 执行) --- + +@app.post("/api/kg/import/nodes/precheck") +def import_nodes_precheck(req): + """ + 预检接口:接收前端上传的 nodes 数组,进行冲突和有效性扫描。 + """ + try: + body = parse_request_body(req) + nodes = body.get("nodes", []) + if not nodes: return create_response(200, {"code": 400, "msg": "数据为空"}) + + result = operation_service.precheck_nodes_batch(nodes) + return create_response(200, { + "code": 200, + "data": { + "conflicts": result.get("conflicts", []), + "invalid": result.get("invalid", []), # 这里会包含因缺少 nodeId 而被过滤的数据 + "summary": result.get("summary", {}) + }, + "msg": "预检完成" + }) + except Exception as e: + return create_response(200, {"code": 500, "msg": f"预检异常: {str(e)}"}) + + +@app.post("/api/kg/import/nodes/execute") +def import_nodes_execute(req): + """ + 执行导入接口:支持模式 mode (strict/skip/update)。 + """ + try: + body = parse_request_body(req) + nodes = body.get("nodes", []) + mode = body.get("mode", "skip") + + if not nodes: return create_response(200, {"code": 400, "msg": "批次数据为空"}) + + result = operation_service.execute_node_import_batch(nodes, mode=mode) + + if result.get("success"): + return create_response(200, {"code": 200, "msg": result.get("msg")}) + else: + # 严格模式下如果因冲突失败,返回详情 + return create_response(200, { + "code": 500, + "msg": result.get("msg"), + "data": {"conflicts": result.get("conflicts", [])} + }) + except Exception as e: + return create_response(200, {"code": 500, "msg": f"执行异常: {str(e)}"}) + +# --- 关系导入核心修正 --- + +@app.post("/api/kg/import/relationships/precheck") +def import_rels_precheck(req): + """ + 修正说明:适配 Service 层返回的 summary 结构,确保前端准确识别冲突数 + """ + try: + body = parse_request_body(req) + # 兼容 relationships 和 rels 两种写法 + rels = body.get("relationships") or body.get("rels") or [] + if not rels: + return create_response(200, {"code": 400, "msg": "预检数据为空"}) + + result = operation_service.precheck_rels_batch(rels) + return create_response(200, { + "code": 200, + "data": { + "conflicts": result.get("conflicts", []), + "invalid": result.get("invalid", []), + "summary": result.get("summary", {}) # 包含准确的 conflict 计数 + }, + "msg": "预检成功" + }) + except Exception as e: + traceback.print_exc() + return create_response(200, {"code": 500, "msg": f"关系预检异常: {str(e)}"}) + + +@app.post("/api/kg/import/relationships/execute") +def import_rels_execute(req): + """ + 修正说明:将 Service 层生成的包含 elementId 的标准 Graph 数据透传回前端 + """ + try: + body = parse_request_body(req) + rels = body.get("relationships") or body.get("rels") or [] + mode = body.get("mode", "skip") + + if not rels: + return create_response(200, {"code": 400, "msg": "执行数据为空"}) + + result = operation_service.execute_rel_import_batch(rels, mode=mode) + + if result.get("success"): + return create_response(200, { + "code": 200, + "msg": result.get("msg"), + "data": result.get("data"), # 核心:返回标准嵌套结构的数组 + "count": result.get("count", 0) + }) + else: + return create_response(200, {"code": 500, "msg": result.get("msg")}) + except Exception as e: + traceback.print_exc() + return create_response(200, {"code": 500, "msg": f"关系导入执行异常: {str(e)}"}) \ No newline at end of file diff --git a/controller/QAController.py b/controller/QAController.py index d61e52d..aeac479 100644 --- a/controller/QAController.py +++ b/controller/QAController.py @@ -9,48 +9,62 @@ from robyn import jsonify, Response from app import app from controller.client import client -import uuid + +# --- 核心工具函数:解决元组返回错误及中文乱码 --- +def create_response(status_code, data_dict): + """ + 统一响应格式封装。 + 1. 确保返回的是 Robyn 预期的 Response 对象。 + 2. description 必须是字符串(json.dumps 结果)。 + 3. 强制使用 UTF-8 防止中文乱码。 + """ + return Response( + status_code=status_code, + description=json.dumps(data_dict, ensure_ascii=False), + headers={"Content-Type": "application/json; charset=utf-8"} + ) + def convert_to_g6_format(data): - entities = data["entities"] - relations = data["relations"] + entities = data.get("entities", []) + relations = data.get("relations", []) # 创建实体名称到唯一ID的映射 name_to_id = {} nodes = [] for ent in entities: - name = ent["n"] - if name not in name_to_id: + name = ent.get("n") + if name and name not in name_to_id: node_id = str(uuid.uuid4()) name_to_id[name] = node_id nodes.append({ "id": node_id, "label": name, "data": { - "type": ent["t"] # 可用于 G6 的节点样式区分 + "type": ent.get("t") # 用于 G6 的节点样式区分 } }) # 构建边,并为每条边生成唯一 ID edges = [] for rel in relations: - e1 = rel["e1"] - e2 = rel["e2"] - r = rel["r"] + e1 = rel.get("e1") + e2 = rel.get("e2") + r = rel.get("r") source_id = name_to_id.get(e1) target_id = name_to_id.get(e2) if source_id and target_id: - edge_id = str(uuid.uuid4()) # 👈 为边生成唯一 ID + edge_id = str(uuid.uuid4()) edges.append({ - "id": edge_id, # ✅ 添加 id 字段 + "id": edge_id, "source": source_id, "target": target_id, - "label": r, # G6 支持直接使用 label(非必须放 data) + "label": r, "data": { - "label": r # 保留 data.label 便于扩展 + "label": r } }) else: @@ -60,85 +74,88 @@ def convert_to_g6_format(data): "nodes": nodes, "edges": edges } + + @app.post("/api/qa/analyze") async def analyze(request): body = request.json() input_text = body.get("text", "").strip() + if not input_text: - return jsonify({"error": "缺少 text 字段"}), 400 + # 使用 create_response 统一返回格式,避免使用 jsonify 可能带来的元组嵌套问题 + return create_response(400, {"error": "缺少 text 字段"}) + try: - # 直接转发到大模型服务(假设它返回 { "task_id": "xxx" }) + # 1. 提取实体 resp = await client.post( "/getEntity", json={"text": input_text}, - timeout=1800.0 # 30分钟 + timeout=1800.0 ) - qaList = [] - if resp.status_code == 202 or resp.status_code == 200: + qaList = [] + if resp.status_code in (200, 202): resp_json = resp.json() - resp_json_data = resp_json.get("data",{}) - resp_json_data = json.loads(resp_json_data) + # 处理字符串形式的 data 字段 + resp_json_data = resp_json.get("data", "{}") + if isinstance(resp_json_data, str): + resp_json_data = json.loads(resp_json_data) + entities = resp_json_data.get("entities", []) - print(entities) - data = [] + print(f"提取到的实体: {entities}") + + # 查询 Neo4j 邻居(此逻辑保留,虽目前未直接放入 qaList,可能用于后续扩展) for name in entities: - neighbors =neo4j_client.find_neighbors_with_relationshipsAI( + neo4j_client.find_neighbors_with_relationshipsAI( node_label=None, direction="both", node_properties={"name": name}, rel_type=None ) - data.append({ - name:neighbors - }) - resp = await client.post( + + # 2. 问答代理获取答案列表 + resp_agent = await client.post( "/question_agent", - json={"neo4j_data": [], - "text": input_text}, - timeout=1800.0 # 30分钟 + json={"neo4j_data": [], "text": input_text}, + timeout=1800.0 ) - resp_data = resp.json() - inner_data = json.loads(resp_data["data"]) - # 第二步:获取 json 数组 - items = inner_data["json"] - # 第三步:按 sort 排序(虽然所有都是 0.9,但为了通用性还是排序) - # 如果 sort 相同,可以保留原始顺序(使用 stable sort),或按 xh 排序等 - sorted_items = sorted(items, key=lambda x: x["sort"], reverse=True) + resp_data = resp_agent.json() + inner_data = json.loads(resp_data["data"]) + items = inner_data.get("json", []) - # 第四步:取前5个 + # 按权重排序取前5 + sorted_items = sorted(items, key=lambda x: x.get("sort", 0), reverse=True) top5 = sorted_items[:5] + + # 3. 对每个答案提取关系图谱 for item in top5: - resp = await client.post( + resp_ext = await client.post( "/extract_entities_and_relations", json={"text": item['answer']}, - timeout=1800.0 # 30分钟 + timeout=1800.0 ) - if resp.status_code in (200, 202): - result = resp.json() - print(result) + + if resp_ext.status_code in (200, 202): + result = resp_ext.json() g6_data = convert_to_g6_format(result) - print(g6_data) qaList.append({ "answer": item["answer"], "result": g6_data, - }) - print(f"xh: {item['xh']}, answer: {item['answer']}, sort: {item['sort']}") - print(resp.json()) - return Response( - status_code=200, - description=jsonify(qaList), - headers={"Content-Type": "text/plain; charset=utf-8"} - ) + }) + print(f"处理成功 xh: {item.get('xh')}, sort: {item.get('sort')}") + + # --- 修复点:使用 create_response 返回解析后的数组 --- + return create_response(200, qaList) + else: - return jsonify({ + return create_response(resp.status_code, { "error": "提交失败", "detail": resp.text - }), resp.status_code + }) + except Exception as e: error_trace = traceback.format_exc() print("❌ 发生异常:") print(error_trace) - - return jsonify({"error": str(e),"traceback": error_trace}), 500 \ No newline at end of file + return create_response(500, {"error": str(e), "traceback": error_trace}) \ No newline at end of file diff --git a/service/GraphStyleService.py b/service/GraphStyleService.py index e4202ad..62e96e8 100644 --- a/service/GraphStyleService.py +++ b/service/GraphStyleService.py @@ -1,191 +1,218 @@ -# service/GraphStyleService.py import json +import logging from util.mysql_utils import mysql_client +# 配置日志 +logger = logging.getLogger(__name__) + class GraphStyleService: @staticmethod - def save_config(canvas_name: str, current_label: str, styles_dict: dict, group_name: str = None, config_id: int = None) -> bool: + def _get_or_create_group(group_name: str) -> int: + """内部辅助方法:获取或创建方案组 ID""" + if not group_name or group_name.strip() == "": + group_name = "默认方案" + + group_name = group_name.strip() + + # 1. 查询是否存在 + check_sql = "SELECT id FROM graph_style_groups WHERE group_name = %s LIMIT 1" + existing = mysql_client.execute_query(check_sql, (group_name,)) + if existing: + return int(existing[0]['id']) + + # 2. 不存在则插入 + insert_sql = "INSERT INTO graph_style_groups (group_name, is_active, is_default) VALUES (%s, %s, %s)" + mysql_client.execute_update(insert_sql, (group_name, False, False)) + + # 3. 获取新生成的 ID + final_check = mysql_client.execute_query(check_sql, (group_name,)) + return int(final_check[0]['id']) if final_check else 1 + + @staticmethod + def create_config(canvas_name: str, current_label: str, styles_dict: dict, group_name: str = None) -> bool: + """【纯新增】用于另存为或初始保存""" + config_json = json.dumps(styles_dict, ensure_ascii=False) + target_group_id = GraphStyleService._get_or_create_group(group_name) + + sql = """ + INSERT INTO graph_configs (canvas_name, current_label, config_json, group_id) + VALUES (%s, %s, %s, %s) """ - 保存图谱样式配置(修复版:防止自动保存导致的分组乱跑) + affected_rows = mysql_client.execute_update(sql, (canvas_name, current_label, config_json, target_group_id)) + return affected_rows > 0 + + @staticmethod + def update_config(config_id: int, canvas_name: str, current_label: str, styles_dict: dict, + group_name: str = None, is_auto_save: bool = False, target_group_id: int = None) -> bool: """ - # 2. 转换样式 JSON - config_json = json.dumps(styles_dict, ensure_ascii=False) + 核心更新逻辑:支持精准 ID 移动,优化了逻辑优先级判断 + """ + if not config_id: + logger.error("更新失败:缺少 config_id") + return False - # 3. 【核心修改点】:区分 更新 还是 新建 - if config_id: - # --- 更新逻辑 --- - # 如果带了 ID,我们要极其谨慎地处理 group_id,防止在自动保存时被误改 - - # A. 如果调用者明确传了 group_name,说明是“移动”或“初次保存到某组” - if group_name and group_name.strip() != "": - # 检查/创建 目标方案组 - check_group_sql = "SELECT id FROM graph_style_groups WHERE group_name = %s LIMIT 1" - existing_group = mysql_client.execute_query(check_group_sql, (group_name,)) - - if existing_group: - target_group_id = existing_group[0]['id'] - else: - create_group_sql = "INSERT INTO graph_style_groups (group_name, is_active, is_default) VALUES (%s, %s, %s)" - mysql_client.execute_update(create_group_sql, (group_name, False, False)) - target_group_id = mysql_client.execute_query("SELECT LAST_INSERT_ID() as last_id")[0]['last_id'] - - # 执行带分组更新的 SQL - sql = """ - UPDATE graph_configs - SET canvas_name = %s, current_label = %s, config_json = %s, group_id = %s - WHERE id = %s - """ - affected_rows = mysql_client.execute_update(sql, (canvas_name, current_label, config_json, target_group_id, config_id)) + config_json_str = json.dumps(styles_dict, ensure_ascii=False) + + try: + # --- 步骤 1:查询当前数据库状态 --- + curr_sql = "SELECT group_id, canvas_name, current_label, config_json FROM graph_configs WHERE id = %s" + current_data = mysql_client.execute_query(curr_sql, (config_id,)) + + if not current_data: + logger.warning(f"更新失败:找不到 ID 为 {config_id} 的配置") + return False + + curr_row = current_data[0] + old_group_id = int(curr_row['group_id']) + + # --- 步骤 2:确定目标组 ID (调整优先级) --- + # 优先级 1: 只要传了 target_group_id,就说明是移动操作,优先级最高 + if target_group_id is not None: + final_group_id = int(target_group_id) + logger.info(f"【移动模式】配置 {config_id}: 强制设定目标组 ID 为 {final_group_id}") + + # 优先级 2: 自动保存模式下,锁定 group_id 不允许变动 + elif is_auto_save: + final_group_id = old_group_id + logger.debug(f"【自保模式】配置 {config_id}: 锁定原组 ID {final_group_id}") + + # 优先级 3: 传了 group_name 但没传 target_group_id (旧版移动逻辑) + elif group_name: + final_group_id = GraphStyleService._get_or_create_group(group_name) + logger.info(f"【名称模式】配置 {config_id}: 根据名称 [{group_name}] 获得 ID {final_group_id}") + + # 兜底:保持不变 else: - # B. 如果没有传 group_name,说明是“实时自动保存”,严禁修改 group_id - # 这样即使前端变量乱了,数据库的分组也不会变 - sql = """ + final_group_id = old_group_id + + # --- 步骤 3:差异比对 --- + # 增加对数据一致性的判定 + has_changed = ( + int(final_group_id) != old_group_id or + canvas_name != curr_row['canvas_name'] or + current_label != curr_row['current_label'] or + config_json_str != curr_row['config_json'] + ) + + if not has_changed: + logger.info( + f"配置 {config_id} 内容无变化 (最终目标ID:{final_group_id}, 原ID:{old_group_id}),跳过数据库更新") + return True + + # --- 步骤 4:执行更新 --- + sql = """ UPDATE graph_configs - SET canvas_name = %s, current_label = %s, config_json = %s + SET group_id = %s, canvas_name = %s, current_label = %s, config_json = %s WHERE id = %s """ - affected_rows = mysql_client.execute_update(sql, (canvas_name, current_label, config_json, config_id)) - else: - # --- 新建逻辑 --- - # 新建时必须有组名,默认“默认方案” - if not group_name or group_name.strip() == "": - group_name = "默认方案" - - check_group_sql = "SELECT id FROM graph_style_groups WHERE group_name = %s LIMIT 1" - existing_group = mysql_client.execute_query(check_group_sql, (group_name,)) - if existing_group: - target_group_id = existing_group[0]['id'] - else: - create_group_sql = "INSERT INTO graph_style_groups (group_name, is_active, is_default) VALUES (%s, %s, %s)" - mysql_client.execute_update(create_group_sql, (group_name, False, False)) - target_group_id = mysql_client.execute_query("SELECT LAST_INSERT_ID() as last_id")[0]['last_id'] + params = (final_group_id, canvas_name, current_label, config_json_str, config_id) - sql = """ - INSERT INTO graph_configs (canvas_name, current_label, config_json, group_id) - VALUES (%s, %s, %s, %s) - """ - affected_rows = mysql_client.execute_update(sql, (canvas_name, current_label, config_json, target_group_id)) + affected_rows = mysql_client.execute_update(sql, params) - return affected_rows > 0 + if affected_rows > 0: + logger.info(f"更新成功,ID: {config_id}, 归属组已变更为: {final_group_id}") + return True + else: + logger.error(f"数据库更新执行成功但受影响行数为 0,ID: {config_id}") + return False + + except Exception as e: + logger.error(f"Service 层更新异常: {str(e)}", exc_info=True) + return False @staticmethod def get_grouped_configs() -> list: - """ - 获取嵌套结构的方案列表,按默认/激活状态排序 - """ - groups_sql = """ - SELECT id, group_name, is_active, is_default - FROM graph_style_groups - ORDER BY is_default DESC, id ASC - """ + """获取嵌套结构的方案列表""" + groups_sql = "SELECT id, group_name, is_active, is_default FROM graph_style_groups ORDER BY is_default DESC, id ASC" groups = mysql_client.execute_query(groups_sql) or [] configs_sql = "SELECT id, group_id, canvas_name, current_label, config_json, create_time FROM graph_configs" configs = mysql_client.execute_query(configs_sql) or [] + # 格式化配置数据 for conf in configs: - if conf.get('config_json'): - try: - conf['styles'] = json.loads(conf['config_json']) - except: - conf['styles'] = {} + conf['styles'] = json.loads(conf['config_json']) if conf.get('config_json') else {} + # 保持 key 简洁 + if 'config_json' in conf: del conf['config_json'] - if conf.get('create_time') and not isinstance(conf['create_time'], str): conf['create_time'] = conf['create_time'].strftime('%Y-%m-%d %H:%M:%S') - result = [] + # 组装嵌套结构 for g in groups: g['is_active'] = bool(g['is_active']) g['is_default'] = bool(g['is_default']) - g_children = [c for c in configs if c['group_id'] == g['id']] - g['configs'] = g_children + g['configs'] = [c for c in configs if c['group_id'] == g['id']] g['expanded'] = g['is_active'] - result.append(g) - - return result + return groups @staticmethod def apply_group_all(group_id: int) -> bool: - """应用全案:设置激活状态""" + """切换当前激活的方案组""" try: - reset_sql = "UPDATE graph_style_groups SET is_active = %s" - mysql_client.execute_update(reset_sql, (False,)) - apply_sql = "UPDATE graph_style_groups SET is_active = %s WHERE id = %s" - affected_rows = mysql_client.execute_update(apply_sql, (True, group_id)) + # 重置所有组的激活状态 + mysql_client.execute_update("UPDATE graph_style_groups SET is_active = %s", (False,)) + # 激活目标组 + affected_rows = mysql_client.execute_update( + "UPDATE graph_style_groups SET is_active = %s WHERE id = %s", + (True, group_id) + ) return affected_rows > 0 except Exception as e: - print(f"Apply group error: {e}") + logger.error(f"应用全案异常: {str(e)}") return False @staticmethod def set_default_group(group_id: int) -> bool: - """设为系统初始默认方案""" + """设为默认方案组""" try: - reset_sql = "UPDATE graph_style_groups SET is_default = %s" - mysql_client.execute_update(reset_sql, (False,)) - set_sql = "UPDATE graph_style_groups SET is_default = %s WHERE id = %s" - affected_rows = mysql_client.execute_update(set_sql, (True, group_id)) + mysql_client.execute_update("UPDATE graph_style_groups SET is_default = %s", (False,)) + affected_rows = mysql_client.execute_update( + "UPDATE graph_style_groups SET is_default = %s WHERE id = %s", + (True, group_id) + ) return affected_rows > 0 except Exception as e: - print(f"Set default error: {e}") + logger.error(f"设置默认方案异常: {str(e)}") return False @staticmethod - def get_all_configs() -> list: - """获取扁平查询""" - sql = """ - SELECT c.id, c.group_id, c.canvas_name, c.current_label, c.config_json, c.create_time, g.is_active - FROM graph_configs c - LEFT JOIN graph_style_groups g ON c.group_id = g.id - ORDER BY c.create_time DESC - """ - rows = mysql_client.execute_query(sql) - if not rows: return [] - - for row in rows: - if row.get('config_json'): - try: - row['styles'] = json.loads(row['config_json']) - except: - row['styles'] = {} - del row['config_json'] - if row.get('create_time') and not isinstance(row['create_time'], str): - row['create_time'] = row['create_time'].strftime('%Y-%m-%d %H:%M:%S') - row['is_active'] = bool(row.get('is_active', False)) - return rows - - @staticmethod def delete_group(group_id: int) -> bool: - """级联删除""" - del_configs_sql = "DELETE FROM graph_configs WHERE group_id = %s" - mysql_client.execute_update(del_configs_sql, (group_id,)) - del_group_sql = "DELETE FROM graph_style_groups WHERE id = %s" - affected_rows = mysql_client.execute_update(del_group_sql, (group_id,)) - return affected_rows > 0 + """级联删除组及其下的所有配置""" + try: + # 先删配置,再删组(如果没设外键级联) + mysql_client.execute_update("DELETE FROM graph_configs WHERE group_id = %s", (group_id,)) + affected_rows = mysql_client.execute_update("DELETE FROM graph_style_groups WHERE id = %s", (group_id,)) + return affected_rows > 0 + except Exception as e: + logger.error(f"删除方案组异常: {str(e)}") + return False @staticmethod def delete_config(config_id: int) -> bool: - """删除配置""" - sql = "DELETE FROM graph_configs WHERE id = %s" - affected_rows = mysql_client.execute_update(sql, (config_id,)) - return affected_rows > 0 + """删除单个配置""" + try: + affected_rows = mysql_client.execute_update("DELETE FROM graph_configs WHERE id = %s", (config_id,)) + return affected_rows > 0 + except Exception as e: + logger.error(f"删除配置异常: {str(e)}") + return False @staticmethod def batch_delete_configs(config_ids: list) -> int: """批量删除""" if not config_ids: return 0 try: - clean_ids = [int(cid) for cid in config_ids if str(cid).isdigit()] - except: return 0 - if not clean_ids: return 0 - placeholders = ', '.join(['%s'] * len(clean_ids)) - sql = f"DELETE FROM graph_configs WHERE id IN ({placeholders})" - return mysql_client.execute_update(sql, tuple(clean_ids)) + placeholders = ', '.join(['%s'] * len(config_ids)) + sql = f"DELETE FROM graph_configs WHERE id IN ({placeholders})" + return mysql_client.execute_update(sql, tuple(config_ids)) + except Exception as e: + logger.error(f"批量删除异常: {str(e)}") + return 0 @staticmethod def get_group_list() -> list: - """获取方案列表""" + """简单的方案名称列表""" sql = "SELECT id, group_name, is_active, is_default FROM graph_style_groups ORDER BY is_default DESC, id DESC" return mysql_client.execute_query(sql) or [] \ No newline at end of file diff --git a/service/OperationService.py b/service/OperationService.py index 1237811..19c7aff 100644 --- a/service/OperationService.py +++ b/service/OperationService.py @@ -11,6 +11,96 @@ class OperationService: def __init__(self): self.db = neo4j_client + # --- 0. 内部辅助工具:格式标准化 --- + def _format_node_data(self, node): + """ + 统一转换前端传来的平铺或嵌套 JSON 格式。 + 输出标准结构: { "labels": [...], "properties": {...} } + """ + # 1. 提取并标准化 labels + raw_labels = node.get("labels") or node.get("label") + if isinstance(raw_labels, str): + labels = [raw_labels] + elif isinstance(raw_labels, list): + labels = raw_labels + else: + labels = [] + + # 2. 提取并标准化 properties + if "properties" in node and isinstance(node["properties"], dict): + # 嵌套格式:直接取属性字典 + props = node["properties"] + else: + # 平铺格式:打包除特殊键外的所有键值对 + props = { + k: v for k, v in node.items() + if k not in ["label", "labels", "identity", "elementId"] + } + + return {"labels": labels, "properties": props} + + def _normalize_rel_data(self, item): + """ + 参考节点导入的逻辑,统一转换关系数据。 + 支持:标准嵌套 JSON、平铺 JSON、以及带有 properties 包装的格式。 + """ + def clean_str(val): + return str(val).strip() if val is not None else None + + # 1. 尝试识别标准嵌套结构 (start/end 对象) + if isinstance(item.get("start"), dict) and isinstance(item.get("end"), dict): + s_node = item["start"] + e_node = item["end"] + + # 使用获取属性的通用逻辑 ( properties 优先 ) + s_props = s_node.get("properties") if isinstance(s_node.get("properties"), dict) else s_node + e_props = e_node.get("properties") if isinstance(e_node.get("properties"), dict) else e_node + + # 提取标签 (取第一个标签用于精确匹配) + s_label = s_node.get("labels", [""])[0] if s_node.get("labels") else "" + e_label = e_node.get("labels", [""])[0] if e_node.get("labels") else "" + + # 获取关系信息 + rel_obj = {} + if item.get("segments") and len(item["segments"]) > 0: + rel_obj = item["segments"][0].get("relationship", {}) + else: + rel_obj = item.get("relationship", {}) + + r_props = rel_obj.get("properties") if isinstance(rel_obj.get("properties"), dict) else rel_obj + + return { + "source_name": clean_str(s_props.get("name")), + "source_label": s_label, + "target_name": clean_str(e_props.get("name")), + "target_label": e_label, + "rel_type": clean_str(rel_obj.get("type")), + "rel_label": clean_str(r_props.get("label") or r_props.get("name") or "") + } + + # 2. 扁平格式适配 + alias_map = { + "source": ["source_name", "source", "start_name", "起点"], + "target": ["target_name", "target", "end_name", "终点"], + "type": ["rel_type", "type", "relationship"], + "label": ["rel_label", "label", "关系标签"] + } + + def find_value(keys): + for k in keys: + val = item.get(k) + if val and not isinstance(val, dict): return val + return None + + return { + "source_name": clean_str(find_value(alias_map["source"])), + "source_label": clean_str(item.get("source_label") or ""), + "target_name": clean_str(find_value(alias_map["target"])), + "target_label": clean_str(item.get("target_label") or ""), + "rel_type": clean_str(find_value(alias_map["type"])), + "rel_label": clean_str(find_value(alias_map["label"])) or "" + } + # --- 0. 数据修复工具 --- def fix_all_missing_node_ids(self): try: @@ -392,86 +482,417 @@ class OperationService: return {"success": False, "msg": f"删除失败: {str(e)}"} # --- 7. 导出功能 --- - def export_nodes_to_json(self, label=None, name=None): - """ - 按照条件导出节点,确保包含 identity, elementId, labels, properties 等所有原始字段 - """ + def export_nodes_to_json(self, label=None, name=None): # 删除了参数中的 limit=20 try: conditions = [] params = {} - # 构建过滤条件(复用查询逻辑,但去掉分页) - if name: + if name and str(name).strip() and name not in ["null", "undefined"]: params["name"] = unquote(str(name)).strip() conditions.append("n.name CONTAINS $name") - lb_clause = "" - if label and label not in ["全部", ""]: - # 为了保证原生对象的完整性,这里直接 MATCH 标签 - lb_clause = f":`{label}`" + if label and str(label).strip() and label not in ["全部", "", "null", "undefined"]: + params["export_label"] = str(label).strip() + label_cypher = f":`{label}`" + else: + label_cypher = "" where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" - # 注意:这里 RETURN n,返回的是整个节点对象 - cypher = f"MATCH (n{lb_clause}) {where_clause} RETURN n" - - raw_data = self.db.execute_read(cypher, params) + # 彻底移除 limit_clause + cypher = f""" + MATCH (n{label_cypher}) + {where_clause} + RETURN elementId(n) AS elementId, + labels(n) AS labels, + properties(n) AS properties + """ export_items = [] - for row in raw_data: - node = row['n'] - # 核心逻辑:提取 Neo4j 节点对象的所有原生属性 - node_data = { - "identity": node.id, # 对应你截图中的 identity (旧版 ID) - "elementId": node.element_id, # 对应你截图中的 elementId (新版 ID) - "labels": list(node.labels), - "properties": dict(node.items()) - } - export_items.append(node_data) + with self.db.driver.session() as session: + result = session.run(cypher, params) + for index, row in enumerate(result): + export_items.append({ + "identity": index, + "elementId": row.get("elementId"), + "labels": row.get("labels"), + "properties": row.get("properties") + }) - return {"success": True, "data": export_items} + return {"success": True, "data": export_items, "count": len(export_items)} except Exception as e: traceback.print_exc() return {"success": False, "msg": f"导出节点失败: {str(e)}"} - def export_relationships_to_json(self, source=None, target=None, rel_type=None): - """ - 按照条件导出关系,确保包含起始/结束节点信息及完整属性 - """ + def export_relationships_to_json(self, source=None, target=None, rel_type=None): # 删除了参数中的 limit=20 try: conditions = [] params = {} + if source: params["source"] = unquote(str(source)).strip() conditions.append("a.name CONTAINS $source") if target: params["target"] = unquote(str(target)).strip() conditions.append("b.name CONTAINS $target") - if rel_type and rel_type not in ["全部", ""]: - conditions.append(f"type(r) = $rel_type") - params["rel_type"] = rel_type + if rel_type and str(rel_type).strip() and rel_type not in ["全部", "", "null", "undefined"]: + params["rel_type"] = str(rel_type).strip() + conditions.append("type(r) = $rel_type") where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" - # 返回关系对象 r 以及起止节点的 elementId 以便追溯 - cypher = f"MATCH (a)-[r]->(b) {where_clause} RETURN r, elementId(a) as startNode, elementId(b) as endNode" - - raw_data = self.db.execute_read(cypher, params) + # 彻底移除 limit_clause + cypher = f""" + MATCH (a)-[r]->(b) + {where_clause} + RETURN + {{ + elementId: elementId(a), + labels: labels(a), + properties: properties(a) + }} AS start_node, + {{ + elementId: elementId(b), + labels: labels(b), + properties: properties(b) + }} AS end_node, + {{ + type: type(r), + properties: properties(r), + elementId: elementId(r), + startNodeElementId: elementId(a), + endNodeElementId: elementId(b) + }} AS rel_info + """ export_items = [] - for row in raw_data: - rel = row['r'] - rel_data = { - "identity": rel.id, - "elementId": rel.element_id, - "type": rel.type, - "startNodeElementId": row['startNode'], - "endNodeElementId": row['endNode'], - "properties": dict(rel.items()) - } - export_items.append(rel_data) + with self.db.driver.session() as session: + result = session.run(cypher, params) + for index, record in enumerate(result): + s = record["start_node"] + e = record["end_node"] + r = record["rel_info"] + + node_id_base = index * 2 + s["identity"] = node_id_base + e["identity"] = node_id_base + 1 + + r["identity"] = index + r["start"] = s["identity"] + r["end"] = e["identity"] + + export_items.append({ + "start": s, + "end": e, + "segments": [{"start": s, "relationship": r, "end": e}], + "length": 1.0 + }) + + return {"success": True, "data": export_items, "count": len(export_items)} + except Exception as e: + traceback.print_exc() + return {"success": False, "msg": f"导出关系失败: {str(e)}"} + + # --- 8. 节点导入核心功能 --- + def precheck_nodes_batch(self, nodes_batch): + """ + 全量预检:针对一批数据,检查格式无效性、nodeId冲突、name+label冲突 + """ + conflicts = [] + invalid_data = [] + valid_nodes = [] + + # 1. 内存清洗:先转换格式,再严格校验“三要素” + for index, raw_node in enumerate(nodes_batch): + # 格式标准化 (处理平铺/嵌套) + node = self._format_node_data(raw_node) + props = node["properties"] + + name = props.get("name") + labels = node["labels"] + node_id = props.get("nodeId") # 关键:获取 nodeId + + # 严格判定逻辑:name、labels、nodeId 缺一不可 + if not name or not labels or node_id is None: + reasons = [] + if not name: reasons.append("缺少 name") + if not labels: reasons.append("缺少 label") + if node_id is None: reasons.append("缺少 nodeId") + + invalid_data.append({ + "index": index, + "name": name or "未知", + "reason": " | ".join(reasons) + }) + continue + + valid_nodes.append(node) + + if not valid_nodes: + return {"success": True, "conflicts": conflicts, "invalid": invalid_data} + + # 2. 批量数据库比对 (查询潜在冲突) + all_node_ids = [n["properties"]["nodeId"] for n in valid_nodes] + all_names = [n["properties"]["name"] for n in valid_nodes] + + # 查询 nodeId 冲突 + db_id_map = {} + if all_node_ids: + id_results = self.db.execute_read( + "MATCH (n) WHERE n.nodeId IN $ids RETURN n.nodeId as nodeId, n.name as name", {"ids": all_node_ids}) + db_id_map = {row["nodeId"]: row for row in id_results} + + # 查询 name+label 冲突 + db_name_set = set() + if all_names: + name_results = self.db.execute_read( + "MATCH (n) WHERE n.name IN $names RETURN n.name as name, labels(n) as labels", {"names": all_names}) + db_name_set = {f"{row['name']}_{lbl}" for row in name_results for lbl in row['labels']} + + # 3. 组装冲突报告 + for node in valid_nodes: + p = node["properties"] + n_id, name, labels = p["nodeId"], p["name"], node["labels"] + + # 优先级 1: nodeId 冲突 + if n_id in db_id_map: + conflicts.append({ + "name": name, "label": labels[0], "nodeId": n_id, + "reason": f"业务ID冲突: 已存在 nodeId={n_id}", + "type": "nodeId_duplicate" + }) + continue + + # 优先级 2: name + label 冲突 + for lbl in labels: + if f"{name}_{lbl}" in db_name_set: + conflicts.append({ + "name": name, "label": lbl, "nodeId": n_id, + "reason": f"逻辑主键冲突: {lbl} 下已存在名称 '{name}'", + "type": "logic_key_duplicate" + }) + break + + return { + "success": True, + "conflicts": conflicts, + "invalid": invalid_data, + "summary": {"total": len(nodes_batch), "valid": len(valid_nodes), "conflict": len(conflicts)} + } + + def execute_node_import_batch(self, nodes_batch, mode="skip"): + try: + formatted_batch = [self._format_node_data(n) for n in nodes_batch] + # 获取当前时间的标准字符串格式,确保与手动添加的节点一致 + current_time_str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + if mode == "strict": + check = self.precheck_nodes_batch(nodes_batch) + if check.get("conflicts"): + return {"success": False, "msg": "严格模式下发现冲突,停止导入", "conflicts": check["conflicts"]} + + label_groups = {} + for node in formatted_batch: + lbls = node.get("labels") + if not lbls or len(lbls) == 0: continue + lbl = lbls[0] + if lbl not in label_groups: label_groups[lbl] = [] + props = node.get("properties") + if props and props.get("name"): + label_groups[lbl].append(props) + + total_imported = 0 + with self.db.driver.session() as session: + for lbl, batch_props in label_groups.items(): + if not batch_props: continue + + # 关键修改:将 datetime() 替换为传入的 $now 字符串 + if mode == "update": + cypher = f""" + UNWIND $batch AS props + MERGE (n:`{lbl}` {{name: props.name}}) + ON CREATE SET n = props, n.createTime = $now + ON MATCH SET n += props, n.updateTime = $now + RETURN count(n) as cnt + """ + elif mode == "skip": + cypher = f""" + UNWIND $batch AS props + MERGE (n:`{lbl}` {{name: props.name}}) + ON CREATE SET n = props, n.createTime = $now + RETURN count(n) as cnt + """ + else: + cypher = f""" + UNWIND $batch AS props + CREATE (n:`{lbl}`) + SET n = props, n.createTime = $now + RETURN count(n) as cnt + """ + + res = session.run(cypher, {"batch": batch_props, "now": current_time_str}) + record = res.single() + if record: total_imported += record["cnt"] + + return {"success": True, "msg": f"成功处理 {total_imported} 个节点", "count": total_imported} + except Exception as e: + traceback.print_exc() + return {"success": False, "msg": f"批次导入异常: {str(e)}"} + + def precheck_rels_batch(self, rels_batch): + """ + 关系导入预检 - 修复版 + """ + conflicts = [] # 关系已存在 + invalid = [] # 节点不存在或格式错误 + # 注意:这里不再使用 valid_to_check 这种模糊中间变量 + + # 统计真正可以执行导入的数量 + actual_valid_count = 0 + + for raw_item in rels_batch: + item = self._normalize_rel_data(raw_item) + + # 1. 基础格式校验 + if not item["source_name"] or not item["target_name"] or not item["rel_type"]: + invalid.append({ + "source": item.get("source_name") or "未知", + "target": item.get("target_name") or "未知", + "reason": "格式错误:缺少必要字段" + }) + continue + + # 2. 数据库存在性校验 + cypher = f""" + OPTIONAL MATCH (s {{name: $s_name}}) + OPTIONAL MATCH (t {{name: $t_name}}) + OPTIONAL MATCH (s)-[r:`{item['rel_type']}`]->(t) + RETURN s IS NOT NULL as hasS, t IS NOT NULL as hasT, r IS NOT NULL as hasR + """ + res = self.db.execute_read(cypher, {"s_name": item["source_name"], "t_name": item["target_name"]}) + + if not res: + continue + rec = res[0] + + if not rec["hasS"] or not rec["hasT"]: + # 关键修复:节点不存在,属于 invalid + invalid.append({ + "source": item["source_name"], + "target": item["target_name"], + "reason": f"节点不存在(起点:{'√' if rec['hasS'] else '×'}, 终点:{'√' if rec['hasT'] else '×'})" + }) + elif rec["hasR"]: + # 关系已存在,属于冲突 + conflicts.append({ + "source": item["source_name"], + "target": item["target_name"], + "type": item["rel_type"], + "reason": "关系已存在" + }) + else: + # 只有走到这里,才是真正的有效数据 + actual_valid_count += 1 + + return { + "success": True, + "conflicts": conflicts, + "invalid": invalid, + "summary": { + "total": len(rels_batch), + "valid": actual_valid_count, # 真正能导进去的数量 + "conflict": len(conflicts), + "invalid": len(invalid) + } + } + + def execute_rel_import_batch(self, rels_batch, mode="skip"): + """ + 执行导入:加入 Label 辅助匹配,确保 ElementId 100% 捕获 + """ + try: + now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + final_results = [] + + with self.db.driver.session() as session: + for raw_item in rels_batch: + item = self._normalize_rel_data(raw_item) + + # 基础检查:没有起点或终点名称直接跳过 + if not item["source_name"] or not item["target_name"] or not item["rel_type"]: + continue + + # 动态构建 Cypher:如果有 Label 则带上 Label,匹配更精准 + s_label_cypher = f":`{item['source_label']}`" if item['source_label'] else "" + t_label_cypher = f":`{item['target_label']}`" if item['target_label'] else "" + + op = "ON MATCH SET r.label = $label" if mode == "update" else "" + + cypher = f""" + MATCH (s{s_label_cypher} {{name: $s_name}}) + MATCH (t{t_label_cypher} {{name: $t_name}}) + MERGE (s)-[r:`{item['rel_type']}`]->(t) + ON CREATE SET r.label = $label, r.createTime = $now + {op} + RETURN s, t, r, + id(s) as s_id, id(t) as t_id, id(r) as r_id, + elementId(s) as s_eid, elementId(t) as t_eid, elementId(r) as r_eid + """ + + res = session.run(cypher, { + "s_name": item["source_name"], + "t_name": item["target_name"], + "label": item["rel_label"], + "now": now + }) - return {"success": True, "data": export_items} + record = res.single() + if record: + s_node, t_node, r_rel = record["s"], record["t"], record["r"] + + # 严格按照你要求的“理想格式”拼装 + graph_item = { + "start": { + "identity": record["s_id"], + "labels": list(s_node.labels), + "properties": dict(s_node), + "elementId": str(record["s_eid"]) + }, + "end": { + "identity": record["t_id"], + "labels": list(t_node.labels), + "properties": dict(t_node), + "elementId": str(record["t_eid"]) + }, + "segments": [{ + "start": { + "identity": record["s_id"], + "labels": list(s_node.labels), + "properties": dict(s_node), + "elementId": str(record["s_eid"]) + }, + "relationship": { + "identity": record["r_id"], + "start": record["s_id"], + "end": record["t_id"], + "type": r_rel.type, + "properties": dict(r_rel), + "elementId": str(record["r_eid"]), + "startNodeElementId": str(record["s_eid"]), + "endNodeElementId": str(record["t_eid"]) + }, + "end": { + "identity": record["identity"] if "identity" in t_node else record["t_id"], # 备选方案 + "labels": list(t_node.labels), + "properties": dict(t_node), + "elementId": str(record["t_eid"]) + } + }], + "length": 1.0 + } + final_results.append(graph_item) + + return {"success": True, "data": final_results, "count": len(final_results)} except Exception as e: traceback.print_exc() - return {"success": False, "msg": f"导出关系失败: {str(e)}"} \ No newline at end of file + return {"success": False, "msg": f"导入执行失败: {str(e)}"} diff --git a/util/auth_interceptor.py b/util/auth_interceptor.py index 2eeefc3..f0e4a8c 100644 --- a/util/auth_interceptor.py +++ b/util/auth_interceptor.py @@ -10,6 +10,7 @@ PUBLIC_PATHS = [ '/api/register', '/api/checkUsername', '/resource', + '/api/kg/export', ] diff --git a/vue/src/api/data.js b/vue/src/api/data.js index 690b355..87e6f49 100644 --- a/vue/src/api/data.js +++ b/vue/src/api/data.js @@ -125,4 +125,99 @@ export function deleteRelationship(id) { method: 'post', data: { id } }) +} + +// --- 11. 导出接口--- + +/** + * 导出节点数据到 JSON + * @param {object} params - 包含 name, label + */ +export function exportNodes(params) { + return request({ + url: '/api/kg/export/nodes', + method: 'get', + params, + timeout: 60000 + }) +} + +/** + * 导出关系数据到 JSON + * @param {object} params - 包含 source, target, type + */ +export function exportRelationships(params) { + return request({ + url: '/api/kg/export/relationships', + method: 'get', + params, + timeout: 60000 + }) +} + +// --- 12. 导入接口 --- + +/** + * 节点导入预检 + * @param {Array} nodes - 全量节点列表(前端解析 JSON 后的数组) + * @returns {Promise} - 返回冲突列表和汇总报告 + */ +export function precheckNodes(nodes) { + return request({ + url: '/api/kg/import/nodes/precheck', + method: 'post', + data: { nodes }, + // 预检 10 万条数据涉及大量内存比对,建议超时设为 2 分钟 + timeout: 120000 + }) +} + +/** + * 执行节点批量导入 + * @param {Array} nodes - 当前批次的节点数据 (建议 5000 条一包) + * @param {string} mode - 导入模式: 'skip' (忽略), 'update' (更新), 'strict' (严格) + */ +export function executeImportNodes(nodes, mode = 'skip') { + return request({ + url: '/api/kg/import/nodes/execute', + method: 'post', + data: { + nodes, + mode + }, + // 单个批次处理建议超时 1 分钟 + timeout: 60000 + }) +} + +// --- 13. 关系导入接口 --- + +/** + * 关系导入预检 + * @param {Array} relationships - 前端解析后的关系数组 (兼容嵌套和扁平格式) + */ +export function precheckRelationships(relationships) { + return request({ + url: '/api/kg/import/relationships/precheck', + method: 'post', + data: { relationships }, + timeout: 120000 // 关系预检涉及双向节点查询,超时时间设长 + }) +} + +/** + * 执行关系批量导入 + * @param {Array} relationships - 当前批次的关系数据 + * @param {string} mode - 'skip' (忽略已存在), 'update' (覆盖已存在属性) + */ +export function executeImportRelationships(relationships, mode = 'skip') { + return request({ + url: '/api/kg/import/relationships/execute', + method: 'post', + data: { + relationships, + mode + }, + timeout: 60000 + }) } \ No newline at end of file diff --git a/vue/src/components/GraphToolbar.vue b/vue/src/components/GraphToolbar.vue new file mode 100644 index 0000000..9ae4173 --- /dev/null +++ b/vue/src/components/GraphToolbar.vue @@ -0,0 +1,260 @@ + + + + + + + \ No newline at end of file diff --git a/vue/src/system/GraphDemo.vue b/vue/src/system/GraphDemo.vue index 38870ac..563e253 100644 --- a/vue/src/system/GraphDemo.vue +++ b/vue/src/system/GraphDemo.vue @@ -4,83 +4,90 @@
-
-
+