# neo4j_util.py import traceback from datetime import datetime import logging from typing import Dict, List, Optional, Any from neo4j import GraphDatabase, Driver logger = logging.getLogger(__name__) # ==================== 配置区(可按需改为从环境变量读取)==================== NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "12345678" # 建议后续改为 os.getenv("NEO4J_PASSWORD") NEO4J_DATABASE = "neo4j" class Neo4jUtil: def __init__(self, uri: str, username: str, password: str, database: str = "neo4j"): self.uri = uri self.username = username self.password = password self.database = database self.driver: Optional[Driver] = None def connect(self) -> bool: """初始化连接""" try: self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password)) self.driver.verify_connectivity() logger.info(f"Neo4jUtil 初始化完成,连接地址: {self.uri}, 数据库: {self.database}") return True except Exception as e: logger.error(f"Neo4j 连接失败: {e}") return False def close(self): """关闭驱动""" if self.driver: self.driver.close() logger.info("Neo4jUtil 驱动已关闭") # ==================== 核心执行方法 ==================== def execute_write(self, cypher: str, params: Optional[Dict[str, Any]] = None): """执行写操作""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: session.execute_write( lambda tx: tx.run(cypher, parameters=params).consume() ) print(cypher) logger.debug(f"执行写操作: {cypher}") def execute_read(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """执行读操作,返回列表字典""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: result = session.execute_read( lambda tx: [record.data() for record in tx.run(cypher, parameters=params)] ) return result def execute_write_and_return(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """执行写操作并返回结果(如 CREATE ... RETURN)""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: result = session.execute_write( lambda tx: [record.data() for record in tx.run(cypher, parameters=params)] ) return result # ==================== 节点操作 ==================== def insert_node(self, label: str, properties: Dict[str, Any]) -> str: """ 创建节点,返回 elementId """ cypher = f"CREATE (n:`{label}` $props) RETURN elementId(n) AS id" result = self.execute_write_and_return(cypher, {"props": properties}) return result[0]["id"] def delete_all_nodes_by_label(self, label: str): """删除指定标签的所有节点(含关系)""" cypher = f"MATCH (n:`{label}`) DETACH DELETE n" self.execute_write(cypher) def delete_nodes_by_condition(self, label: str, conditions: Dict[str, Any]): """按属性条件删除节点""" if not conditions: raise ValueError("删除条件不能为空,防止误删全表!") where_clause = " AND ".join([f"n.`{k}` = $cond_{k}" for k in conditions]) params = {f"cond_{k}": v for k, v in conditions.items()} cypher = f"MATCH (n:`{label}`) WHERE {where_clause} DETACH DELETE n" self.execute_write(cypher, params) def find_all_nodes(self, label: str) -> List[Dict[str, Any]]: """查询所有节点,包含 elementId""" cypher = f"MATCH (n:`{label}`) RETURN elementId(n) AS id, n{{.*}} AS props" raw = self.execute_read(cypher) return [self._merge_id_and_props(row) for row in raw] def count_nodes_by_label(self, label: str) -> int: """返回指定标签的节点数量""" # 安全校验(强烈建议保留) cypher = f"MATCH (n:`{label}`) RETURN count(n) AS total" result = self.execute_read(cypher) return result[0]["total"] # Neo4j 返回结果通常是列表,取第一个记录 from typing import Optional, Dict, Any, List def find_nodes_with_element_id( self, label: Optional[str], properties: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ 按属性查询节点(支持指定标签或不指定标签) Args: label (Optional[str]): 节点标签,若为 None 则匹配任意标签的节点 properties (Optional[Dict[str, Any]]): 节点属性过滤条件 Returns: List[Dict]: 节点列表,每个节点包含 id 和所有属性 """ # 构建节点模式 if label is not None: node_pattern = f"(n:`{label}`)" else: node_pattern = "(n)" cypher = f"MATCH {node_pattern}" params = {} if properties: # 安全构建 WHERE 条件(使用参数化查询) where_clause = " AND ".join([f"n.`{k}` = $prop_{k}" for k in properties]) params = {f"prop_{k}": v for k, v in properties.items()} cypher += f" WHERE {where_clause}" # cypher += " RETURN elementId(n) AS id, n{.*} AS props" cypher += " RETURN elementId(n) AS id, labels(n) AS labels, n{.*} AS props" raw = self.execute_read(cypher, params) return [self._merge_id_and_props(row) for row in raw] def update_node_by_properties(self, label: str, where: Dict[str, Any], updates: Dict[str, Any]): """按条件更新节点""" if not where: raise ValueError("WHERE 条件不能为空!") if not updates: raise ValueError("更新内容不能为空!") where_clause = " AND ".join([f"n.`{k}` = $where_{k}" for k in where]) set_clause = ", ".join([f"n.`{k}` = $update_{k}" for k in updates]) params = {f"where_{k}": v for k, v in where.items()} params.update({f"update_{k}": v for k, v in updates.items()}) cypher = f"MATCH (n:`{label}`) WHERE {where_clause} SET {set_clause}" self.execute_write(cypher, params) # ==================== 关系操作 ==================== def create_relationship( self, source_label: str, source_props: Dict[str, Any], target_label: str, target_props: Dict[str, Any], rel_type: str, rel_properties: Optional[Dict[str, Any]] = None, ): """创建关系(要求两端节点存在)""" if not source_props or not target_props: raise ValueError("源或目标节点条件不能为空") match_a = self._build_match_clause("a", source_label, source_props, "src") match_b = self._build_match_clause("b", target_label, target_props, "tgt") params = {**match_a["params"], **match_b["params"]} if rel_properties: rel_part = f"`{rel_type}` $rel_props" params["rel_props"] = rel_properties else: rel_part = f"`{rel_type}`" cypher = f""" MATCH {match_a['clause']}, {match_b['clause']} CREATE (a)-[r:{rel_part}]->(b) """ self.execute_write(cypher, params) def find_all_relationships(self, rel_type: Optional[str] = None) -> List[Dict[str, Any]]: """查询所有关系(可选类型)""" r_label = f":`{rel_type}`" if rel_type else "" cypher = f""" MATCH (a)-[r{r_label}]->(b) RETURN elementId(r) AS relId, type(r) AS type, r{{.*}} AS relProps, elementId(a) AS sourceId, head(labels(a)) AS sourceLabel, a{{.*}} AS sourceProps, elementId(b) AS targetId, head(labels(b)) AS targetLabel, b{{.*}} AS targetProps """ raw = self.execute_read(cypher) return [self._enrich_relationship(row) for row in raw] def find_neighbors_with_relationshipsAI( self, node_label: Optional[str], node_properties: Dict[str, Any], direction: str = "both", rel_type: Optional[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: """ 查询指定节点的所有邻居,返回扁平化的 nodes 和 relationships(无数量限制) """ if not node_properties: raise ValueError("node_properties 不能为空,用于定位起始节点") where_clause, params = self._build_where_conditions("n", node_properties, "node") rel_filter = f":`{rel_type}`" if rel_type else "" if node_label is not None: node_pattern = f"(n:`{node_label}`)" else: node_pattern = "(n)" if direction == "out": pattern = f"{node_pattern}-[r{rel_filter}]->(m)" elif direction == "in": pattern = f"{node_pattern}<[r{rel_filter}]-(m)" elif direction == "both": pattern = f"{node_pattern}-[r{rel_filter}]-(m)" else: raise ValueError("direction 必须是 'out', 'in' 或 'both'") cypher = f""" MATCH {pattern} WHERE {where_clause} WITH r, startNode(r) AS s, endNode(r) AS t RETURN elementId(s) AS sourceId, head(labels(s)) AS sourceLabel, s{{.*}} AS sourceProps, elementId(t) AS targetId, head(labels(t)) AS targetLabel, t{{.*}} AS targetProps, elementId(r) AS relId, type(r) AS relType, r{{.*}} AS relProps """ try: raw_results = self.execute_read(cypher, params) except Exception as e: traceback.print_exc() raise RuntimeError(f"查询邻居节点时发生数据库错误: {str(e)}") from e # 用字典去重节点(key 为 id) node_dict = {} relationships = [] for row in raw_results: # 处理 source 节点 source_id = row["sourceId"] if source_id not in node_dict: source_node = dict(row["sourceProps"]) source_node.update({"id": source_id, "label": row["sourceLabel"]}) node_dict[source_id] = source_node # 处理 target 节点 target_id = row["targetId"] if target_id not in node_dict: target_node = dict(row["targetProps"]) target_node.update({"id": target_id, "label": row["targetLabel"]}) node_dict[target_id] = target_node # 处理 relationship rel = { "id": row["relId"], "type": row["relType"], "sourceId": source_id, "targetId": target_id, "properties": dict(row["relProps"]) if row["relProps"] else {} } relationships.append(rel) return { "nodes": list(node_dict.values()), "relationships": relationships } def find_relationships_by_condition( self, source_label: Optional[str] = None, source_props: Optional[Dict[str, Any]] = None, target_label: Optional[str] = None, target_props: Optional[Dict[str, Any]] = None, rel_type: Optional[str] = None, rel_properties: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """按复杂条件查询关系""" a_label = f":`{source_label}`" if source_label else "" b_label = f":`{target_label}`" if target_label else "" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_parts = [] params = {} if source_props: part, p = self._build_where_conditions("a", source_props, "src") where_parts.append(part) params.update(p) if target_props: part, p = self._build_where_conditions("b", target_props, "tgt") where_parts.append(part) params.update(p) if rel_properties: part, p = self._build_where_conditions("r", rel_properties, "rel") where_parts.append(part) params.update(p) cypher = match if where_parts: cypher += " WHERE " + " AND ".join(where_parts) cypher += """ RETURN elementId(r) AS relId, type(r) AS type, r{.*} AS relProps, elementId(a) AS sourceId, head(labels(a)) AS sourceLabel, a{.*} AS sourceProps, elementId(b) AS targetId, head(labels(b)) AS targetLabel, b{.*} AS targetProps """ raw = self.execute_read(cypher, params) return [self._enrich_relationship(row) for row in raw] from typing import Optional, Dict, Any, List def find_neighbors_with_relationships( self, node_label: Optional[str], node_properties: Dict[str, Any], direction: str = "both", rel_type: Optional[str] = None, limit: int = 2000, # 👈 新增参数,默认 1000 ) -> List[Dict[str, Any]]: """ 查询指定节点的邻居(最多返回 limit 条) """ if not node_properties: raise ValueError("node_properties 不能为空,用于定位起始节点") where_clause, params = self._build_where_conditions("n", node_properties, "node") rel_filter = f":`{rel_type}`" if rel_type else "" if node_label is not None: node_pattern = f"(n:`{node_label}`)" else: node_pattern = "(n)" if direction == "out": pattern = f"{node_pattern}-[r{rel_filter}]->(m)" elif direction == "in": pattern = f"{node_pattern}<[r{rel_filter}]-(m)" elif direction == "both": pattern = f"{node_pattern}-[r{rel_filter}]-(m)" else: raise ValueError("direction 必须是 'out', 'in' 或 'both'") cypher = f""" MATCH {pattern} WHERE {where_clause} WITH r, startNode(r) AS s, endNode(r) AS t RETURN elementId(s) AS sourceId, head(labels(s)) AS sourceLabel, s{{.*}} AS sourceProps, elementId(t) AS targetId, head(labels(t)) AS targetLabel, t{{.*}} AS targetProps, elementId(r) AS relId, type(r) AS relType, r{{.*}} AS relProps LIMIT $limit """ params["limit"] = limit # 注入 limit 参数(安全) # ✅ 安全执行查询 try: raw_results = self.execute_read(cypher, params) except Exception as e: # 可选:记录原始错误(如果你有 logger) # self.logger.error(f"Neo4j 查询失败: {str(e)}\nCypher: {cypher}\nParams: {params}") traceback.print_exc() # 这会打印完整的 traceback 到 stderr # 抛出更友好的运行时错误 raise RuntimeError( f"查询邻居节点时发生数据库错误: {str(e)}" ) from e # 使用 'from e' 保留原始异常链 neighbors = [] for row in raw_results: source = dict(row["sourceProps"]) source.update({"id": row["sourceId"], "label": row["sourceLabel"]}) target = dict(row["targetProps"]) target.update({"id": row["targetId"], "label": row["targetLabel"]}) relationship = { "id": row["relId"], "type": row["relType"], "properties": dict(row["relProps"]) if row["relProps"] else {} } neighbors.append({ "source": source, "target": target, "relationship": relationship }) return neighbors # def find_neighbors_with_relationships( # self, # node_label: Optional[str], # node_properties: Dict[str, Any], # direction: str = "both", # 可选: "out", "in", "both" # rel_type: Optional[str] = None, # ) -> List[Dict[str, Any]]: # """ # 查询指定节点的所有邻居节点及其关系(包括入边、出边或双向) # # Args: # node_label (Optional[str]): 节点标签,若为 None 则匹配任意标签的节点 # node_properties (Dict[str, Any]): 节点匹配属性(必须能唯一或有效定位节点) # direction (str): 关系方向,"out" 表示 (n)-[r]->(m),"in" 表示 (n)<-[r]-(m),"both" 表示无向 # rel_type (Optional[str]): 可选的关系类型过滤 # # Returns: # List[Dict]: 每项包含 source(原节点)、target(邻居)、relationship 信息 # """ # if not node_properties: # raise ValueError("node_properties 不能为空,用于定位起始节点") # # # 构建起始节点匹配条件 # where_clause, params = self._build_where_conditions("n", node_properties, "node") # # # 构建关系类型过滤 # rel_filter = f":`{rel_type}`" if rel_type else "" # # # ✅ 动态构建节点模式:支持 node_label=None # if node_label is not None: # node_pattern = f"(n:`{node_label}`)" # else: # node_pattern = "(n)" # # # 构建完整 MATCH 模式 # if direction == "out": # pattern = f"{node_pattern}-[r{rel_filter}]->(m)" # elif direction == "in": # pattern = f"{node_pattern}<[r{rel_filter}]-(m)" # elif direction == "both": # pattern = f"{node_pattern}-[r{rel_filter}]-(m)" # else: # raise ValueError("direction 必须是 'out', 'in' 或 'both'") # # cypher = f""" # MATCH {pattern} # WHERE {where_clause} # RETURN # elementId(n) AS sourceId, # head(labels(n)) AS sourceLabel, # n{{.*}} AS sourceProps, # elementId(m) AS targetId, # head(labels(m)) AS targetLabel, # m{{.*}} AS targetProps, # elementId(r) AS relId, # type(r) AS relType, # r{{.*}} AS relProps # """ # # raw_results = self.execute_read(cypher, params) # # neighbors = [] # # for row in raw_results: # source = dict(row["sourceProps"]) # source.update({"id": row["sourceId"], "label": row["sourceLabel"]}) # # target = dict(row["targetProps"]) # target.update({"id": row["targetId"], "label": row["targetLabel"]}) # # relationship = { # "id": row["relId"], # "type": row["relType"], # "properties": row["relProps"] # } # # neighbors.append({ # "source": source, # "target": target, # "relationship": relationship # }) # # return neighbors def delete_all_relationships_by_node_label(self, node_label: str): """删除某标签节点的所有关系(保留节点)""" cypher = f"MATCH (n:`{node_label}`)-[r]-() DELETE r" self.execute_write(cypher) def delete_all_relationships_by_node_props(self, label: str, properties: Dict[str, Any]): """按属性删除某节点的所有关系""" where_clause, params = self._build_where_conditions("n", properties, "node") cypher = f"MATCH (n:`{label}`) WHERE {where_clause} MATCH (n)-[r]-() DELETE r" self.execute_write(cypher, params) def delete_relationships_advanced( self, source_label: Optional[str] = None, source_props: Optional[Dict[str, Any]] = None, target_label: Optional[str] = None, target_props: Optional[Dict[str, Any]] = None, rel_type: Optional[str] = None, rel_properties: Optional[Dict[str, Any]] = None, ): """高级删除关系""" a_label = f":`{source_label}`" if source_label else "" b_label = f":`{target_label}`" if target_label else "" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_parts = [] params = {} if source_props: part, p = self._build_where_conditions("a", source_props, "src") where_parts.append(part) params.update(p) if target_props: part, p = self._build_where_conditions("b", target_props, "tgt") where_parts.append(part) params.update(p) if rel_properties: part, p = self._build_where_conditions("r", rel_properties, "rel") where_parts.append(part) params.update(p) cypher = match if where_parts: cypher += " WHERE " + " AND ".join(where_parts) cypher += " DELETE r" self.execute_write(cypher, params) def update_relationship( self, source_label: str, source_props: Dict[str, Any], target_label: str, target_props: Dict[str, Any], rel_type: Optional[str] = None, new_rel_properties: Optional[Dict[str, Any]] = None, ): """更新关系属性""" if not new_rel_properties: raise ValueError("至少需要提供一个要更新的关系属性") a_label = f":`{source_label}`" b_label = f":`{target_label}`" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_a, p_a = self._build_where_conditions("a", source_props, "src") where_b, p_b = self._build_where_conditions("b", target_props, "tgt") where_clause = f"{where_a} AND {where_b}" params = {**p_a, **p_b} set_clause = ", ".join([f"r.`{k}` = $rel_update_{k}" for k in new_rel_properties]) for k, v in new_rel_properties.items(): params[f"rel_update_{k}"] = v cypher = f"{match} WHERE {where_clause} SET {set_clause}" self.execute_write(cypher, params) # ==================== 内部辅助方法 ==================== def _merge_id_and_props(self, row: Dict[str, Any]) -> Dict[str, Any]: """ 将 elementId、labels 和 props 合并为一个扁平字典 """ result = { "id": row["id"], "labels": row["labels"], # 保留标签列表,如 ["Disease"] **row["props"] # 展开所有属性(name, nodeId 等) } return result def _enrich_relationship(self, row: Dict[str, Any]) -> Dict[str, Any]: """格式化关系结果""" source = dict(row["sourceProps"]) source.update({"id": row["sourceId"], "label": row["sourceLabel"]}) target = dict(row["targetProps"]) target.update({"id": row["targetId"], "label": row["targetLabel"]}) return { "relId": row["relId"], "type": row["type"], "relProps": row["relProps"], "source": source, "target": target, } def _build_match_clause(self, var: str, label: str, props: Dict[str, Any], prefix: str) -> Dict[str, Any]: """ 构建合法的 MATCH 节点子句,属性必须在 () 内部。 示例输出: (a:`Disease` {`name`: $src_name}) """ if not props: clause = f"({var}:`{label}`)" return {"clause": clause, "params": {}} attr_parts = [] params = {} for k, v in props.items(): param_key = f"{prefix}_{k}" attr_parts.append(f"`{k}`: ${param_key}") params[param_key] = v attrs_str = ", ".join(attr_parts) clause = f"({var}:`{label}` {{{attrs_str}}})" return {"clause": clause, "params": params} def _build_where_conditions(self, var: str, props: Dict[str, Any], prefix: str) -> tuple[str, Dict[str, Any]]: """生成 WHERE 条件字符串和参数""" if not props: return "1=1", {} conditions = [] params = {} for k, v in props.items(): param_key = f"{prefix}_{k}" conditions.append(f"{var}.`{k}` = ${param_key}") params[param_key] = v return " AND ".join(conditions), params def get_department_disease_tree(self): """ 查询所有科室及其关联的疾病,返回 el-tree 格式数据 """ cypher = """ MATCH (d:Department)--(dis:Disease) RETURN d.name AS dept_name, collect(dis.name) AS diseases ORDER BY d.name """ results = self.execute_read(cypher) tree = [] for record in results: dept_node = { "label": record["dept_name"], "type": "Department", "children": [{"label": name,"type": "Disease"} for name in record["diseases"]] } tree.append(dept_node) return tree def get_subject_drug_tree(self): """ 查询所有药物分类(Subject)及其关联的药物(Drug),返回 el-tree 格式数据 """ cypher = """ MATCH (s:Subject)--(d:Drug) RETURN s.name AS subject_name, collect(d.name) AS drugs ORDER BY s.name """ results = self.execute_read(cypher) tree = [] for record in results: subject_node = { "label": record["subject_name"], "type": "Subject", "children": [{"label": name, "type": "Drug"} for name in record["drugs"]] } tree.append(subject_node) return tree # ==================== 全局单例实例(自动初始化)==================== neo4j_client = Neo4jUtil( uri=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE ) # 自动连接(模块导入时执行) if not neo4j_client.connect(): raise RuntimeError("Failed to connect to Neo4j at module load time!")