4 changed files with 156 additions and 6 deletions
@ -0,0 +1,48 @@ |
|||||
|
import json |
||||
|
|
||||
|
UNIFIED_PROMPT_TEMPLATE = ( |
||||
|
"你是一个医疗知识图谱构建专家。请从以下文本中:\n" |
||||
|
"1. 提取所有医学实体(去重),仅返回名称列表;\n" |
||||
|
"2. 在这些实体之间抽取高质量、术语化的语义关系三元组。\n\n" |
||||
|
"### 输出规则\n" |
||||
|
"- 实体类型无需标注,只输出名称字符串(如 \"慢性淋巴细胞白血病\")。\n" |
||||
|
"- 关系谓词必须是专业术语(2~6字),如:临床表现、诊断、相关疾病、禁忌症、治疗药物等。\n" |
||||
|
"- e1 和 e2 必须来自提取出的实体列表,且 e1 ≠ e2。\n" |
||||
|
"- 输出必须是纯 JSON,仅包含两个字段:\"entities\"(字符串列表)和 \"relations\"(对象列表,每个含 e1/r/e2)。\n" |
||||
|
"- 不要任何额外文本、解释或 Markdown。\n\n" |
||||
|
"文本:{input}\n\n输出:" |
||||
|
) |
||||
|
|
||||
|
with open("test_data.jsonl", "r", encoding="utf-8") as fin, \ |
||||
|
open("sft_messages_format.jsonl", "w", encoding="utf-8") as fout: |
||||
|
|
||||
|
for line in fin: |
||||
|
line = line.strip() |
||||
|
if not line: |
||||
|
continue |
||||
|
try: |
||||
|
item = json.loads(line) |
||||
|
input_text = item["input"] |
||||
|
output_obj = item["output"] |
||||
|
|
||||
|
# system prompt 中的 {input} 占位符替换(可选,也可保留原样) |
||||
|
# 这里按你要求:system 保持模板不变,user 才放真实 input |
||||
|
system_content = UNIFIED_PROMPT_TEMPLATE # 不替换 {input} |
||||
|
|
||||
|
user_content = input_text |
||||
|
# assistant content 必须是 JSON 字符串(带转义) |
||||
|
assistant_content = json.dumps(output_obj, ensure_ascii=False) |
||||
|
|
||||
|
messages = [ |
||||
|
{"role": "system", "content": system_content}, |
||||
|
{"role": "user", "content": user_content}, |
||||
|
{"role": "assistant", "content": assistant_content} |
||||
|
] |
||||
|
|
||||
|
fout.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n") |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"处理出错: {e}") |
||||
|
continue |
||||
|
|
||||
|
print("✅ 转换完成!文件已保存为 sft_messages_format.jsonl") |
||||
@ -0,0 +1,66 @@ |
|||||
|
# cmekg_aligner.py |
||||
|
|
||||
|
from difflib import SequenceMatcher |
||||
|
from typing import List, Dict, Optional, Tuple |
||||
|
|
||||
|
class CMeKGAligner: |
||||
|
def __init__(self, uri, user, password): |
||||
|
from neo4j import GraphDatabase |
||||
|
self.driver = GraphDatabase.driver(uri, auth=(user, password)) |
||||
|
|
||||
|
def infer_type_by_name(self, name: str) -> str: |
||||
|
# 你的类型推断逻辑(保持不变) |
||||
|
if "片" in name or "胶囊" in name or "注射" in name: |
||||
|
return "Drug" |
||||
|
# ... 其他规则 |
||||
|
return "Unknown" |
||||
|
|
||||
|
def find_entities_batch(self, terms: List[str]) -> Dict[str, Optional[Tuple[str, str]]]: |
||||
|
""" |
||||
|
批量对齐上万条实体,仅一次数据库查询 |
||||
|
:param terms: 原始术语列表(允许重复) |
||||
|
:return: {原始词: (标准名, 类型) 或 None} |
||||
|
""" |
||||
|
if not terms: |
||||
|
return {} |
||||
|
|
||||
|
# 1. 去重并保留顺序(可选) |
||||
|
unique_terms = list(dict.fromkeys(terms)) # 保持首次出现顺序 |
||||
|
|
||||
|
# 2. 一次性从 Neo4j 获取所有可能的候选实体 |
||||
|
with self.driver.session() as session: |
||||
|
result = session.run( |
||||
|
""" |
||||
|
UNWIND $terms AS input_name |
||||
|
MATCH (e) |
||||
|
WHERE toLower(e.name) CONTAINS toLower(input_name) |
||||
|
OR toLower(input_name) CONTAINS toLower(e.name) |
||||
|
RETURN input_name, e.name AS std_name |
||||
|
""", |
||||
|
terms=unique_terms |
||||
|
) |
||||
|
# 构建 {input_name: [std_name1, std_name2, ...]} |
||||
|
candidates_map = {} |
||||
|
for record in result: |
||||
|
inp = record["input_name"] |
||||
|
std = record["std_name"] |
||||
|
if inp not in candidates_map: |
||||
|
candidates_map[inp] = [] |
||||
|
candidates_map[inp].append(std) |
||||
|
|
||||
|
# 3. 对每个输入词,从候选中选最相似的标准名 |
||||
|
output = {} |
||||
|
for term in terms: # 遍历原始列表(含重复) |
||||
|
if term in output: # 已处理过(因重复) |
||||
|
continue |
||||
|
|
||||
|
candidates = candidates_map.get(term, []) |
||||
|
if not candidates: |
||||
|
output[term] = None |
||||
|
else: |
||||
|
# 选与 term 最相似的标准名 |
||||
|
best_std = max(candidates, key=lambda x: SequenceMatcher(None, term, x).ratio()) |
||||
|
entity_type = self.infer_type_by_name(best_std) |
||||
|
output[term] = (best_std, entity_type) |
||||
|
|
||||
|
return output |
||||
@ -0,0 +1,36 @@ |
|||||
|
# batch_test.py |
||||
|
from cmekg_aligner import CMeKGAligner |
||||
|
|
||||
|
aligner = CMeKGAligner( |
||||
|
uri="bolt://localhost:7687", |
||||
|
user="neo4j", |
||||
|
password="your_password" |
||||
|
) |
||||
|
|
||||
|
# 模拟上万条数据(实际可从文件读取) |
||||
|
with open("input_terms.txt", "r", encoding="utf-8") as f: |
||||
|
terms = [line.strip() for line in f if line.strip()] |
||||
|
|
||||
|
print(f"🔍 开始批量对齐 {len(terms)} 条实体...") |
||||
|
|
||||
|
results = aligner.find_entities_batch(terms) |
||||
|
|
||||
|
# 输出结果 |
||||
|
for term in terms[:10]: # 只打印前10条示例 |
||||
|
res = results[term] |
||||
|
if res: |
||||
|
print(f"✅ '{term}' → '{res[0]}', {res[1]}") |
||||
|
else: |
||||
|
print(f"❌ '{term}' → 未匹配") |
||||
|
|
||||
|
# 可选:保存到 CSV |
||||
|
import csv |
||||
|
with open("batch_alignment_result.csv", "w", encoding="utf-8", newline="") as f: |
||||
|
writer = csv.writer(f) |
||||
|
writer.writerow(["原始词", "标准名", "类型"]) |
||||
|
for term in terms: |
||||
|
res = results[term] |
||||
|
if res: |
||||
|
writer.writerow([term, res[0], res[1]]) |
||||
|
else: |
||||
|
writer.writerow([term, "", ""]) |
||||
Loading…
Reference in new issue