plugin.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from typing import Optional
  2. from openai import OpenAI
  3. from wcferry import WxMsg
  4. from common.log import logger
  5. from config import conf
  6. from openai import OpenAI, AuthenticationError, APIConnectionError, APIError
  7. class Plugin:
  8. def __init__(self):
  9. self.config = conf()
  10. self.LOG = logger
  11. self.openAiClient = OpenAI(api_key=self.config.get("openai_key"),
  12. base_url=self.config.get("openai_base"))
  13. def answer(self, msg: WxMsg, wx_wxid: Optional[str] = None) -> str:
  14. rsp = ""
  15. try:
  16. messages = [
  17. {"role": "user", "content": msg.content}
  18. ]
  19. rsp = self._client_reply(self.openAiClient, messages)
  20. self.LOG.info(rsp)
  21. except AuthenticationError:
  22. self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
  23. except APIConnectionError:
  24. self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
  25. except APIError as e1:
  26. self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
  27. except Exception as e0:
  28. self.LOG.error(f"发生未知错误:{str(e0)}")
  29. return rsp
  30. def _client_reply(self, client: OpenAI, messages: list, sender: Optional[str] = None, roomid: Optional[str] = None):
  31. if sender is None and roomid is None:
  32. extra_body = {}
  33. else:
  34. chat_id = "chatId"
  35. if sender is not None:
  36. chat_id += "_" + sender
  37. if roomid is not None:
  38. chat_id += "_" + roomid
  39. extra_body = {
  40. "chatId": chat_id
  41. }
  42. extra_body = {}
  43. ret = client.chat.completions.create(
  44. model=self.config.get("open_ai_model", "gpt-4o"),
  45. max_tokens=self.config.get("open_ai_max_tokens", 8192),
  46. temperature=self.config.get("open_ai_temperature", 0.7),
  47. top_p=self.config.get("open_ai_top_p", 1),
  48. extra_body=extra_body,
  49. messages=messages
  50. )
  51. rsp = ret.choices[0].message.content
  52. rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
  53. rsp = rsp.replace("\n\n", "\n")
  54. return rsp