You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

448 lines
17 KiB

# neo4j_util.py
import logging
from typing import Dict, List, Optional, Any
from neo4j import GraphDatabase, Driver
logger = logging.getLogger(__name__)
# ==================== 配置区(可按需改为从环境变量读取)====================
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "12345678" # 建议后续改为 os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE = "neo4j"
class Neo4jUtil:
def __init__(self, uri: str, username: str, password: str, database: str = "neo4j"):
self.uri = uri
self.username = username
self.password = password
self.database = database
self.driver: Optional[Driver] = None
def connect(self) -> bool:
"""初始化连接"""
try:
self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password))
self.driver.verify_connectivity()
logger.info(f"Neo4jUtil 初始化完成,连接地址: {self.uri}, 数据库: {self.database}")
return True
except Exception as e:
logger.error(f"Neo4j 连接失败: {e}")
return False
def close(self):
"""关闭驱动"""
if self.driver:
self.driver.close()
logger.info("Neo4jUtil 驱动已关闭")
# ==================== 核心执行方法 ====================
def execute_write(self, cypher: str, params: Optional[Dict[str, Any]] = None):
"""执行写操作"""
if not self.driver:
raise RuntimeError("Neo4j 驱动未初始化")
params = params or {}
with self.driver.session(database=self.database) as session:
session.execute_write(
lambda tx: tx.run(cypher, parameters=params).consume()
)
logger.debug(f"执行写操作: {cypher}")
def execute_read(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""执行读操作,返回列表字典"""
if not self.driver:
raise RuntimeError("Neo4j 驱动未初始化")
params = params or {}
with self.driver.session(database=self.database) as session:
result = session.execute_read(
lambda tx: [record.data() for record in tx.run(cypher, parameters=params)]
)
return result
def execute_write_and_return(self, cypher: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""执行写操作并返回结果(如 CREATE ... RETURN)"""
if not self.driver:
raise RuntimeError("Neo4j 驱动未初始化")
params = params or {}
with self.driver.session(database=self.database) as session:
result = session.execute_write(
lambda tx: [record.data() for record in tx.run(cypher, parameters=params)]
)
return result
# ==================== 节点操作 ====================
def insert_node(self, label: str, properties: Dict[str, Any]) -> str:
"""
创建节点,返回 elementId
"""
cypher = f"CREATE (n:`{label}` $props) RETURN elementId(n) AS id"
result = self.execute_write_and_return(cypher, {"props": properties})
return result[0]["id"]
def delete_all_nodes_by_label(self, label: str):
"""删除指定标签的所有节点(含关系)"""
cypher = f"MATCH (n:`{label}`) DETACH DELETE n"
self.execute_write(cypher)
def delete_nodes_by_condition(self, label: str, conditions: Dict[str, Any]):
"""按属性条件删除节点"""
if not conditions:
raise ValueError("删除条件不能为空,防止误删全表!")
where_clause = " AND ".join([f"n.`{k}` = $cond_{k}" for k in conditions])
params = {f"cond_{k}": v for k, v in conditions.items()}
cypher = f"MATCH (n:`{label}`) WHERE {where_clause} DETACH DELETE n"
self.execute_write(cypher, params)
def find_all_nodes(self, label: str) -> List[Dict[str, Any]]:
"""查询所有节点,包含 elementId"""
cypher = f"MATCH (n:`{label}`) RETURN elementId(n) AS id, n{{.*}} AS props"
raw = self.execute_read(cypher)
return [self._merge_id_and_props(row) for row in raw]
def find_nodes_with_element_id(self, label: str, properties: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""按属性查询节点"""
cypher = f"MATCH (n:`{label}`)"
params = {}
if properties:
where_clause = " AND ".join([f"n.`{k}` = $prop_{k}" for k in properties])
params = {f"prop_{k}": v for k, v in properties.items()}
cypher += f" WHERE {where_clause}"
cypher += " RETURN elementId(n) AS id, n{.*} AS props"
raw = self.execute_read(cypher, params)
return [self._merge_id_and_props(row) for row in raw]
def update_node_by_properties(self, label: str, where: Dict[str, Any], updates: Dict[str, Any]):
"""按条件更新节点"""
if not where:
raise ValueError("WHERE 条件不能为空!")
if not updates:
raise ValueError("更新内容不能为空!")
where_clause = " AND ".join([f"n.`{k}` = $where_{k}" for k in where])
set_clause = ", ".join([f"n.`{k}` = $update_{k}" for k in updates])
params = {f"where_{k}": v for k, v in where.items()}
params.update({f"update_{k}": v for k, v in updates.items()})
cypher = f"MATCH (n:`{label}`) WHERE {where_clause} SET {set_clause}"
self.execute_write(cypher, params)
# ==================== 关系操作 ====================
def create_relationship(
self,
source_label: str,
source_props: Dict[str, Any],
target_label: str,
target_props: Dict[str, Any],
rel_type: str,
rel_properties: Optional[Dict[str, Any]] = None,
):
"""创建关系(要求两端节点存在)"""
if not source_props or not target_props:
raise ValueError("源或目标节点条件不能为空")
match_a = self._build_match_clause("a", source_label, source_props, "src")
match_b = self._build_match_clause("b", target_label, target_props, "tgt")
params = {**match_a["params"], **match_b["params"]}
if rel_properties:
rel_part = f"`{rel_type}` $rel_props"
params["rel_props"] = rel_properties
else:
rel_part = f"`{rel_type}`"
cypher = f"""
MATCH {match_a['clause']}, {match_b['clause']}
CREATE (a)-[r:{rel_part}]->(b)
"""
self.execute_write(cypher, params)
def find_all_relationships(self, rel_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""查询所有关系(可选类型)"""
r_label = f":`{rel_type}`" if rel_type else ""
cypher = f"""
MATCH (a)-[r{r_label}]->(b)
RETURN
elementId(r) AS relId,
type(r) AS type,
r{{.*}} AS relProps,
elementId(a) AS sourceId,
head(labels(a)) AS sourceLabel,
a{{.*}} AS sourceProps,
elementId(b) AS targetId,
head(labels(b)) AS targetLabel,
b{{.*}} AS targetProps
"""
raw = self.execute_read(cypher)
return [self._enrich_relationship(row) for row in raw]
def find_relationships_by_condition(
self,
source_label: Optional[str] = None,
source_props: Optional[Dict[str, Any]] = None,
target_label: Optional[str] = None,
target_props: Optional[Dict[str, Any]] = None,
rel_type: Optional[str] = None,
rel_properties: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""按复杂条件查询关系"""
a_label = f":`{source_label}`" if source_label else ""
b_label = f":`{target_label}`" if target_label else ""
r_label = f":`{rel_type}`" if rel_type else ""
match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})"
where_parts = []
params = {}
if source_props:
part, p = self._build_where_conditions("a", source_props, "src")
where_parts.append(part)
params.update(p)
if target_props:
part, p = self._build_where_conditions("b", target_props, "tgt")
where_parts.append(part)
params.update(p)
if rel_properties:
part, p = self._build_where_conditions("r", rel_properties, "rel")
where_parts.append(part)
params.update(p)
cypher = match
if where_parts:
cypher += " WHERE " + " AND ".join(where_parts)
cypher += """
RETURN
elementId(r) AS relId,
type(r) AS type,
r{.*} AS relProps,
elementId(a) AS sourceId,
head(labels(a)) AS sourceLabel,
a{.*} AS sourceProps,
elementId(b) AS targetId,
head(labels(b)) AS targetLabel,
b{.*} AS targetProps
"""
raw = self.execute_read(cypher, params)
return [self._enrich_relationship(row) for row in raw]
def find_neighbors_with_relationships(
self,
node_label: str,
node_properties: Dict[str, Any],
direction: str = "both", # 可选: "out", "in", "both"
rel_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
查询指定节点的所有邻居节点及其关系(包括入边、出边或双向)
Args:
node_label (str): 节点标签
node_properties (Dict[str, Any]): 节点匹配属性(必须能唯一或有效定位节点)
direction (str): 关系方向,"out" 表示 (n)-[r]->(m),"in" 表示 (n)<-[r]-(m),"both" 表示无向
rel_type (Optional[str]): 可选的关系类型过滤
Returns:
List[Dict]: 每项包含 source(原节点)、target(邻居)、relationship 信息
"""
if not node_properties:
raise ValueError("node_properties 不能为空,用于定位起始节点")
# 构建起始节点匹配条件
where_clause, params = self._build_where_conditions("n", node_properties, "node")
rel_filter = f":`{rel_type}`" if rel_type else ""
if direction == "out":
pattern = f"(n:`{node_label}`)-[r{rel_filter}]->(m)"
elif direction == "in":
pattern = f"(n:`{node_label}`)<-[r{rel_filter}]-(m)"
elif direction == "both":
pattern = f"(n:`{node_label}`)-[r{rel_filter}]-(m)"
else:
raise ValueError("direction 必须是 'out', 'in''both'")
cypher = f"""
MATCH {pattern}
WHERE {where_clause}
RETURN
elementId(n) AS sourceId,
head(labels(n)) AS sourceLabel,
n{{.*}} AS sourceProps,
elementId(m) AS targetId,
head(labels(m)) AS targetLabel,
m{{.*}} AS targetProps,
elementId(r) AS relId,
type(r) AS relType,
r{{.*}} AS relProps
"""
raw_results = self.execute_read(cypher, params)
neighbors = []
for row in raw_results:
source = dict(row["sourceProps"])
source.update({"id": row["sourceId"], "label": row["sourceLabel"]})
target = dict(row["targetProps"])
target.update({"id": row["targetId"], "label": row["targetLabel"]})
relationship = {
"id": row["relId"],
"type": row["relType"],
"properties": row["relProps"]
}
neighbors.append({
"source": source,
"target": target,
"relationship": relationship
})
return neighbors
def delete_all_relationships_by_node_label(self, node_label: str):
"""删除某标签节点的所有关系(保留节点)"""
cypher = f"MATCH (n:`{node_label}`)-[r]-() DELETE r"
self.execute_write(cypher)
def delete_all_relationships_by_node_props(self, label: str, properties: Dict[str, Any]):
"""按属性删除某节点的所有关系"""
where_clause, params = self._build_where_conditions("n", properties, "node")
cypher = f"MATCH (n:`{label}`) WHERE {where_clause} MATCH (n)-[r]-() DELETE r"
self.execute_write(cypher, params)
def delete_relationships_advanced(
self,
source_label: Optional[str] = None,
source_props: Optional[Dict[str, Any]] = None,
target_label: Optional[str] = None,
target_props: Optional[Dict[str, Any]] = None,
rel_type: Optional[str] = None,
rel_properties: Optional[Dict[str, Any]] = None,
):
"""高级删除关系"""
a_label = f":`{source_label}`" if source_label else ""
b_label = f":`{target_label}`" if target_label else ""
r_label = f":`{rel_type}`" if rel_type else ""
match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})"
where_parts = []
params = {}
if source_props:
part, p = self._build_where_conditions("a", source_props, "src")
where_parts.append(part)
params.update(p)
if target_props:
part, p = self._build_where_conditions("b", target_props, "tgt")
where_parts.append(part)
params.update(p)
if rel_properties:
part, p = self._build_where_conditions("r", rel_properties, "rel")
where_parts.append(part)
params.update(p)
cypher = match
if where_parts:
cypher += " WHERE " + " AND ".join(where_parts)
cypher += " DELETE r"
self.execute_write(cypher, params)
def update_relationship(
self,
source_label: str,
source_props: Dict[str, Any],
target_label: str,
target_props: Dict[str, Any],
rel_type: Optional[str] = None,
new_rel_properties: Optional[Dict[str, Any]] = None,
):
"""更新关系属性"""
if not new_rel_properties:
raise ValueError("至少需要提供一个要更新的关系属性")
a_label = f":`{source_label}`"
b_label = f":`{target_label}`"
r_label = f":`{rel_type}`" if rel_type else ""
match = f"MATCH (a{a_label})-[r{r_label}]->(b{b_label})"
where_a, p_a = self._build_where_conditions("a", source_props, "src")
where_b, p_b = self._build_where_conditions("b", target_props, "tgt")
where_clause = f"{where_a} AND {where_b}"
params = {**p_a, **p_b}
set_clause = ", ".join([f"r.`{k}` = $rel_update_{k}" for k in new_rel_properties])
for k, v in new_rel_properties.items():
params[f"rel_update_{k}"] = v
cypher = f"{match} WHERE {where_clause} SET {set_clause}"
self.execute_write(cypher, params)
# ==================== 内部辅助方法 ====================
def _merge_id_and_props(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""合并 id 和 props"""
props = dict(row.get("props", {}))
props["id"] = row["id"]
return props
def _enrich_relationship(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""格式化关系结果"""
source = dict(row["sourceProps"])
source.update({"id": row["sourceId"], "label": row["sourceLabel"]})
target = dict(row["targetProps"])
target.update({"id": row["targetId"], "label": row["targetLabel"]})
return {
"relId": row["relId"],
"type": row["type"],
"relProps": row["relProps"],
"source": source,
"target": target,
}
def _build_match_clause(self, var: str, label: str, props: Dict[str, Any], prefix: str) -> Dict[str, Any]:
"""
构建合法的 MATCH 节点子句,属性必须在 () 内部。
示例输出: (a:`Disease` {`name`: $src_name})
"""
if not props:
clause = f"({var}:`{label}`)"
return {"clause": clause, "params": {}}
attr_parts = []
params = {}
for k, v in props.items():
param_key = f"{prefix}_{k}"
attr_parts.append(f"`{k}`: ${param_key}")
params[param_key] = v
attrs_str = ", ".join(attr_parts)
clause = f"({var}:`{label}` {{{attrs_str}}})"
return {"clause": clause, "params": params}
def _build_where_conditions(self, var: str, props: Dict[str, Any], prefix: str) -> tuple[str, Dict[str, Any]]:
"""生成 WHERE 条件字符串和参数"""
if not props:
return "1=1", {}
conditions = []
params = {}
for k, v in props.items():
param_key = f"{prefix}_{k}"
conditions.append(f"{var}.`{k}` = ${param_key}")
params[param_key] = v
return " AND ".join(conditions), params
# ==================== 全局单例实例(自动初始化)====================
neo4j_client = Neo4jUtil(
uri=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
database=NEO4J_DATABASE
)
# 自动连接(模块导入时执行)
if not neo4j_client.connect():
raise RuntimeError("Failed to connect to Neo4j at module load time!")