|
@@ -1,215 +0,0 @@
|
|
|
-import json
|
|
|
-import os
|
|
|
-import re
|
|
|
-
|
|
|
-from final.ByRules.similarity_answer_json import load_template_info
|
|
|
-
|
|
|
-
|
|
|
-class EntityExtractor:
|
|
|
- def __init__(self):
|
|
|
- self.entity_patterns = {
|
|
|
- "time": [
|
|
|
- r'\d{4}年\d{1,2}月到\d{4}年\d{1,2}月',
|
|
|
- r'\d{4}年\d{1,2}月到\d{1,2}月',
|
|
|
- r'\d{4}年\d{1,2}月',
|
|
|
- r'\d{4}年',
|
|
|
- r'\d{1,2}月',
|
|
|
- r'去年|今年|前年|明年|上个月|本月|上季度|下季度',
|
|
|
- r'过去\d+年|过去\d+个月|近\d+年|近\d+个月'
|
|
|
- ],
|
|
|
- "market": ["山东省", "河南省", "山西省", "江苏省"],
|
|
|
- "indicator": {
|
|
|
- "1": "累计省间交易电量",
|
|
|
- "2": "交易电量",
|
|
|
- },
|
|
|
- "calculation": ["累计", "均值", "是多少", "有多少"],
|
|
|
- "constraint": [
|
|
|
- r'最高|最低|超过|低于|不少于',
|
|
|
- r'第[零一二三四五六七八九十百千万\d]+名到第[零一二三四五六七八九十百千万\d]+名',
|
|
|
- r'第[零一二三四五六七八九十百千万\d]+名',
|
|
|
- r'前[零一二三四五六七八九十百千万\d]+名'
|
|
|
- ]
|
|
|
- }
|
|
|
- self.states = ["time", "market", "indicator", "calculation", "constraint"]
|
|
|
-
|
|
|
- # 这里取出 indicator 字典的所有指标名,方便匹配
|
|
|
- self.indicator_dict = self.entity_patterns["indicator"]
|
|
|
- self.indicator_list = list(self.indicator_dict.values())
|
|
|
-
|
|
|
- def extract(self, text):
|
|
|
- # 用question来保留问句
|
|
|
- question = text
|
|
|
- extracted_entities = {state: [] for state in self.states}
|
|
|
-
|
|
|
- for state in self.states:
|
|
|
- # indicator 不用按正则匹配,单独处理
|
|
|
- if state == "indicator":
|
|
|
- matched_indicators = []
|
|
|
- for indicator_name in self.indicator_list:
|
|
|
- if indicator_name in text:
|
|
|
- matched_indicators.append(indicator_name)
|
|
|
- text = text.replace(indicator_name, "") # 去掉已匹配,防止重复匹配
|
|
|
- extracted_entities[state] = matched_indicators
|
|
|
- else:
|
|
|
- patterns = self.entity_patterns[state]
|
|
|
- for pattern in patterns:
|
|
|
- matches = re.findall(pattern, text)
|
|
|
- if matches:
|
|
|
- extracted_entities[state].extend(matches)
|
|
|
- for match in matches:
|
|
|
- text = text.replace(match, "")
|
|
|
-
|
|
|
- conditions = {}
|
|
|
- # 时间实体处理
|
|
|
- extracted_entities["time"] = self._process_time_entities(extracted_entities["time"])
|
|
|
- time_entity = extracted_entities["time"]
|
|
|
-
|
|
|
- # 如果是list,取第一个元素
|
|
|
- if isinstance(time_entity, list):
|
|
|
- ti = time_entity[0]
|
|
|
- else:
|
|
|
- ti = time_entity
|
|
|
-
|
|
|
- # 处理具体时间字段
|
|
|
- if 'start_year' in ti and 'end_year' in ti:
|
|
|
- conditions['start_year'] = ti.get('start_year')
|
|
|
- conditions['start_month'] = ti.get('start_month')
|
|
|
- conditions['end_year'] = ti.get('end_year')
|
|
|
- conditions['end_month'] = ti.get('end_month')
|
|
|
- else:
|
|
|
- if 'year' in ti:
|
|
|
- conditions['年'] = ti['year']
|
|
|
- if 'month' in ti:
|
|
|
- conditions['月'] = ti['month']
|
|
|
-
|
|
|
- # 指标实体处理,转换为 key: name 字典形式
|
|
|
- extracted_entities["indicator"] = self._process_indicator_entities(extracted_entities["indicator"])
|
|
|
- print(extracted_entities["indicator"])
|
|
|
- # 获取序列号
|
|
|
- keys = extracted_entities["indicator"].keys()
|
|
|
- keys_list = list(keys)
|
|
|
- key = keys_list[0]
|
|
|
- print(key)
|
|
|
- # 获取templatesJson中的数据
|
|
|
- json_folder = "templatesJson"
|
|
|
- template_info = load_template_info(key, json_folder)
|
|
|
- keywords = template_info.get("keyword")
|
|
|
- target = template_info.get("target")
|
|
|
- type_ = template_info.get("type", "")
|
|
|
- dataJsonName = template_info.get("dataJsonName", "")
|
|
|
- value_key = template_info.get("value_key", "")
|
|
|
- name_key = template_info.get("name_key", "")
|
|
|
- find_max = template_info.get("find_max")
|
|
|
- name = template_info.get("name", "")
|
|
|
- content = template_info.get("content", "")
|
|
|
- play = template_info.get("play", "")
|
|
|
- qcode = template_info.get("qcode", "")
|
|
|
- unit = template_info.get("unit", "")
|
|
|
- flag = template_info.get("flag", "")
|
|
|
-
|
|
|
- return {
|
|
|
- "type": type_,
|
|
|
- "keywords": keywords,
|
|
|
- "target": target,
|
|
|
- "dataJsonName": dataJsonName,
|
|
|
- "name": name,
|
|
|
- "conditions": conditions,
|
|
|
- "content": content,
|
|
|
- "query": question,
|
|
|
- "play": play,
|
|
|
- "find_max": find_max,
|
|
|
- "value_key": value_key,
|
|
|
- "name_key": name_key,
|
|
|
- "qcode": qcode,
|
|
|
- "unit": unit,
|
|
|
- "flag": flag
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- return extracted_entities
|
|
|
-
|
|
|
- def _process_time_entities(self, time_entities):
|
|
|
- conditions = []
|
|
|
- for time_str in time_entities:
|
|
|
- if re.match(r'\d{4}年\d{1,2}月到\d{4}年\d{1,2}月', time_str):
|
|
|
- m = re.match(r'(\d{4})年(\d{1,2})月到(\d{4})年(\d{1,2})月', time_str)
|
|
|
- if m:
|
|
|
- conditions.append({
|
|
|
- "start_year": int(m.group(1)),
|
|
|
- "start_month": int(m.group(2)),
|
|
|
- "end_year": int(m.group(3)),
|
|
|
- "end_month": int(m.group(4))
|
|
|
- })
|
|
|
- elif re.match(r'\d{4}年\d{1,2}月到\d{1,2}月', time_str):
|
|
|
- m = re.match(r'(\d{4})年(\d{1,2})月到(\d{1,2})月', time_str)
|
|
|
- if m:
|
|
|
- conditions.append({
|
|
|
- "start_year": int(m.group(1)),
|
|
|
- "start_month": int(m.group(2)),
|
|
|
- "end_year": int(m.group(1)), # 同年
|
|
|
- "end_month": int(m.group(3))
|
|
|
- })
|
|
|
- elif re.match(r'\d{4}年\d{1,2}月', time_str):
|
|
|
- m = re.match(r'(\d{4})年(\d{1,2})月', time_str)
|
|
|
- if m:
|
|
|
- conditions.append({
|
|
|
- "year": int(m.group(1)),
|
|
|
- "month": int(m.group(2))
|
|
|
- })
|
|
|
- elif re.match(r'\d{4}年', time_str):
|
|
|
- m = re.match(r'(\d{4})年', time_str)
|
|
|
- if m:
|
|
|
- conditions.append({
|
|
|
- "year": int(m.group(1))
|
|
|
- })
|
|
|
- elif re.match(r'\d{1,2}月', time_str):
|
|
|
- m = re.match(r'(\d{1,2})月', time_str)
|
|
|
- if m:
|
|
|
- conditions.append({
|
|
|
- "month": int(m.group(1))
|
|
|
- })
|
|
|
- else:
|
|
|
- conditions.append({"text": time_str})
|
|
|
-
|
|
|
- # 如果只提取到一个时间实体,直接返回字典,否则返回列表
|
|
|
- if len(conditions) == 1:
|
|
|
- return conditions[0]
|
|
|
- else:
|
|
|
- return conditions
|
|
|
-
|
|
|
- def _process_indicator_entities(self, indicators):
|
|
|
- # 返回 key: indicator_name 字典
|
|
|
- # 根据已匹配的指标名称反查 key
|
|
|
- result = {}
|
|
|
- for indicator in indicators:
|
|
|
- for key, name in self.indicator_dict.items():
|
|
|
- if name == indicator:
|
|
|
- result[key] = name
|
|
|
- break
|
|
|
- return result
|
|
|
-
|
|
|
-
|
|
|
-# ==== 测试 ====
|
|
|
-extractor = EntityExtractor()
|
|
|
-# # question = "请问2023年5月到2024年3月、2024年5月山东省、河南省的交易电量累计是多少?"
|
|
|
-# # question = "2024年全年累计省间交易电量是多少??"
|
|
|
-# question = "2023年12月交易电量是多少?"
|
|
|
-# question = "2024年1月到3月累计交易电量是多少?"
|
|
|
-question = "2023年省间交易电量按交易周期划分的电量是多少?"
|
|
|
-# # question = "2023年省间交易电量年度交易电量是多少?"
|
|
|
-# # question = "今年哪个省火电送出电量最多?分别是多少?"
|
|
|
-# # question = "哪个省送出电量最高?是多少?"
|
|
|
-# # question = "省间交易正在组织的交易有多少?"
|
|
|
-# # question = "2024年送出电量前十名的省份是?"
|
|
|
-#
|
|
|
-result = extractor.extract(question)
|
|
|
-
|
|
|
-print("类型:", result["type"])
|
|
|
-print("关键词:", result["keywords"])
|
|
|
-print("查询字段:", result["target"])
|
|
|
-print("模型名字", result["name"])
|
|
|
-print("条件", result["conditions"])
|
|
|
-print("返回的内容是:", result["content"])
|
|
|
-print("问句是:", result["query"])
|
|
|
-print("动作是:", result["play"])
|
|
|
-print("描述:", result["content"])
|