similarity_answer_json.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. from sklearn.feature_extraction.text import TfidfVectorizer
  2. from sklearn.metrics.pairwise import cosine_similarity
  3. import jieba
  4. import os
  5. import json
  6. from final.ByRules.util import calculate_sum_by_time_range
  7. def jieba_tokenizer(text):
  8. return list(jieba.cut(text))
  9. # 定义问题模板
  10. template_dict = {
  11. "1": ["某年全年累计省间交易电量是多少?"],
  12. "2": ["某年某月交易电量是多少?"],
  13. "3": ["某年某月到某月累计交易电量是多少?"],
  14. "8.1": ["某年省间交易电量按交易周期划分的电量是多少?"],
  15. "8.2": ["某年省间交易电量按交易类型划分的电量是多少?"],
  16. "8.3": ["某年省间交易电量按发电类型划分的电量是多少?"],
  17. "8.4": ["某年省间交易电量按交易方式划分的电量是多少?"],
  18. "9.1": ["某年省间交易电量年度交易电量是多少?"],
  19. "9.2": ["某年省间交易电量月度交易电量是多少?"],
  20. "9.3": ["某年省间交易电量现货交易电量是多少?"],
  21. "9.4": ["某年省间交易电量应急交易电量是多少?"],
  22. "9.5": ["某年省间交易电量月内交易电量是多少?"],
  23. "9.6": ["某年省间交易电量省间外送交易电量是多少?"],
  24. "9.7": ["某年省间交易电量电力直接交易电量是多少?"],
  25. "9.8": ["某年省间交易电量合同交易电量是多少?"],
  26. "9.9": ["某年省间交易电量绿电交易电量是多少?"],
  27. "9.10": ["某年省间交易电量非市场化交易电量是多少?"],
  28. "9.11": ["某年省间交易电量新能源交易电量是多少?"],
  29. "9.12": ["某年省间交易电量火电交易电量是多少?"],
  30. "9.13": ["某年省间交易电量水电交易电量是多少?"],
  31. "9.14": ["某年省间交易电量核电交易电量是多少?"],
  32. "9.15": ["某年省间交易电量双边交易电量是多少?"],
  33. "9.16": ["某年省间交易电量集中交易电量是多少?"],
  34. "9.17": ["某年省间交易电量挂牌交易电量是多少?"],
  35. "17.1": ["那个省送出电量最高?是多少?"],
  36. "19": ["省间交易正在组织的交易有多少?"],
  37. "20": ["省间交易当月完成的交易有多少?"],
  38. "21": ["省间交易当年完成的交易有多少?"],
  39. "22": ["省间交易当年达成的电量有多少?"],
  40. "23": ["省间交易当年参与交易的家次有多少?"],
  41. }
  42. # 将地点映射成相应的代码
  43. def map_location_to_unit(location: str) -> str:
  44. mapping_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../省间关系映射.json'))
  45. if not os.path.exists(mapping_path):
  46. print(f"映射文件未找到: {mapping_path}")
  47. return '未知单位'
  48. with open(mapping_path, 'r', encoding='utf-8') as f:
  49. mapping = json.load(f)
  50. for code, name in mapping.items():
  51. if name == location:
  52. return code
  53. return '未知单位'
  54. # 提取时间和地点
  55. from typing import Tuple, List, Dict
  56. from datetime import datetime
  57. import re
  58. def extract_time_location(question: str) -> Tuple[List[Dict], List[str]]:
  59. current_date = datetime.now()
  60. current_year = current_date.year
  61. current_month = current_date.month
  62. # 匹配绝对时间
  63. absolute_patterns = [
  64. r'(?P<year>\d{4})年(?P<month>\d{1,2})月(?P<day>\d{1,2})日',
  65. r'(?P<year>\d{4})年(?P<month>\d{1,2})月',
  66. r'(?P<year>\d{4})年'
  67. ]
  68. relative_year_mapping = {
  69. '明年': current_year + 1,
  70. '今年': current_year,
  71. '去年': current_year - 1,
  72. '前年': current_year - 2
  73. }
  74. season_mapping = {
  75. '一季度': (1, 3),
  76. '二季度': (4, 6),
  77. '三季度': (7, 9),
  78. '四季度': (10, 12),
  79. '上半年': (1, 6),
  80. '下半年': (7, 12)
  81. }
  82. provinces = [
  83. '北京', '天津', '上海', '重庆', '河北', '山西', '辽宁', '吉林', '黑龙江',
  84. '江苏', '浙江', '安徽', '福建', '江西', '山东', '河南', '湖北', '湖南',
  85. '广东', '海南', '四川', '贵州', '云南', '陕西', '甘肃', '青海', '台湾',
  86. '内蒙古', '广西', '西藏', '宁夏', '新疆', '香港', '澳门'
  87. ]
  88. time_results = []
  89. used_keywords = set()
  90. # 🆕 处理“起止时间段”,格式:2023年1月到2024年2月、去年1月到今年2月、2023年1月到1月等
  91. range_pattern = r'(?P<start>(\d{4}|今|去|前|明)年(\d{1,2})?月?)到(?P<end>(\d{4}|今|去|前|明)年(\d{1,2})?月?)'
  92. for match in re.finditer(range_pattern, question):
  93. start_raw, end_raw = match.group('start'), match.group('end')
  94. def parse_relative(text):
  95. year = current_year
  96. month = None
  97. if '明年' in text:
  98. year = current_year + 1
  99. elif '今年' in text or '今' in text:
  100. year = current_year
  101. elif '去年' in text or '去' in text:
  102. year = current_year - 1
  103. elif '前年' in text or '前' in text:
  104. year = current_year - 2
  105. m = re.search(r'(\d{1,2})月', text)
  106. if m:
  107. month = int(m.group(1))
  108. return year, month
  109. def parse_absolute(text):
  110. m = re.match(r'(?P<year>\d{4})年(?P<month>\d{1,2})?月?', text)
  111. if m:
  112. year = int(m.group('year'))
  113. month = int(m.group('month')) if m.group('month') else None
  114. return year, month
  115. return None, None
  116. def parse_any(text):
  117. if any(key in text for key in relative_year_mapping.keys()) or text[:1] in ['今', '去', '前', '明']:
  118. return parse_relative(text)
  119. else:
  120. return parse_absolute(text)
  121. start_y, start_m = parse_any(start_raw)
  122. end_y, end_m = parse_any(end_raw)
  123. time_results.append({
  124. 'start_year': start_y, 'start_month': start_m,
  125. 'end_year': end_y, 'end_month': end_m,
  126. 'label': f'{start_raw}到{end_raw}',
  127. 'raw': match.group()
  128. })
  129. used_keywords.add(match.group())
  130. # 🆕 新增匹配“2024年1月到2月”,结束时间没有写年份,默认与开始时间同年
  131. partial_range_pattern = r'(?P<year>\d{4})年(?P<start_month>\d{1,2})月到(?P<end_month>\d{1,2})月'
  132. for match in re.finditer(partial_range_pattern, question):
  133. # 避免重复匹配已经被上面时间段匹配使用过的字符串
  134. if match.group() in used_keywords:
  135. continue
  136. year = int(match.group('year'))
  137. start_month = int(match.group('start_month'))
  138. end_month = int(match.group('end_month'))
  139. time_results.append({
  140. 'start_year': year,
  141. 'start_month': start_month,
  142. 'end_year': year,
  143. 'end_month': end_month,
  144. 'label': match.group(),
  145. 'raw': match.group()
  146. })
  147. used_keywords.add(match.group())
  148. # 相对+具体月份
  149. relative_absolute_pattern = r'(?P<relative>今|去|前)年(?P<month>\d{1,2})月'
  150. for match in re.finditer(relative_absolute_pattern, question):
  151. if match.group() in used_keywords:
  152. continue
  153. rel = match.group('relative')
  154. month = int(match.group('month'))
  155. year = {'今': current_year, '去': current_year - 1, '前': current_year - 2}.get(rel, current_year)
  156. time_results.append({'year': year, 'month': month, 'raw': match.group()})
  157. used_keywords.add(match.group())
  158. # 绝对时间
  159. for pattern in absolute_patterns:
  160. for match in re.finditer(pattern, question):
  161. if match.group() in used_keywords:
  162. continue
  163. time_info = {'raw': match.group()}
  164. gd = match.groupdict()
  165. if gd.get('year'):
  166. time_info['year'] = int(gd['year'])
  167. if gd.get('month'):
  168. time_info['month'] = int(gd['month'])
  169. if gd.get('day'):
  170. time_info['day'] = int(gd['day'])
  171. time_results.append(time_info)
  172. used_keywords.add(match.group())
  173. # 单独的相对年份关键词
  174. for term, year in relative_year_mapping.items():
  175. if term in question and term not in used_keywords:
  176. time_results.append({'year': year, 'label': term, 'raw': term})
  177. used_keywords.add(term)
  178. # 当前/上个月
  179. if '当前' in question and '当前' not in used_keywords:
  180. time_results.append({'year': current_year, 'month': current_month, 'label': '当前', 'raw': '当前'})
  181. used_keywords.add('当前')
  182. if '上个月' in question and '上个月' not in used_keywords:
  183. prev_year = current_year if current_month > 1 else current_year - 1
  184. prev_month = current_month - 1 if current_month > 1 else 12
  185. time_results.append({'year': prev_year, 'month': prev_month, 'label': '上个月', 'raw': '上个月'})
  186. used_keywords.add('上个月')
  187. # 季度和半年
  188. for term, (start_month, end_month) in season_mapping.items():
  189. if term in question and term not in used_keywords:
  190. time_results.append({
  191. 'year': current_year,
  192. 'label': term,
  193. 'start_month': start_month,
  194. 'end_month': end_month,
  195. 'raw': term
  196. })
  197. used_keywords.add(term)
  198. # 地点识别
  199. locations = [p for p in provinces if p in question]
  200. return time_results, locations
  201. # 先用 extract_time_location 判断问句包含哪类时间信息,然后只对结构匹配的模板子集做余弦匹配。
  202. # def classify_by_time_type(query, time_info):
  203. # if any('start_year' in t and 'end_year' in t for t in time_info):
  204. # return ['3'] # 时间段
  205. # return list(template_dict.keys()) # fallback 所有模板
  206. def classify_by_time_type(query, time_info):
  207. if not time_info:
  208. # 无时间信息时,返回指定模板 19-23
  209. return ['19', '20', '21', '22', '23']
  210. time = time_info[0]
  211. # 情况 1:起始时间和结束时间都有,判断为时间段
  212. if 'start_year' in time and 'end_year' in time:
  213. return ['3'] # 某年某月到某月累计交易电量
  214. # 情况 2:有 year 和 month,精确到月
  215. if 'year' in time and 'month' in time:
  216. return ['2'] # 某年某月交易电量
  217. # 情况 3:仅 year,全年
  218. if 'year' in time and 'month' not in time:
  219. return ['1','8.1','8.2','8.3','8.4','9.1','9.2','9.3','9.4','9.5','9.6','9.7','9.8','9.9','9.10','9.11','9.12','9.13','9.14','9.15','9.16','9.17'] # 某年全年累计交易电量
  220. def match_template_with_time_filter(query, template_dict, tokenizer, extract_time_location_func):
  221. """
  222. 先基于时间信息筛选候选模板,再进行TF-IDF匹配。
  223. """
  224. # 提取时间
  225. time_info, _ = extract_time_location_func(query)
  226. # 通过时间判断候选模板 key
  227. candidate_keys = classify_by_time_type(query, time_info)
  228. # 构造候选子模板字典
  229. filtered_template_dict = {k: template_dict[k] for k in candidate_keys}
  230. # 使用你原来的 TF-IDF 匹配函数
  231. return match_template(query, filtered_template_dict, tokenizer)
  232. # 找相似度最高的模板
  233. def match_template(query, template_dict, tokenizer):
  234. """
  235. 匹配与 query 最相似的模板句,并返回对应的 key、模板句和相似度分数。
  236. 参数:
  237. query (str): 用户输入的问句。
  238. template_dict (dict): 模板字典,格式为 {key: [模板句1, 模板句2, ...]}。
  239. tokenizer (callable): 分词器函数,例如 jieba.lcut。
  240. 返回:
  241. matched_key (str): 最相似模板的 key。
  242. best_match_sentence (str): 最相似的模板句。
  243. similarity_score (float): 相似度得分。
  244. """
  245. # 构造模板和 key 映射
  246. templates = []
  247. key_map = []
  248. for key, sentences in template_dict.items():
  249. for s in sentences:
  250. templates.append(s)
  251. key_map.append(key)
  252. # TF-IDF 向量化
  253. vectorizer = TfidfVectorizer(tokenizer=tokenizer)
  254. tfidf_matrix = vectorizer.fit_transform([query] + templates)
  255. # 计算余弦相似度
  256. cos_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
  257. most_similar_idx = cos_sim.argmax()
  258. # 获取最相似结果
  259. best_match_sentence = templates[most_similar_idx]
  260. matched_key = key_map[most_similar_idx]
  261. similarity_score = cos_sim[0][most_similar_idx]
  262. return matched_key, best_match_sentence, similarity_score
  263. # 根据模板去对应的json文件中找数据
  264. def load_template_info(matched_key, json_folder):
  265. """
  266. 根据 matched_key 从指定文件夹中加载对应的模板 JSON 文件内容。
  267. 参数:
  268. matched_key (str): 匹配到的模板 key,一般作为 JSON 文件名(不含扩展名)。
  269. json_folder (str): JSON 文件所在的文件夹路径。
  270. 返回:
  271. dict: 解析后的 JSON 内容字典。
  272. 异常:
  273. - 如果文件不存在或解析出错,会抛出异常。
  274. """
  275. # 构造完整文件路径
  276. file_path = os.path.join(json_folder, f"{matched_key}.json")
  277. if not os.path.exists(file_path):
  278. raise FileNotFoundError(f"未找到 JSON 文件:{file_path}")
  279. # 读取并解析 JSON 文件
  280. with open(file_path, 'r', encoding='utf-8') as f:
  281. data = json.load(f)
  282. return data
  283. def process_query(query, template_dict, json_folder, tokenizer=jieba_tokenizer):
  284. # 提取条件
  285. time_info, location_info = extract_time_location(query)
  286. conditions = {}
  287. # 匹配模板
  288. matched_key, best_sentence, score = match_template_with_time_filter(
  289. query,
  290. template_dict,
  291. tokenizer,
  292. extract_time_location_func=extract_time_location
  293. )
  294. # 定义阈值
  295. similarity_threshold = 0.3
  296. # ★ 判断相似度阈值
  297. if score < similarity_threshold:
  298. return {
  299. "matched_key": None,
  300. "matched_template": None,
  301. "similarity_score": score,
  302. "type": None,
  303. "keywords": None,
  304. "target": None,
  305. "name": None,
  306. "conditions": conditions,
  307. "content": "您提问的问题目前我还没有掌握",
  308. "query": query,
  309. "play":"疑问"
  310. }
  311. if time_info:
  312. ti = time_info[0]
  313. # 先判断是否是区间时间(有start_year/end_year等字段)
  314. if 'start_year' in ti and 'end_year' in ti:
  315. conditions['start_year'] = ti.get('start_year')
  316. conditions['start_month'] = ti.get('start_month')
  317. conditions['end_year'] = ti.get('end_year')
  318. conditions['end_month'] = ti.get('end_month')
  319. else:
  320. # 单时间点
  321. if 'year' in ti:
  322. conditions['年'] = ti['year']
  323. if 'month' in ti:
  324. conditions['月'] = ti['month']
  325. if location_info:
  326. unit = map_location_to_unit(location_info[0])
  327. if unit and unit != '未知单位':
  328. conditions['单位'] = unit
  329. # 查询模板json
  330. template_info = load_template_info(matched_key, json_folder)
  331. # 模板的关键词
  332. keywords = template_info.get("keyword")
  333. # 模板中的映射关系
  334. target = template_info.get("target")
  335. # 模板的类型
  336. type_ = template_info.get("type", "")
  337. # 模板的名字
  338. dataJsonName = template_info.get("dataJsonName", "")
  339. # ---------------- 比较类 -----------------
  340. value_key = template_info.get("value_key", "")
  341. name_key = template_info.get("name_key", "")
  342. find_max = template_info.get("find_max")
  343. # block名称
  344. name = template_info.get("name", "")
  345. # 输出内容
  346. content = template_info.get("content", "")
  347. # 动作类型
  348. play = template_info.get("play", "")
  349. return {
  350. "matched_key": matched_key,
  351. "matched_template": best_sentence,
  352. "similarity_score": score,
  353. "type": type_,
  354. "keywords": keywords,
  355. "target": target,
  356. "dataJsonName": dataJsonName,
  357. "name": name,
  358. "conditions": conditions,
  359. "content": content,
  360. "query": query,
  361. "play": play,
  362. "find_max": find_max,
  363. "value_key": value_key,
  364. "name_key": name_key
  365. }
  366. # 查询类
  367. def smart_find_value(folder_path, file_name, conditions: dict, target_key: str):
  368. file_name = file_name + ".json"
  369. file_path = os.path.join(folder_path, file_name)
  370. if not os.path.exists(file_path):
  371. print(f"文件 {file_path} 不存在")
  372. return None
  373. with open(file_path, 'r', encoding='utf-8') as f:
  374. try:
  375. data = json.load(f)
  376. except json.JSONDecodeError as e:
  377. print(f"JSON 解析失败:{e}")
  378. return None
  379. def match_conditions(record):
  380. return all(record.get(k) == v for k, v in conditions.items())
  381. # 情况一:数据是 dict
  382. if isinstance(data, dict):
  383. if not conditions or match_conditions(data):
  384. values = find_key_recursively(data, target_key)
  385. return values[0] if len(values) == 1 else values if values else None
  386. return None
  387. # 情况二:数据是 list
  388. elif isinstance(data, list):
  389. results = []
  390. for record in data:
  391. if isinstance(record, dict) and match_conditions(record):
  392. matches = find_key_recursively(record, target_key)
  393. results.extend(matches)
  394. if not results:
  395. return None
  396. elif len(results) == 1:
  397. return results[0]
  398. else:
  399. return results
  400. # 查询类的辅助函数
  401. def find_key_recursively(data, target_key):
  402. results = []
  403. def _search(obj):
  404. if isinstance(obj, dict):
  405. for k, v in obj.items():
  406. if k == target_key:
  407. results.append(v)
  408. _search(v)
  409. elif isinstance(obj, list):
  410. for item in obj:
  411. _search(item)
  412. _search(data)
  413. return results
  414. # query = "当月省间交易完成的交易是多少?"
  415. # query = "2024年1月到2月累计交易电量是多少?"
  416. query = "2023年省间交易电量新能源交易电量是多少??"
  417. # query = "但同样阿贾克斯大口径的话我可合金外壳设计文件突然发?"
  418. json_folder = "templatesJson"
  419. result = process_query(query, template_dict, json_folder)
  420. print("匹配的模板 key:", result["matched_key"])
  421. print("最相似的模板句:", result["matched_template"])
  422. print("相似度分数:", result["similarity_score"])
  423. print("类型:", result["type"])
  424. print("关键词:", result["keywords"])
  425. print("查询字段:", result["target"])
  426. print("模型名字", result["name"])
  427. print("条件", result["conditions"])
  428. print("返回的内容是:", result["content"])
  429. print("问句是:", result["query"])
  430. print("动作是:", result["play"])
  431. type = result["type"]
  432. content = result["content"]
  433. json_data_folder = "..\Json\json_data"
  434. if type == "query":
  435. fileName = result["dataJsonName"]
  436. result = smart_find_value(json_data_folder, fileName,result["conditions"],result["target"] )
  437. print(result)
  438. elif type == "calculate":
  439. conditions = result["conditions"]
  440. start_conditions = {('年' if 'year' in k else '月'): v for k, v in conditions.items() if k.startswith('start_')}
  441. end_conditions = {('年' if 'year' in k else '月'): v for k, v in conditions.items() if k.startswith('end_')}
  442. print(start_conditions)
  443. print(end_conditions)
  444. fileName = result["dataJsonName"] + ".json"
  445. result = calculate_sum_by_time_range(json_data_folder,fileName,result["target"],start_conditions, end_conditions)
  446. print(result)
  447. #
  448. # # 最终回答的文本
  449. # final_content = content.replace("?", str(result))
  450. # # print(f"{content}{result}")
  451. #
  452. # print(final_content)