You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.5 KiB
66 lines
2.5 KiB
# cmekg_aligner.py
|
|
|
|
from difflib import SequenceMatcher
|
|
from typing import List, Dict, Optional, Tuple
|
|
|
|
class CMeKGAligner:
|
|
def __init__(self, uri, user, password):
|
|
from neo4j import GraphDatabase
|
|
self.driver = GraphDatabase.driver(uri, auth=(user, password))
|
|
|
|
def infer_type_by_name(self, name: str) -> str:
|
|
# 你的类型推断逻辑(保持不变)
|
|
if "片" in name or "胶囊" in name or "注射" in name:
|
|
return "Drug"
|
|
# ... 其他规则
|
|
return "Unknown"
|
|
|
|
def find_entities_batch(self, terms: List[str]) -> Dict[str, Optional[Tuple[str, str]]]:
|
|
"""
|
|
批量对齐上万条实体,仅一次数据库查询
|
|
:param terms: 原始术语列表(允许重复)
|
|
:return: {原始词: (标准名, 类型) 或 None}
|
|
"""
|
|
if not terms:
|
|
return {}
|
|
|
|
# 1. 去重并保留顺序(可选)
|
|
unique_terms = list(dict.fromkeys(terms)) # 保持首次出现顺序
|
|
|
|
# 2. 一次性从 Neo4j 获取所有可能的候选实体
|
|
with self.driver.session() as session:
|
|
result = session.run(
|
|
"""
|
|
UNWIND $terms AS input_name
|
|
MATCH (e)
|
|
WHERE toLower(e.name) CONTAINS toLower(input_name)
|
|
OR toLower(input_name) CONTAINS toLower(e.name)
|
|
RETURN input_name, e.name AS std_name
|
|
""",
|
|
terms=unique_terms
|
|
)
|
|
# 构建 {input_name: [std_name1, std_name2, ...]}
|
|
candidates_map = {}
|
|
for record in result:
|
|
inp = record["input_name"]
|
|
std = record["std_name"]
|
|
if inp not in candidates_map:
|
|
candidates_map[inp] = []
|
|
candidates_map[inp].append(std)
|
|
|
|
# 3. 对每个输入词,从候选中选最相似的标准名
|
|
output = {}
|
|
for term in terms: # 遍历原始列表(含重复)
|
|
if term in output: # 已处理过(因重复)
|
|
continue
|
|
|
|
candidates = candidates_map.get(term, [])
|
|
if not candidates:
|
|
output[term] = None
|
|
else:
|
|
# 选与 term 最相似的标准名
|
|
best_std = max(candidates, key=lambda x: SequenceMatcher(None, term, x).ratio())
|
|
entity_type = self.infer_type_by_name(best_std)
|
|
output[term] = (best_std, entity_type)
|
|
|
|
return output
|