# neo4j_util.py import logging from typing import Dict, List, Optional, Any from neo4j import GraphDatabase, Driver logger = logging.getLogger(__name__) # ==================== 配置区(可按需改为从环境变量读取)==================== NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "12345678" # 建议后续改为 os.getenv("NEO4J_PASSWORD") NEO4J_DATABASE = "neo4j" class Neo4jUtil: def __init__(self, uri: str, username: str, password: str, database: str = "neo4j"): self.uri = uri self.username = username self.password = password self.database = database self.driver: Optional[Driver] = None def connect(self) -> bool: """初始化连接""" try: self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password)) self.driver.verify_connectivity() logger.info(f"Neo4jUtil 初始化完成,连接地址: {self.uri}, 数据库: {self.database}") return True except Exception as e: logger.error(f"Neo4j 连接失败: {e}") return False def close(self): """关闭驱动""" if self.driver: self.driver.close() logger.info("Neo4jUtil 驱动已关闭") # ==================== 核心执行方法 ==================== def execute_write(self, cypher: str, params: Optional[Dict[str, Any]] = None): """执行写操作""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: session.execute_write( lambda tx: tx.run(cypher, parameters=params).consume() ) logger.debug(f"执行写操作: {cypher}") def execute_read(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """执行读操作,返回列表字典""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: result = session.execute_read( lambda tx: [record.data() for record in tx.run(cypher, parameters=params)] ) return result def execute_write_and_return(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """执行写操作并返回结果(如 CREATE ... RETURN)""" if not self.driver: raise RuntimeError("Neo4j 驱动未初始化") params = params or {} with self.driver.session(database=self.database) as session: result = session.execute_write( lambda tx: [record.data() for record in tx.run(cypher, parameters=params)] ) return result # ==================== 节点操作 ==================== def insert_node(self, label: str, properties: Dict[str, Any]) -> str: """ 创建节点,返回 elementId """ cypher = f"CREATE (n:`{label}` $props) RETURN elementId(n) AS id" result = self.execute_write_and_return(cypher, {"props": properties}) return result[0]["id"] def delete_all_nodes_by_label(self, label: str): """删除指定标签的所有节点(含关系)""" cypher = f"MATCH (n:`{label}`) DETACH DELETE n" self.execute_write(cypher) def delete_nodes_by_condition(self, label: str, conditions: Dict[str, Any]): """按属性条件删除节点""" if not conditions: raise ValueError("删除条件不能为空,防止误删全表!") where_clause = " AND ".join([f"n.`{k}` = $cond_{k}" for k in conditions]) params = {f"cond_{k}": v for k, v in conditions.items()} cypher = f"MATCH (n:`{label}`) WHERE {where_clause} DETACH DELETE n" self.execute_write(cypher, params) def find_all_nodes(self, label: str) -> List[Dict[str, Any]]: """查询所有节点,包含 elementId""" cypher = f"MATCH (n:`{label}`) RETURN elementId(n) AS id, n{{.*}} AS props" raw = self.execute_read(cypher) return [self._merge_id_and_props(row) for row in raw] def find_nodes_with_element_id(self, label: str, properties: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """按属性查询节点""" cypher = f"MATCH (n:`{label}`)" params = {} if properties: where_clause = " AND ".join([f"n.`{k}` = $prop_{k}" for k in properties]) params = {f"prop_{k}": v for k, v in properties.items()} cypher += f" WHERE {where_clause}" cypher += " RETURN elementId(n) AS id, n{.*} AS props" raw = self.execute_read(cypher, params) return [self._merge_id_and_props(row) for row in raw] def update_node_by_properties(self, label: str, where: Dict[str, Any], updates: Dict[str, Any]): """按条件更新节点""" if not where: raise ValueError("WHERE 条件不能为空!") if not updates: raise ValueError("更新内容不能为空!") where_clause = " AND ".join([f"n.`{k}` = $where_{k}" for k in where]) set_clause = ", ".join([f"n.`{k}` = $update_{k}" for k in updates]) params = {f"where_{k}": v for k, v in where.items()} params.update({f"update_{k}": v for k, v in updates.items()}) cypher = f"MATCH (n:`{label}`) WHERE {where_clause} SET {set_clause}" self.execute_write(cypher, params) # ==================== 关系操作 ==================== def create_relationship( self, source_label: str, source_props: Dict[str, Any], target_label: str, target_props: Dict[str, Any], rel_type: str, rel_properties: Optional[Dict[str, Any]] = None, ): """创建关系(要求两端节点存在)""" if not source_props or not target_props: raise ValueError("源或目标节点条件不能为空") match_a = self._build_match_clause("a", source_label, source_props, "src") match_b = self._build_match_clause("b", target_label, target_props, "tgt") params = {**match_a["params"], **match_b["params"]} if rel_properties: rel_part = f"`{rel_type}` $rel_props" params["rel_props"] = rel_properties else: rel_part = f"`{rel_type}`" cypher = f""" MATCH {match_a['clause']}, {match_b['clause']} CREATE (a)-[r:{rel_part}]->(b) """ self.execute_write(cypher, params) def find_all_relationships(self, rel_type: Optional[str] = None) -> List[Dict[str, Any]]: """查询所有关系(可选类型)""" r_label = f":`{rel_type}`" if rel_type else "" cypher = f""" MATCH (a)-[r{r_label}]->(b) RETURN elementId(r) AS relId, type(r) AS type, r{{.*}} AS relProps, elementId(a) AS sourceId, head(labels(a)) AS sourceLabel, a{{.*}} AS sourceProps, elementId(b) AS targetId, head(labels(b)) AS targetLabel, b{{.*}} AS targetProps """ raw = self.execute_read(cypher) return [self._enrich_relationship(row) for row in raw] def find_relationships_by_condition( self, source_label: Optional[str] = None, source_props: Optional[Dict[str, Any]] = None, target_label: Optional[str] = None, target_props: Optional[Dict[str, Any]] = None, rel_type: Optional[str] = None, rel_properties: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """按复杂条件查询关系""" a_label = f":`{source_label}`" if source_label else "" b_label = f":`{target_label}`" if target_label else "" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_parts = [] params = {} if source_props: part, p = self._build_where_conditions("a", source_props, "src") where_parts.append(part) params.update(p) if target_props: part, p = self._build_where_conditions("b", target_props, "tgt") where_parts.append(part) params.update(p) if rel_properties: part, p = self._build_where_conditions("r", rel_properties, "rel") where_parts.append(part) params.update(p) cypher = match if where_parts: cypher += " WHERE " + " AND ".join(where_parts) cypher += """ RETURN elementId(r) AS relId, type(r) AS type, r{.*} AS relProps, elementId(a) AS sourceId, head(labels(a)) AS sourceLabel, a{.*} AS sourceProps, elementId(b) AS targetId, head(labels(b)) AS targetLabel, b{.*} AS targetProps """ raw = self.execute_read(cypher, params) return [self._enrich_relationship(row) for row in raw] def find_neighbors_with_relationships( self, node_label: str, node_properties: Dict[str, Any], direction: str = "both", # 可选: "out", "in", "both" rel_type: Optional[str] = None, ) -> List[Dict[str, Any]]: """ 查询指定节点的所有邻居节点及其关系(包括入边、出边或双向) Args: node_label (str): 节点标签 node_properties (Dict[str, Any]): 节点匹配属性(必须能唯一或有效定位节点) direction (str): 关系方向,"out" 表示 (n)-[r]->(m),"in" 表示 (n)<-[r]-(m),"both" 表示无向 rel_type (Optional[str]): 可选的关系类型过滤 Returns: List[Dict]: 每项包含 source(原节点)、target(邻居)、relationship 信息 """ if not node_properties: raise ValueError("node_properties 不能为空,用于定位起始节点") # 构建起始节点匹配条件 where_clause, params = self._build_where_conditions("n", node_properties, "node") rel_filter = f":`{rel_type}`" if rel_type else "" if direction == "out": pattern = f"(n:`{node_label}`)-[r{rel_filter}]->(m)" elif direction == "in": pattern = f"(n:`{node_label}`)<-[r{rel_filter}]-(m)" elif direction == "both": pattern = f"(n:`{node_label}`)-[r{rel_filter}]-(m)" else: raise ValueError("direction 必须是 'out', 'in' 或 'both'") cypher = f""" MATCH {pattern} WHERE {where_clause} RETURN elementId(n) AS sourceId, head(labels(n)) AS sourceLabel, n{{.*}} AS sourceProps, elementId(m) AS targetId, head(labels(m)) AS targetLabel, m{{.*}} AS targetProps, elementId(r) AS relId, type(r) AS relType, r{{.*}} AS relProps """ raw_results = self.execute_read(cypher, params) neighbors = [] for row in raw_results: source = dict(row["sourceProps"]) source.update({"id": row["sourceId"], "label": row["sourceLabel"]}) target = dict(row["targetProps"]) target.update({"id": row["targetId"], "label": row["targetLabel"]}) relationship = { "id": row["relId"], "type": row["relType"], "properties": row["relProps"] } neighbors.append({ "source": source, "target": target, "relationship": relationship }) return neighbors def delete_all_relationships_by_node_label(self, node_label: str): """删除某标签节点的所有关系(保留节点)""" cypher = f"MATCH (n:`{node_label}`)-[r]-() DELETE r" self.execute_write(cypher) def delete_all_relationships_by_node_props(self, label: str, properties: Dict[str, Any]): """按属性删除某节点的所有关系""" where_clause, params = self._build_where_conditions("n", properties, "node") cypher = f"MATCH (n:`{label}`) WHERE {where_clause} MATCH (n)-[r]-() DELETE r" self.execute_write(cypher, params) def delete_relationships_advanced( self, source_label: Optional[str] = None, source_props: Optional[Dict[str, Any]] = None, target_label: Optional[str] = None, target_props: Optional[Dict[str, Any]] = None, rel_type: Optional[str] = None, rel_properties: Optional[Dict[str, Any]] = None, ): """高级删除关系""" a_label = f":`{source_label}`" if source_label else "" b_label = f":`{target_label}`" if target_label else "" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_parts = [] params = {} if source_props: part, p = self._build_where_conditions("a", source_props, "src") where_parts.append(part) params.update(p) if target_props: part, p = self._build_where_conditions("b", target_props, "tgt") where_parts.append(part) params.update(p) if rel_properties: part, p = self._build_where_conditions("r", rel_properties, "rel") where_parts.append(part) params.update(p) cypher = match if where_parts: cypher += " WHERE " + " AND ".join(where_parts) cypher += " DELETE r" self.execute_write(cypher, params) def update_relationship( self, source_label: str, source_props: Dict[str, Any], target_label: str, target_props: Dict[str, Any], rel_type: Optional[str] = None, new_rel_properties: Optional[Dict[str, Any]] = None, ): """更新关系属性""" if not new_rel_properties: raise ValueError("至少需要提供一个要更新的关系属性") a_label = f":`{source_label}`" b_label = f":`{target_label}`" r_label = f":`{rel_type}`" if rel_type else "" match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})" where_a, p_a = self._build_where_conditions("a", source_props, "src") where_b, p_b = self._build_where_conditions("b", target_props, "tgt") where_clause = f"{where_a} AND {where_b}" params = {**p_a, **p_b} set_clause = ", ".join([f"r.`{k}` = $rel_update_{k}" for k in new_rel_properties]) for k, v in new_rel_properties.items(): params[f"rel_update_{k}"] = v cypher = f"{match} WHERE {where_clause} SET {set_clause}" self.execute_write(cypher, params) # ==================== 内部辅助方法 ==================== def _merge_id_and_props(self, row: Dict[str, Any]) -> Dict[str, Any]: """合并 id 和 props""" props = dict(row.get("props", {})) props["id"] = row["id"] return props def _enrich_relationship(self, row: Dict[str, Any]) -> Dict[str, Any]: """格式化关系结果""" source = dict(row["sourceProps"]) source.update({"id": row["sourceId"], "label": row["sourceLabel"]}) target = dict(row["targetProps"]) target.update({"id": row["targetId"], "label": row["targetLabel"]}) return { "relId": row["relId"], "type": row["type"], "relProps": row["relProps"], "source": source, "target": target, } def _build_match_clause(self, var: str, label: str, props: Dict[str, Any], prefix: str) -> Dict[str, Any]: """ 构建合法的 MATCH 节点子句,属性必须在 () 内部。 示例输出: (a:`Disease` {`name`: $src_name}) """ if not props: clause = f"({var}:`{label}`)" return {"clause": clause, "params": {}} attr_parts = [] params = {} for k, v in props.items(): param_key = f"{prefix}_{k}" attr_parts.append(f"`{k}`: ${param_key}") params[param_key] = v attrs_str = ", ".join(attr_parts) clause = f"({var}:`{label}` {{{attrs_str}}})" return {"clause": clause, "params": params} def _build_where_conditions(self, var: str, props: Dict[str, Any], prefix: str) -> tuple[str, Dict[str, Any]]: """生成 WHERE 条件字符串和参数""" if not props: return "1=1", {} conditions = [] params = {} for k, v in props.items(): param_key = f"{prefix}_{k}" conditions.append(f"{var}.`{k}` = ${param_key}") params[param_key] = v return " AND ".join(conditions), params # ==================== 全局单例实例(自动初始化)==================== neo4j_client = Neo4jUtil( uri=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE ) # 自动连接(模块导入时执行) if not neo4j_client.connect(): raise RuntimeError("Failed to connect to Neo4j at module load time!")