You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
5.2 KiB
165 lines
5.2 KiB
import json
|
|
import re
|
|
import os
|
|
from neo4j import GraphDatabase
|
|
|
|
# === 配置 ===
|
|
NEO4J_URI = "bolt://localhost:7687"
|
|
NEO4J_USER = "neo4j"
|
|
NEO4J_PASSWORD = "12345678" # 👈 请确保密码正确
|
|
RELATIONSHIP_FOLDER = r"D:\temp\669"
|
|
BATCH_SIZE = 100
|
|
|
|
|
|
def sanitize_relationship_type(rel_type: str) -> str:
|
|
"""清理关系类型,确保合法"""
|
|
if not isinstance(rel_type, str):
|
|
rel_type = str(rel_type)
|
|
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", rel_type)
|
|
if not sanitized or sanitized[0].isdigit():
|
|
sanitized = "REL_" + sanitized
|
|
return sanitized or "RELATED"
|
|
|
|
|
|
def extract_start_end(rel: dict):
|
|
"""兼容多种字段名提取 start/end"""
|
|
for s_key, e_key in [("start", "end"), ("source", "target"), ("from", "to")]:
|
|
s = rel.get(s_key)
|
|
e = rel.get(e_key)
|
|
if s is not None and e is not None:
|
|
return s, e
|
|
return None, None
|
|
|
|
|
|
def load_relationships_from_file(filepath):
|
|
"""从单个 JSON 文件加载 relationships"""
|
|
with open(filepath, "r", encoding="utf-8-sig") as f:
|
|
data = json.load(f)
|
|
|
|
relationships = []
|
|
if isinstance(data, list):
|
|
for item in data:
|
|
if isinstance(item, dict) and "relationships" in item:
|
|
relationships.extend(item["relationships"])
|
|
elif isinstance(item, dict):
|
|
relationships.append(item)
|
|
elif isinstance(data, dict) and "relationships" in data:
|
|
relationships = data["relationships"]
|
|
else:
|
|
relationships = data if isinstance(data, list) else []
|
|
|
|
return relationships
|
|
|
|
|
|
def process_relationships(relationships):
|
|
"""清洗并验证关系列表"""
|
|
valid_rels = []
|
|
for rel in relationships:
|
|
start_id, end_id = extract_start_end(rel)
|
|
rel_type = rel.get("type", "RELATED")
|
|
props = rel.get("properties", {}) or {}
|
|
|
|
if start_id is None or end_id is None:
|
|
continue
|
|
|
|
try:
|
|
start_id = int(float(start_id))
|
|
end_id = int(float(end_id))
|
|
except (TypeError, ValueError):
|
|
continue
|
|
|
|
valid_rels.append({
|
|
"start": start_id,
|
|
"end": end_id,
|
|
"type": sanitize_relationship_type(rel_type),
|
|
"props": props
|
|
})
|
|
return valid_rels
|
|
|
|
|
|
def import_relationships_in_batches(tx, rels, batch_size):
|
|
total = len(rels)
|
|
created_total = 0
|
|
|
|
for i in range(0, total, batch_size):
|
|
batch = rels[i:i + batch_size]
|
|
rel_groups = {}
|
|
for rel in batch:
|
|
rel_groups.setdefault(rel["type"], []).append({
|
|
"start": rel["start"],
|
|
"end": rel["end"],
|
|
"props": rel["props"]
|
|
})
|
|
|
|
created_this_batch = 0
|
|
for rel_type, group in rel_groups.items():
|
|
cypher = f"""
|
|
UNWIND $rels AS r
|
|
MATCH (a {{nodeId: r.start}})
|
|
MATCH (b {{nodeId: r.end}})
|
|
WITH a, b, r
|
|
WHERE a IS NOT NULL AND b IS NOT NULL
|
|
MERGE (a)-[rel:`{rel_type}`]->(b)
|
|
SET rel += r.props
|
|
RETURN count(rel) AS c
|
|
"""
|
|
result = tx.run(cypher, rels=group).single()
|
|
created_this_batch += result["c"]
|
|
|
|
created_total += created_this_batch
|
|
print(f" ➤ 本批创建关系: {created_this_batch} 条")
|
|
|
|
return created_total
|
|
|
|
|
|
def main():
|
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
|
|
|
# 获取所有 JSON 文件,并按文件名排序(确保 relations_001.json 先于 002)
|
|
json_files = [f for f in os.listdir(RELATIONSHIP_FOLDER) if f.startswith("relations_") and f.endswith(".json")]
|
|
json_files.sort() # 按字典序排序,适用于 001, 002... 格式
|
|
|
|
if not json_files:
|
|
print("❌ 文件夹中没有找到 relations_*.json 文件")
|
|
return
|
|
|
|
total_global_created = 0
|
|
total_global_processed = 0
|
|
|
|
print(f"📁 找到 {len(json_files)} 个关系文件,开始逐个导入...\n")
|
|
|
|
for idx, filename in enumerate(json_files, 1):
|
|
filepath = os.path.join(RELATIONSHIP_FOLDER, filename)
|
|
print(f"\n📄 [{idx}/{len(json_files)}] 正在处理: {filename}")
|
|
|
|
try:
|
|
raw_rels = load_relationships_from_file(filepath)
|
|
print(f" ➤ 原始关系数: {len(raw_rels)}")
|
|
|
|
valid_rels = process_relationships(raw_rels)
|
|
print(f" ➤ 有效关系数: {len(valid_rels)}")
|
|
|
|
if not valid_rels:
|
|
print(" ⚠️ 跳过:无有效关系")
|
|
continue
|
|
|
|
with driver.session() as session:
|
|
created = session.execute_write(import_relationships_in_batches, valid_rels, BATCH_SIZE)
|
|
|
|
total_global_created += created
|
|
total_global_processed += len(valid_rels)
|
|
print(f" ✅ 文件 {filename} 导入完成,创建 {created} 条关系")
|
|
|
|
except Exception as e:
|
|
print(f" ❌ 处理 {filename} 时出错: {e}")
|
|
continue # 继续处理下一个文件
|
|
|
|
print("\n" + "="*60)
|
|
print(f"🎉 全部导入完成!")
|
|
print(f"📊 总共处理有效关系: {total_global_processed}")
|
|
print(f"✅ 总共成功创建关系: {total_global_created}")
|
|
driver.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|