|
@@ -0,0 +1,215 @@
|
|
|
+import json
|
|
|
+import os
|
|
|
+import re
|
|
|
+
|
|
|
+from final.ByRules.similarity_answer_json import load_template_info
|
|
|
+
|
|
|
+
|
|
|
+class EntityExtractor:
|
|
|
+ def __init__(self):
|
|
|
+ self.entity_patterns = {
|
|
|
+ "time": [
|
|
|
+ r'\d{4}年\d{1,2}月到\d{4}年\d{1,2}月',
|
|
|
+ r'\d{4}年\d{1,2}月到\d{1,2}月',
|
|
|
+ r'\d{4}年\d{1,2}月',
|
|
|
+ r'\d{4}年',
|
|
|
+ r'\d{1,2}月',
|
|
|
+ r'去年|今年|前年|明年|上个月|本月|上季度|下季度',
|
|
|
+ r'过去\d+年|过去\d+个月|近\d+年|近\d+个月'
|
|
|
+ ],
|
|
|
+ "market": ["山东省", "河南省", "山西省", "江苏省"],
|
|
|
+ "indicator": {
|
|
|
+ "1": "累计省间交易电量",
|
|
|
+ "2": "交易电量",
|
|
|
+ },
|
|
|
+ "calculation": ["累计", "均值", "是多少", "有多少"],
|
|
|
+ "constraint": [
|
|
|
+ r'最高|最低|超过|低于|不少于',
|
|
|
+ r'第[零一二三四五六七八九十百千万\d]+名到第[零一二三四五六七八九十百千万\d]+名',
|
|
|
+ r'第[零一二三四五六七八九十百千万\d]+名',
|
|
|
+ r'前[零一二三四五六七八九十百千万\d]+名'
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ self.states = ["time", "market", "indicator", "calculation", "constraint"]
|
|
|
+
|
|
|
+ # 这里取出 indicator 字典的所有指标名,方便匹配
|
|
|
+ self.indicator_dict = self.entity_patterns["indicator"]
|
|
|
+ self.indicator_list = list(self.indicator_dict.values())
|
|
|
+
|
|
|
+ def extract(self, text):
|
|
|
+ # 用question来保留问句
|
|
|
+ question = text
|
|
|
+ extracted_entities = {state: [] for state in self.states}
|
|
|
+
|
|
|
+ for state in self.states:
|
|
|
+ # indicator 不用按正则匹配,单独处理
|
|
|
+ if state == "indicator":
|
|
|
+ matched_indicators = []
|
|
|
+ for indicator_name in self.indicator_list:
|
|
|
+ if indicator_name in text:
|
|
|
+ matched_indicators.append(indicator_name)
|
|
|
+ text = text.replace(indicator_name, "") # 去掉已匹配,防止重复匹配
|
|
|
+ extracted_entities[state] = matched_indicators
|
|
|
+ else:
|
|
|
+ patterns = self.entity_patterns[state]
|
|
|
+ for pattern in patterns:
|
|
|
+ matches = re.findall(pattern, text)
|
|
|
+ if matches:
|
|
|
+ extracted_entities[state].extend(matches)
|
|
|
+ for match in matches:
|
|
|
+ text = text.replace(match, "")
|
|
|
+
|
|
|
+ conditions = {}
|
|
|
+ # 时间实体处理
|
|
|
+ extracted_entities["time"] = self._process_time_entities(extracted_entities["time"])
|
|
|
+ time_entity = extracted_entities["time"]
|
|
|
+
|
|
|
+ # 如果是list,取第一个元素
|
|
|
+ if isinstance(time_entity, list):
|
|
|
+ ti = time_entity[0]
|
|
|
+ else:
|
|
|
+ ti = time_entity
|
|
|
+
|
|
|
+ # 处理具体时间字段
|
|
|
+ if 'start_year' in ti and 'end_year' in ti:
|
|
|
+ conditions['start_year'] = ti.get('start_year')
|
|
|
+ conditions['start_month'] = ti.get('start_month')
|
|
|
+ conditions['end_year'] = ti.get('end_year')
|
|
|
+ conditions['end_month'] = ti.get('end_month')
|
|
|
+ else:
|
|
|
+ if 'year' in ti:
|
|
|
+ conditions['年'] = ti['year']
|
|
|
+ if 'month' in ti:
|
|
|
+ conditions['月'] = ti['month']
|
|
|
+
|
|
|
+ # 指标实体处理,转换为 key: name 字典形式
|
|
|
+ extracted_entities["indicator"] = self._process_indicator_entities(extracted_entities["indicator"])
|
|
|
+ print(extracted_entities["indicator"])
|
|
|
+ # 获取序列号
|
|
|
+ keys = extracted_entities["indicator"].keys()
|
|
|
+ keys_list = list(keys)
|
|
|
+ key = keys_list[0]
|
|
|
+ print(key)
|
|
|
+ # 获取templatesJson中的数据
|
|
|
+ json_folder = "templatesJson"
|
|
|
+ template_info = load_template_info(key, json_folder)
|
|
|
+ keywords = template_info.get("keyword")
|
|
|
+ target = template_info.get("target")
|
|
|
+ type_ = template_info.get("type", "")
|
|
|
+ dataJsonName = template_info.get("dataJsonName", "")
|
|
|
+ value_key = template_info.get("value_key", "")
|
|
|
+ name_key = template_info.get("name_key", "")
|
|
|
+ find_max = template_info.get("find_max")
|
|
|
+ name = template_info.get("name", "")
|
|
|
+ content = template_info.get("content", "")
|
|
|
+ play = template_info.get("play", "")
|
|
|
+ qcode = template_info.get("qcode", "")
|
|
|
+ unit = template_info.get("unit", "")
|
|
|
+ flag = template_info.get("flag", "")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "type": type_,
|
|
|
+ "keywords": keywords,
|
|
|
+ "target": target,
|
|
|
+ "dataJsonName": dataJsonName,
|
|
|
+ "name": name,
|
|
|
+ "conditions": conditions,
|
|
|
+ "content": content,
|
|
|
+ "query": question,
|
|
|
+ "play": play,
|
|
|
+ "find_max": find_max,
|
|
|
+ "value_key": value_key,
|
|
|
+ "name_key": name_key,
|
|
|
+ "qcode": qcode,
|
|
|
+ "unit": unit,
|
|
|
+ "flag": flag
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ return extracted_entities
|
|
|
+
|
|
|
+ def _process_time_entities(self, time_entities):
|
|
|
+ conditions = []
|
|
|
+ for time_str in time_entities:
|
|
|
+ if re.match(r'\d{4}年\d{1,2}月到\d{4}年\d{1,2}月', time_str):
|
|
|
+ m = re.match(r'(\d{4})年(\d{1,2})月到(\d{4})年(\d{1,2})月', time_str)
|
|
|
+ if m:
|
|
|
+ conditions.append({
|
|
|
+ "start_year": int(m.group(1)),
|
|
|
+ "start_month": int(m.group(2)),
|
|
|
+ "end_year": int(m.group(3)),
|
|
|
+ "end_month": int(m.group(4))
|
|
|
+ })
|
|
|
+ elif re.match(r'\d{4}年\d{1,2}月到\d{1,2}月', time_str):
|
|
|
+ m = re.match(r'(\d{4})年(\d{1,2})月到(\d{1,2})月', time_str)
|
|
|
+ if m:
|
|
|
+ conditions.append({
|
|
|
+ "start_year": int(m.group(1)),
|
|
|
+ "start_month": int(m.group(2)),
|
|
|
+ "end_year": int(m.group(1)), # 同年
|
|
|
+ "end_month": int(m.group(3))
|
|
|
+ })
|
|
|
+ elif re.match(r'\d{4}年\d{1,2}月', time_str):
|
|
|
+ m = re.match(r'(\d{4})年(\d{1,2})月', time_str)
|
|
|
+ if m:
|
|
|
+ conditions.append({
|
|
|
+ "year": int(m.group(1)),
|
|
|
+ "month": int(m.group(2))
|
|
|
+ })
|
|
|
+ elif re.match(r'\d{4}年', time_str):
|
|
|
+ m = re.match(r'(\d{4})年', time_str)
|
|
|
+ if m:
|
|
|
+ conditions.append({
|
|
|
+ "year": int(m.group(1))
|
|
|
+ })
|
|
|
+ elif re.match(r'\d{1,2}月', time_str):
|
|
|
+ m = re.match(r'(\d{1,2})月', time_str)
|
|
|
+ if m:
|
|
|
+ conditions.append({
|
|
|
+ "month": int(m.group(1))
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ conditions.append({"text": time_str})
|
|
|
+
|
|
|
+ # 如果只提取到一个时间实体,直接返回字典,否则返回列表
|
|
|
+ if len(conditions) == 1:
|
|
|
+ return conditions[0]
|
|
|
+ else:
|
|
|
+ return conditions
|
|
|
+
|
|
|
+ def _process_indicator_entities(self, indicators):
|
|
|
+ # 返回 key: indicator_name 字典
|
|
|
+ # 根据已匹配的指标名称反查 key
|
|
|
+ result = {}
|
|
|
+ for indicator in indicators:
|
|
|
+ for key, name in self.indicator_dict.items():
|
|
|
+ if name == indicator:
|
|
|
+ result[key] = name
|
|
|
+ break
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# ==== 测试 ====
|
|
|
+extractor = EntityExtractor()
|
|
|
+# # question = "请问2023年5月到2024年3月、2024年5月山东省、河南省的交易电量累计是多少?"
|
|
|
+# # question = "2024年全年累计省间交易电量是多少??"
|
|
|
+# question = "2023年12月交易电量是多少?"
|
|
|
+# question = "2024年1月到3月累计交易电量是多少?"
|
|
|
+question = "2023年省间交易电量按交易周期划分的电量是多少?"
|
|
|
+# # question = "2023年省间交易电量年度交易电量是多少?"
|
|
|
+# # question = "今年哪个省火电送出电量最多?分别是多少?"
|
|
|
+# # question = "哪个省送出电量最高?是多少?"
|
|
|
+# # question = "省间交易正在组织的交易有多少?"
|
|
|
+# # question = "2024年送出电量前十名的省份是?"
|
|
|
+#
|
|
|
+result = extractor.extract(question)
|
|
|
+
|
|
|
+print("类型:", result["type"])
|
|
|
+print("关键词:", result["keywords"])
|
|
|
+print("查询字段:", result["target"])
|
|
|
+print("模型名字", result["name"])
|
|
|
+print("条件", result["conditions"])
|
|
|
+print("返回的内容是:", result["content"])
|
|
|
+print("问句是:", result["query"])
|
|
|
+print("动作是:", result["play"])
|
|
|
+print("描述:", result["content"])
|