|
@@ -10,7 +10,8 @@ def jieba_tokenizer(text):
|
|
|
return list(jieba.cut(text))
|
|
|
# 定义问题模板
|
|
|
template_dict = {
|
|
|
- "1": ["某年全年累计省间交易电量是多少?"],
|
|
|
+ # "1": ["某年全年累计省间交易电量是多少?"],
|
|
|
+ "1": ["某年全年累计XXX是多少?"],
|
|
|
"2": ["某年某月交易电量是多少?"],
|
|
|
"3": ["某年某月到某月累计交易电量是多少?"],
|
|
|
"8.1": ["某年省间交易电量按交易周期划分的电量是多少?"],
|
|
@@ -49,6 +50,8 @@ template_dict = {
|
|
|
"21": ["省间交易当年完成的交易有多少?"],
|
|
|
"22": ["省间交易当年达成的电量有多少?"],
|
|
|
"23": ["省间交易当年参与交易的家次有多少?"],
|
|
|
+ "24": ["某年送出电量前五名是谁?"],
|
|
|
+ "25": ["某年受入电量前五名是谁?"],
|
|
|
}
|
|
|
# 将地点映射成相应的代码
|
|
|
def map_location_to_unit(location: str) -> str:
|
|
@@ -236,17 +239,48 @@ def extract_time_location(question: str) -> Tuple[List[Dict], List[str]]:
|
|
|
# 地点识别
|
|
|
locations = [p for p in provinces if p in question]
|
|
|
|
|
|
- return time_results, locations
|
|
|
+ # 提取“前N名”或“Top N”格式
|
|
|
+ # 中文数字转阿拉伯数字映射(可扩展)
|
|
|
+ chinese_digit_map = {
|
|
|
+ '一': 1, '二': 2, '三': 3, '四': 4, '五': 5,
|
|
|
+ '六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
|
+ }
|
|
|
+
|
|
|
+ # question = "2022年受入电量前五名是谁"
|
|
|
+
|
|
|
+ # 匹配“前五”或“top 5”等形式
|
|
|
+ rank_match = re.search(r'(前|top\s*)(\d+|[一二三四五六七八九十])', question, re.IGNORECASE)
|
|
|
+
|
|
|
+ # rank_match = re.search(
|
|
|
+ # r'(前|top\s*|第\s*)(\d+|[一二三四五六七八九十])\s*(名)?',
|
|
|
+ # question,
|
|
|
+ # re.IGNORECASE
|
|
|
+ # )
|
|
|
+
|
|
|
+ if rank_match:
|
|
|
+ rank_str = rank_match.group(2)
|
|
|
+ if rank_str.isdigit():
|
|
|
+ rank = int(rank_str)
|
|
|
+ else:
|
|
|
+ rank = chinese_digit_map.get(rank_str, None)
|
|
|
+
|
|
|
+ # print(f"匹配到的排名为:{rank}")
|
|
|
+ else:
|
|
|
+ rank = None
|
|
|
+ print("未匹配到排名")
|
|
|
+
|
|
|
+ return time_results, locations, rank
|
|
|
|
|
|
# 先用 extract_time_location 判断问句包含哪类时间信息,然后只对结构匹配的模板子集做余弦匹配。
|
|
|
# def classify_by_time_type(query, time_info):
|
|
|
# if any('start_year' in t and 'end_year' in t for t in time_info):
|
|
|
# return ['3'] # 时间段
|
|
|
# return list(template_dict.keys()) # fallback 所有模板
|
|
|
+#888
|
|
|
def classify_by_time_type(query, time_info):
|
|
|
if not time_info:
|
|
|
# 无时间信息时,返回指定模板 19-23
|
|
|
- return ['19', '20', '21', '22', '23', '17.1', '17.2', '17.3', '17.4', '18.1', '18.2', '18.3', '18.4']
|
|
|
+ return ['19', '20', '21', '22', '23', '17.1', '17.2', '17.3', '17.4', '18.1', '18.2', '18.3', '18.4','24','25']
|
|
|
|
|
|
time = time_info[0]
|
|
|
|
|
@@ -256,17 +290,17 @@ def classify_by_time_type(query, time_info):
|
|
|
|
|
|
# 情况 2:有 year 和 month,精确到月
|
|
|
if 'year' in time and 'month' in time:
|
|
|
- return ['2','16.1','16.2'] # 某年某月交易电量
|
|
|
+ return ['2','16.1','16.2','20'] # 某年某月交易电量
|
|
|
|
|
|
# 情况 3:仅 year,全年
|
|
|
if 'year' in time and 'month' not in time:
|
|
|
- return ['1','8.1','8.2','8.3','8.4','9.1','9.2','9.3','9.4','9.5','9.6','9.7','9.8','9.9','9.10','9.11','9.12','9.13','9.14','9.15','9.16','9.17','16.1','16.2'] # 某年全年累计交易电量
|
|
|
+ return ['1','8.1','8.2','8.3','8.4','9.1','9.2','9.3','9.4','9.5','9.6','9.7','9.8','9.9','9.10','9.11','9.12','9.13','9.14','9.15','9.16','9.17','16.1','16.2','21','22','23','24','25'] # 某年全年累计交易电量
|
|
|
def match_template_with_time_filter(query, template_dict, tokenizer, extract_time_location_func):
|
|
|
"""
|
|
|
先基于时间信息筛选候选模板,再进行TF-IDF匹配。
|
|
|
"""
|
|
|
# 提取时间
|
|
|
- time_info, _ = extract_time_location_func(query)
|
|
|
+ time_info, _, _ = extract_time_location_func(query)
|
|
|
print(time_info)
|
|
|
# 通过时间判断候选模板 key
|
|
|
candidate_keys = classify_by_time_type(query, time_info)
|
|
@@ -341,7 +375,7 @@ def load_template_info(matched_key, json_folder):
|
|
|
return data
|
|
|
def process_query(query, template_dict, json_folder, tokenizer=jieba_tokenizer):
|
|
|
# 提取条件
|
|
|
- time_info, location_info = extract_time_location(query)
|
|
|
+ time_info, location_info, rank_info = extract_time_location(query)
|
|
|
|
|
|
conditions = {}
|
|
|
# 匹配模板
|
|
@@ -388,6 +422,9 @@ def process_query(query, template_dict, json_folder, tokenizer=jieba_tokenizer):
|
|
|
unit = map_location_to_unit(location_info[0])
|
|
|
if unit and unit != '未知单位':
|
|
|
conditions['单位'] = unit
|
|
|
+
|
|
|
+ if rank_info:
|
|
|
+ conditions['rank'] = rank_info
|
|
|
# 查询模板json
|
|
|
template_info = load_template_info(matched_key, json_folder)
|
|
|
# 模板的关键词
|
|
@@ -516,11 +553,12 @@ def find_key_recursively(data, target_key):
|
|
|
# print("动作是:", result["play"])
|
|
|
|
|
|
# query = "当月送出均价最高的是哪个省??"
|
|
|
+# query = ("2025年送出电量前五名是谁??")
|
|
|
#
|
|
|
# json_folder = "templatesJson"
|
|
|
#
|
|
|
# result = process_query(query, template_dict, json_folder)
|
|
|
-#
|
|
|
+
|
|
|
# print("匹配的模板 key:", result["matched_key"])
|
|
|
# print("最相似的模板句:", result["matched_template"])
|
|
|
# print("相似度分数:", result["similarity_score"])
|