agent.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from typing import Optional
  2. import requests
  3. from openai import OpenAI
  4. from wcferry import WxMsg
  5. from common.log import logger
  6. from common.sql_lite import init_new_db_connection
  7. from config import Config
  8. from openai import OpenAI, AuthenticationError, APIConnectionError, APIError
  9. from db.msg_history import msg_history_get, msg_history_update
  10. from .plugin import Plugin
  11. class Agent(Plugin):
  12. def __init__(self):
  13. super().__init__()
  14. def answer(self, msg: WxMsg, wx_wxid: Optional[str] = None):
  15. connection = init_new_db_connection()
  16. dataset_id = self.config.get("dataset_id")
  17. answer = None
  18. cursor = connection.cursor()
  19. try:
  20. history_messages = msg_history_get(cursor, wx_wxid, msg.sender, msg.roomid)
  21. if not dataset_id:
  22. # 优化问题
  23. expand_system_prompt = f"""# 任务:
  24. 请根据上下文信息,优化用户发送的最后一条消息,补齐消息中可能缺失的主语、谓语、宾语、定语、状语、补语句子成分。
  25. # 用户发送的最后一条消息:{msg.content}
  26. # 回复要求
  27. 1. 直接输出优化后的消息"""
  28. expand_messages = [{"role": "system", "content": expand_system_prompt}] + history_messages
  29. expand_bot_reply = self._client_reply(self.openAiClient, expand_messages)
  30. answer = self._dataset_search(self.config.get("dataset_id"), expand_bot_reply)
  31. system_prompt = f"""# 角色
  32. {self.config.get("role")}
  33. # 背景:
  34. {self.config.get("background")}
  35. """
  36. if answer is not None:
  37. system_prompt += """# 相关知识:
  38. """
  39. for data in answer:
  40. system_prompt += f"""问题:{data['q']}
  41. 答案:{data['a']}
  42. """
  43. system_prompt += """# 回复要求:
  44. 1. 直接以角色设定的角度回答问题,并以第一人称输出。
  45. 2. 不要在回复前加角色、姓名。
  46. 3. 回复要正式"""
  47. messages = [{"role": "system", "content": system_prompt}] + history_messages + [{"role": "user", "content": msg.content}]
  48. bot_reply = self._client_reply(self.openAiClient, messages)
  49. new_history_messages = history_messages + [{"role": "user", "content": msg.content}, {"role": "assistant", "content": bot_reply}]
  50. msg_history_update(cursor, wx_wxid, new_history_messages, msg.sender, msg.roomid)
  51. connection.commit()
  52. return bot_reply
  53. except Exception as e:
  54. # 回滚事务
  55. connection.rollback()
  56. print(f"发生错误: {e}")
  57. finally:
  58. # 确保资源被正确释放
  59. cursor.close()
  60. connection.close()
  61. def _dataset_search(self, dataset_id: str, text: str):
  62. headers = {"Content-Type": "application/json",
  63. 'Authorization': f'Bearer {self.config.get("dataset_key")}'}
  64. args = {
  65. "datasetId": dataset_id,
  66. "text": text,
  67. "limit": 100,
  68. "similarity": 0.8,
  69. "searchMode": "mixedRecall",
  70. "usingReRank": False
  71. }
  72. response = requests.post(self.config.get("dataset_base"), headers=headers, json=args)
  73. if response.status_code == 200:
  74. response_json = response.json()
  75. if response_json and "data" in response_json and response_json["data"] and "list" in response_json["data"]:
  76. return response_json["data"]["list"]
  77. else:
  78. return None
  79. return None