|
@@ -373,6 +373,40 @@ def match_template(query, template_dict, tokenizer):
|
|
|
|
|
|
return matched_key, best_match_sentence, similarity_score
|
|
|
|
|
|
+def match_template_all(query, template_dict, tokenizer):
|
|
|
+ """
|
|
|
+ 返回所有模板句与 query 的匹配得分列表,格式为 [(key, 模板句, 相似度分数), ...],按分数降序排列。
|
|
|
+ """
|
|
|
+ templates = []
|
|
|
+ key_map = []
|
|
|
+ for key, sentences in template_dict.items():
|
|
|
+ for s in sentences:
|
|
|
+ templates.append(s)
|
|
|
+ key_map.append(key)
|
|
|
+
|
|
|
+ if not templates:
|
|
|
+ return []
|
|
|
+
|
|
|
+ vectorizer = TfidfVectorizer(tokenizer=tokenizer)
|
|
|
+ tfidf_matrix = vectorizer.fit_transform([query] + templates)
|
|
|
+ cos_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])[0]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for idx, score in enumerate(cos_sim):
|
|
|
+ results.append((key_map[idx], templates[idx], score))
|
|
|
+
|
|
|
+ # 按相似度降序排序
|
|
|
+ results.sort(key=lambda x: x[2], reverse=True)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+def match_template_with_time_filter_all(query, template_dict, tokenizer, extract_time_location_func):
|
|
|
+ time_info, _, _, _ = extract_time_location_func(query)
|
|
|
+ candidate_keys = classify_by_time_type(query, time_info)
|
|
|
+ filtered_template_dict = {k: template_dict[k] for k in candidate_keys if k in template_dict}
|
|
|
+ return match_template_all(query, filtered_template_dict, tokenizer)
|
|
|
+
|
|
|
+
|
|
|
# 根据模板去对应的json文件中找数据
|
|
|
def load_template_info(matched_key, json_folder):
|
|
|
"""
|
|
@@ -405,40 +439,60 @@ def process_query(query, template_dict, json_folder, tokenizer=jieba_tokenizer):
|
|
|
|
|
|
conditions = {}
|
|
|
# 匹配模板
|
|
|
- matched_key, best_sentence, score = match_template_with_time_filter(
|
|
|
- query,
|
|
|
- template_dict,
|
|
|
- tokenizer,
|
|
|
- extract_time_location_func=extract_time_location
|
|
|
- )
|
|
|
- # 定义阈值
|
|
|
+ # matched_key, best_sentence, score = match_template_with_time_filter(
|
|
|
+ # query,
|
|
|
+ # template_dict,
|
|
|
+ # tokenizer,
|
|
|
+ # extract_time_location_func=extract_time_location
|
|
|
+ # )
|
|
|
+
|
|
|
+ # 这里match_template_with_time_filter改为返回所有匹配的列表 [(matched_key, best_sentence, score), ...]
|
|
|
+ all_matches = match_template_with_time_filter_all(query, template_dict, tokenizer,
|
|
|
+ extract_time_location_func=extract_time_location)
|
|
|
+ for idx, (key, sentence, score) in enumerate(all_matches):
|
|
|
+ print(f"排名 {idx + 1}: key={key}, 模板句='{sentence}', 相似度={score:.4f}")
|
|
|
+
|
|
|
+ # 按得分排序(降序)
|
|
|
+ all_matches.sort(key=lambda x: x[2], reverse=True)
|
|
|
+ best_match = all_matches[0]
|
|
|
+ best_score = best_match[2]
|
|
|
+ second_score = all_matches[1][2] if len(all_matches) > 1 else 0
|
|
|
+
|
|
|
+ # 判断阈值
|
|
|
similarity_threshold = 0.25
|
|
|
- # ★ 判断相似度阈值
|
|
|
- if score < similarity_threshold:
|
|
|
+ diff_threshold = 0.05 # 差距阈值,可调
|
|
|
+
|
|
|
+ if best_score < similarity_threshold:
|
|
|
return {
|
|
|
"matched_key": None,
|
|
|
"matched_template": None,
|
|
|
- "similarity_score": score,
|
|
|
- "type": None,
|
|
|
- "keywords": None,
|
|
|
- "target": None,
|
|
|
- "name": None,
|
|
|
- "conditions": conditions,
|
|
|
+ "similarity_score": best_score,
|
|
|
"content": "您提问的问题目前我还没有掌握",
|
|
|
"query": query,
|
|
|
- "play":"疑问"
|
|
|
+ "play": "疑问"
|
|
|
+ }
|
|
|
+
|
|
|
+ if (best_score - second_score) < diff_threshold:
|
|
|
+ # 差距太小,匹配不准确
|
|
|
+ return {
|
|
|
+ "content": "您提问的问题不太准确,我无法理解",
|
|
|
+ "query": query,
|
|
|
+ "play": "疑问",
|
|
|
+ "name": "疑问",
|
|
|
}
|
|
|
|
|
|
+ # 匹配准确,正常返回第一个匹配的模板信息
|
|
|
+ matched_key, best_sentence, score = best_match
|
|
|
+
|
|
|
+ # 下面是你已有的提取条件逻辑
|
|
|
if time_info:
|
|
|
ti = time_info[0]
|
|
|
- # 先判断是否是区间时间(有start_year/end_year等字段)
|
|
|
if 'start_year' in ti and 'end_year' in ti:
|
|
|
conditions['start_year'] = ti.get('start_year')
|
|
|
conditions['start_month'] = ti.get('start_month')
|
|
|
conditions['end_year'] = ti.get('end_year')
|
|
|
conditions['end_month'] = ti.get('end_month')
|
|
|
else:
|
|
|
- # 单时间点
|
|
|
if 'year' in ti:
|
|
|
conditions['年'] = ti['year']
|
|
|
if 'month' in ti:
|
|
@@ -455,28 +509,19 @@ def process_query(query, template_dict, json_folder, tokenizer=jieba_tokenizer):
|
|
|
if rank_info2:
|
|
|
conditions['rank2'] = rank_info2
|
|
|
|
|
|
- # 查询模板json
|
|
|
template_info = load_template_info(matched_key, json_folder)
|
|
|
- # 模板的关键词
|
|
|
keywords = template_info.get("keyword")
|
|
|
- # 模板中的映射关系
|
|
|
target = template_info.get("target")
|
|
|
- # 模板的类型
|
|
|
type_ = template_info.get("type", "")
|
|
|
- # 模板的名字
|
|
|
dataJsonName = template_info.get("dataJsonName", "")
|
|
|
- # ---------------- 比较类 -----------------
|
|
|
value_key = template_info.get("value_key", "")
|
|
|
name_key = template_info.get("name_key", "")
|
|
|
find_max = template_info.get("find_max")
|
|
|
- # block名称
|
|
|
name = template_info.get("name", "")
|
|
|
- # 输出内容
|
|
|
content = template_info.get("content", "")
|
|
|
- # 动作类型
|
|
|
play = template_info.get("play", "")
|
|
|
- # 问题序号
|
|
|
qcode = template_info.get("qcode", "")
|
|
|
+
|
|
|
return {
|
|
|
"matched_key": matched_key,
|
|
|
"matched_template": best_sentence,
|
|
@@ -582,7 +627,7 @@ def find_key_recursively(data, target_key):
|
|
|
# print("动作是:", result["play"])
|
|
|
|
|
|
# query = "当月送出均价最高的是哪个省??"
|
|
|
-# query = ("2025年送出电量前五名是谁??")
|
|
|
+# query = ("交易?")
|
|
|
#
|
|
|
# json_folder = "templatesJson"
|
|
|
#
|
|
@@ -599,6 +644,7 @@ def find_key_recursively(data, target_key):
|
|
|
# print("返回的内容是:", result["content"])
|
|
|
# print("问句是:", result["query"])
|
|
|
# print("动作是:", result["play"])
|
|
|
+# print("描述:", result["content"])
|
|
|
|
|
|
#
|
|
|
# type = result["type"]
|