|
|
@ -205,6 +205,93 @@ class Neo4jUtil: |
|
|
raw = self.execute_read(cypher) |
|
|
raw = self.execute_read(cypher) |
|
|
return [self._enrich_relationship(row) for row in raw] |
|
|
return [self._enrich_relationship(row) for row in raw] |
|
|
|
|
|
|
|
|
|
|
|
def find_neighbors_with_relationshipsAI( |
|
|
|
|
|
self, |
|
|
|
|
|
node_label: Optional[str], |
|
|
|
|
|
node_properties: Dict[str, Any], |
|
|
|
|
|
direction: str = "both", |
|
|
|
|
|
rel_type: Optional[str] = None, |
|
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]: |
|
|
|
|
|
""" |
|
|
|
|
|
查询指定节点的所有邻居,返回扁平化的 nodes 和 relationships(无数量限制) |
|
|
|
|
|
""" |
|
|
|
|
|
if not node_properties: |
|
|
|
|
|
raise ValueError("node_properties 不能为空,用于定位起始节点") |
|
|
|
|
|
|
|
|
|
|
|
where_clause, params = self._build_where_conditions("n", node_properties, "node") |
|
|
|
|
|
rel_filter = f":`{rel_type}`" if rel_type else "" |
|
|
|
|
|
|
|
|
|
|
|
if node_label is not None: |
|
|
|
|
|
node_pattern = f"(n:`{node_label}`)" |
|
|
|
|
|
else: |
|
|
|
|
|
node_pattern = "(n)" |
|
|
|
|
|
|
|
|
|
|
|
if direction == "out": |
|
|
|
|
|
pattern = f"{node_pattern}-[r{rel_filter}]->(m)" |
|
|
|
|
|
elif direction == "in": |
|
|
|
|
|
pattern = f"{node_pattern}<[r{rel_filter}]-(m)" |
|
|
|
|
|
elif direction == "both": |
|
|
|
|
|
pattern = f"{node_pattern}-[r{rel_filter}]-(m)" |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("direction 必须是 'out', 'in' 或 'both'") |
|
|
|
|
|
|
|
|
|
|
|
cypher = f""" |
|
|
|
|
|
MATCH {pattern} |
|
|
|
|
|
WHERE {where_clause} |
|
|
|
|
|
WITH r, startNode(r) AS s, endNode(r) AS t |
|
|
|
|
|
RETURN |
|
|
|
|
|
elementId(s) AS sourceId, |
|
|
|
|
|
head(labels(s)) AS sourceLabel, |
|
|
|
|
|
s{{.*}} AS sourceProps, |
|
|
|
|
|
|
|
|
|
|
|
elementId(t) AS targetId, |
|
|
|
|
|
head(labels(t)) AS targetLabel, |
|
|
|
|
|
t{{.*}} AS targetProps, |
|
|
|
|
|
|
|
|
|
|
|
elementId(r) AS relId, |
|
|
|
|
|
type(r) AS relType, |
|
|
|
|
|
r{{.*}} AS relProps |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
raw_results = self.execute_read(cypher, params) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
raise RuntimeError(f"查询邻居节点时发生数据库错误: {str(e)}") from e |
|
|
|
|
|
|
|
|
|
|
|
# 用字典去重节点(key 为 id) |
|
|
|
|
|
node_dict = {} |
|
|
|
|
|
relationships = [] |
|
|
|
|
|
|
|
|
|
|
|
for row in raw_results: |
|
|
|
|
|
# 处理 source 节点 |
|
|
|
|
|
source_id = row["sourceId"] |
|
|
|
|
|
if source_id not in node_dict: |
|
|
|
|
|
source_node = dict(row["sourceProps"]) |
|
|
|
|
|
source_node.update({"id": source_id, "label": row["sourceLabel"]}) |
|
|
|
|
|
node_dict[source_id] = source_node |
|
|
|
|
|
|
|
|
|
|
|
# 处理 target 节点 |
|
|
|
|
|
target_id = row["targetId"] |
|
|
|
|
|
if target_id not in node_dict: |
|
|
|
|
|
target_node = dict(row["targetProps"]) |
|
|
|
|
|
target_node.update({"id": target_id, "label": row["targetLabel"]}) |
|
|
|
|
|
node_dict[target_id] = target_node |
|
|
|
|
|
|
|
|
|
|
|
# 处理 relationship |
|
|
|
|
|
rel = { |
|
|
|
|
|
"id": row["relId"], |
|
|
|
|
|
"type": row["relType"], |
|
|
|
|
|
"sourceId": source_id, |
|
|
|
|
|
"targetId": target_id, |
|
|
|
|
|
"properties": dict(row["relProps"]) if row["relProps"] else {} |
|
|
|
|
|
} |
|
|
|
|
|
relationships.append(rel) |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
|
|
|
"nodes": list(node_dict.values()), |
|
|
|
|
|
"relationships": relationships |
|
|
|
|
|
} |
|
|
def find_relationships_by_condition( |
|
|
def find_relationships_by_condition( |
|
|
self, |
|
|
self, |
|
|
source_label: Optional[str] = None, |
|
|
source_label: Optional[str] = None, |
|
|
|