Pārlūkot izejas kodu

1.增加agent回复
2.回复方式插件化
3.解耦部分页面代码(未完成)

宋伯文 4 mēneši atpakaļ
vecāks
revīzija
9320c17f21

+ 32 - 3
config.py

@@ -10,10 +10,10 @@ from common.log import logger
 # 将所有可用的配置项写在字典里, 请使用小写字母
 # 此处的配置值无实际意义,程序不会读取此处的配置,仅用于提示格式,请将配置加入到config.json中
 available_setting = {
-    "debug":False,
+    "debug": False,
     "api_base": "http://fastgpt.ascrm.cn/api/v1",
     "api_key": "fastgpt-sKABkv3PTHxlFZYPn9Mo35HHsZSdzdFNBH4XeWIRn5CwdkG7aXqEDmXwDwK",
-    "token":"bowen-test",
+    "token": "bowen-test",
     "open_ai_model": "gpt-4o",
     "open_ai_temperature": 0.7,
     "open_ai_max_tokens": 1024,
@@ -23,7 +23,20 @@ available_setting = {
     "open_ai_stop": ["<|im_end|>"],
     "open_ai_stream": True,
     "contacts_white_list": [],
-    "appdata_dir":"."
+    "appdata_dir": ".",
+
+    "agent_id": 0,
+    "role": "",
+    "background": "",
+    "examples": "",
+    "dataset_id": "",
+    "collection_id": "",
+    "custom_agent_base": "",
+    "custom_agent_key": "",
+    "openai_base": "",
+    "openai_key": "",
+    "dataset_base": "",
+    "dataset_key": "",
 }
 
 
@@ -82,6 +95,8 @@ class Config(dict):
 
 
 config = Config()
+
+
 # parser = argparse.ArgumentParser(description='消息中间处理程序')
 # parser.add_argument('-c', '--config', help='设置配置文件,默认值是 ./config.json')
 
@@ -134,6 +149,20 @@ def load_config():
     # logger.info("[INIT] load user datas: {}".format(config.get_user_data("api_base")))
 
 
+def set_config(name, value):
+    global config
+    name = name.lower()
+    if name in available_setting:
+        logger.info("[INIT] override config by environ args: {}={}".format(name, value))
+        try:
+            config[name] = eval(value)
+        except:
+            if value == "false":
+                config[name] = False
+            elif value == "true":
+                config[name] = True
+            else:
+                config[name] = value
 
 
 def get_root():

+ 67 - 0
db/msg_history.py

@@ -0,0 +1,67 @@
+import json
+from sqlite3 import Cursor
+from typing import TypedDict, Optional
+
+
+class MsgHistoryModel(TypedDict):
+    wx_wxid: str
+    sender: str
+    roomid: str
+    messages: str
+
+
+def msg_history_get(cur: Cursor, wx_wxid: str, sender: Optional[str], roomid: Optional[str]) -> list:
+    # 查询历史消息
+    fields = ["wx_wxid"]
+    params = [wx_wxid]
+
+    if roomid:
+        fields.append("roomid")
+        params.append(roomid)
+
+    if sender:
+        fields.append("sender")
+        params.append(sender)
+
+    query = f"SELECT * FROM msg_history WHERE {' AND '.join([f'{field} = ?' for field in fields])} LIMIT 1"
+    cur.execute(query, params)
+    result = cur.fetchone()
+    if result is None:
+        fields = ["wx_wxid", "messages"]
+        values = [wx_wxid, "[]"]
+
+        if roomid is not None:
+            fields.append("roomid")
+            values.append(roomid)
+
+        if sender is not None:
+            fields.append("sender")
+            values.append(sender)
+
+        operation = f"INSERT INTO msg_history ({', '.join(fields)}) VALUES ({', '.join(['?' for _ in fields])})"
+        cur.execute(operation, values)
+        content = []
+    else:
+        content = json.loads(result[3])
+
+    return content
+
+
+def msg_history_update(cur: Cursor, wx_wxid: str, history_messages: list, sender: Optional[str],
+                       roomid: Optional[str]):
+    messages_str = json.dumps(history_messages[-5:])
+
+    fields = ["wx_wxid"]
+    params = [messages_str, wx_wxid]
+
+    if roomid:
+        fields.append("roomid")
+        params.append(roomid)
+
+    if sender:
+        fields.append("sender")
+        params.append(sender)
+
+    update_query = f"UPDATE msg_history SET messages = ? WHERE {' AND '.join([f'{field} = ?' for field in fields])}"
+
+    cur.execute(update_query, params)

BIN
local.db


+ 0 - 0
logic/__init__.py


+ 1 - 1
logic/logic_batch_task_create.py

@@ -5,7 +5,7 @@ from common.sql_lite import init_new_db_connection
 from db.batch_msg import batch_msg_create_many
 from db.batch_task import batch_task_create
 from service.robot import get_robot
-from ui.ui_batch_task_create import WinGUIBatchTaskCreate
+from ui.batch_task.ui_batch_task_create import WinGUIBatchTaskCreate
 
 
 class WinBatchTaskCreate(WinGUIBatchTaskCreate):

+ 1 - 1
logic/logic_batch_task_detail.py

@@ -2,7 +2,7 @@ from datetime import datetime
 
 from common.sql_lite import get_global_db_connection
 from db.batch_msg import batch_msg_get_list, batch_msg_status, batch_msg_type
-from ui.ui_batch_task_detail import WinGUIBatchTaskDetail
+from ui.batch_task.ui_batch_task_detail import WinGUIBatchTaskDetail
 
 
 class WinBatchTaskDetail(WinGUIBatchTaskDetail):

+ 69 - 42
logic/logic_ui.py

@@ -9,12 +9,14 @@ from tkinter import messagebox
 import requests
 
 from common.sql_lite import get_global_db_connection
-from config import conf
+from config import conf, set_config
 from db.batch_task import batch_task_get_list, batch_task_status, batch_task_update_status
 from logic.logic_batch_task_create import open_batch_task_create_win
 from logic.logic_batch_task_detail import open_batch_task_detail_win
+from plugins.agent import Agent
+from plugins.custom_agent import CustomAgent
 from service.batch_task import stop_batch_task, start_batch_task
-from service.robot import get_robot
+from service.robot import get_robot, init_robot
 from ui.ui import WinGUI
 
 
@@ -27,16 +29,16 @@ class Win(WinGUI):
         self.selection_batch_task_id = None
 
     def __event_bind(self):
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_save.bind('<Button-1>', self.save_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_start.bind('<Button-1>', self.start_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_pause.bind('<Button-1>', self.stop_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_version.bind('<Button-1>', self.version_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_button_create.bind('<Button-1>', self.batch_task_create_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.bind('<Double-1>',
-                                                                                 self.batch_task_detail_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.bind('<Button-1>', self.just_click_event)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.bind('<Button-3>',
-                                                                                 self.batch_task_action_event)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_save.bind('<Button-1>', self.save_event)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_start.bind('<Button-1>', self.start_event)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_pause.bind('<Button-1>', self.stop_event)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_version.bind('<Button-1>', self.version_event)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_button_create.bind('<Button-1>', self.batch_task_create_event)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.bind('<Double-1>',
+                                                                                self.batch_task_detail_event)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.bind('<Button-1>', self.just_click_event)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.bind('<Button-3>',
+                                                                                self.batch_task_action_event)
         pass
 
     def __style_config(self):
@@ -47,15 +49,15 @@ class Win(WinGUI):
 
     def stop_event(self, event):
         stop_batch_task()
-        if self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_pause.cget('state').__str__() == tkinter.DISABLED:
+        if self.tk_tabs_main_tabs.tk_tabs_start.tk_button_pause.cget('state').__str__() == tkinter.DISABLED:
             return
         robot = get_robot()
         if robot is not None:
             robot.wcf.cleanup()
             robot = None
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_start.config(state=tkinter.NORMAL)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_pause.config(state=tkinter.DISABLED)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_button_create.config(state=tkinter.DISABLED)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_start.config(state=tkinter.NORMAL)
+        self.tk_tabs_main_tabs.tk_tabs_start.tk_button_pause.config(state=tkinter.DISABLED)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_button_create.config(state=tkinter.DISABLED)
         self.__is_started = False
         messagebox.showinfo('提示', '助手已停止运行!')
 
@@ -113,6 +115,23 @@ class Win(WinGUI):
                     messagebox.showerror('错误', 'token验证失败!')
                     return False
 
+                agent_info = resp['agent_info']
+                if agent_info is None:
+                    return False
+
+                set_config("agent_id", agent_info.get("id", 0))
+                set_config("role", agent_info.get("role", ""))
+                set_config("background", agent_info.get("background", ""))
+                set_config("examples", agent_info.get("examples", ""))
+                set_config("dataset_id", agent_info.get("dataset_id", ""))
+                set_config("collection_id", agent_info.get("collection_id", ""))
+                set_config("custom_agent_base", resp['custom_agent_base'])
+                set_config("custom_agent_key", resp['custom_agent_key'])
+                set_config("openai_base", resp['openai_base'])
+                set_config("openai_key", resp['openai_key'])
+                set_config("dataset_base", resp['dataset_base'])
+                set_config("dataset_key", resp['dataset_key'])
+
                 return True
 
             else:
@@ -124,7 +143,7 @@ class Win(WinGUI):
             return False
 
     def start_event(self, event):
-        if self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_start.cget('state').__str__() == tkinter.DISABLED:
+        if self.tk_tabs_main_tabs.tk_tabs_start.tk_button_start.cget('state').__str__() == tkinter.DISABLED:
             return
 
         if not self.check_token():
@@ -132,24 +151,32 @@ class Win(WinGUI):
 
         # 接收消息
         # robot.enableRecvMsg()     # 可能会丢消息?
+        # init_robot()
         robot = get_robot()
         robot.enableReceivingMsg()  # 加队列
 
         if robot.wcf.is_login():
             self.__is_started = True
-            self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_start.config(state='disabled')
-            self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_button_pause.config(state='normal')
-            self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_button_create.config(state='normal')
+            self.tk_tabs_main_tabs.tk_tabs_start.tk_button_start.config(state='disabled')
+            self.tk_tabs_main_tabs.tk_tabs_start.tk_button_pause.config(state='normal')
+            self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_button_create.config(state='normal')
             messagebox.showinfo('提示', '助手开始运行!')
             start_batch_task()
 
+            # 注册插件
+            if conf().get("agent_id") != 0 and conf().get("openai_base") and conf().get("openai_key"):
+                robot.register_plugin(Agent())
+
+            if conf().get("agent_id") == 0 and conf().get("custom_agent_base") and conf().get("custom_agent_key"):
+                robot.register_plugin(CustomAgent())
+
             wx_wxid = robot.wcf.get_self_wxid()
             connection = get_global_db_connection()
 
             def refresh_list():
                 # 清空列表
-                self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.delete(
-                    *self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.get_children())
+                self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.delete(
+                    *self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.get_children())
 
                 # 查询数据
                 cursor = connection.cursor()
@@ -163,12 +190,12 @@ class Win(WinGUI):
                     values = (
                     created_at, status, f"{result['success'] + result['fail']}/{result['total']}", result['fail'],
                     result['content'])
-                    self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.insert('', "end",
-                                                                                               iid=result['id'],
-                                                                                               values=values)
+                    self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.insert('', "end",
+                                                                                              iid=result['id'],
+                                                                                              values=values)
                     # 维持选中状态
                     if result['id'] == self.selection_batch_task_id:
-                        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.selection_set(result['id'])
+                        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.selection_set(result['id'])
 
                 # 每 3秒调用一次这个函数
                 self.tk_tabs_main_tabs.after(3000, refresh_list)
@@ -177,9 +204,9 @@ class Win(WinGUI):
 
     def save_event(self, event):
         conf().update({
-            "api_base": self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_input_api_base.get(),
-            "api_key": self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_input_api_key.get(),
-            "token": self.tk_tabs_main_tabs.tk_tabs_main_tabs_0.tk_input_token.get()
+            # "api_base": self.tk_tabs_main_tabs.tk_tabs_start.tk_input_api_base.get(),
+            # "api_key": self.tk_tabs_main_tabs.tk_tabs_start.tk_input_api_key.get(),
+            "token": self.tk_tabs_main_tabs.tk_tabs_start.tk_input_token.get()
         })
         # 将字典写入 JSON 文件
         with open('config.json', 'w') as json_file:
@@ -193,9 +220,9 @@ class Win(WinGUI):
 
     def batch_task_detail_event(self, event):
         # 获取双击项的标识符
-        item_id = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.identify_row(event.y)
+        item_id = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.identify_row(event.y)
         # 获取该项的值
-        item_values = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.item(item_id, "values")
+        item_values = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.item(item_id, "values")
         # print("Double-clicked item:", item_values)
         if len(item_values) > 0:
             self.selection_batch_task_id = int(item_id)
@@ -205,9 +232,9 @@ class Win(WinGUI):
 
     def batch_task_start_event(self, event):
         # 获取双击项的标识符
-        item_id = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.identify_row(event.y)
+        item_id = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.identify_row(event.y)
         # 获取该项的值
-        item_values = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.item(item_id, "values")
+        item_values = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.item(item_id, "values")
         # print("Double-clicked item:", item_values)
         if len(item_values) > 0:
             self.selection_batch_task_id = int(item_id)
@@ -228,9 +255,9 @@ class Win(WinGUI):
 
     def batch_task_stop_event(self, event):
         # 获取双击项的标识符
-        item_id = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.identify_row(event.y)
+        item_id = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.identify_row(event.y)
         # 获取该项的值
-        item_values = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.item(item_id, "values")
+        item_values = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.item(item_id, "values")
         # print("Double-clicked item:", item_values)
         if len(item_values) > 0:
             self.selection_batch_task_id = int(item_id)
@@ -251,18 +278,18 @@ class Win(WinGUI):
 
     def batch_task_action_event(self, event):
         # 获取右键项的标识符
-        item_id = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.identify_row(event.y)
+        item_id = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.identify_row(event.y)
         print(f"Right-clicked item: {item_id}")
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.selection_set(item_id)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.selection_set(item_id)
         # 获取该项的值
-        item_values = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.item(item_id, "values")
+        item_values = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.item(item_id, "values")
         if len(item_values) > 0:
             self.selection_batch_task_id = int(item_id)
 
             # print("Right-clicked item:", item_values)
 
             # 创建菜单
-            menu = tkinter.Menu(self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list, tearoff=0)
+            menu = tkinter.Menu(self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list, tearoff=0)
             menu.add_command(label="查看详情", command=lambda: self.batch_task_detail_event(event))
             if item_values[1] == "等待中" or item_values[1] == "已开始":
                 menu.add_command(label="停止", command=lambda: self.batch_task_stop_event(event))
@@ -273,9 +300,9 @@ class Win(WinGUI):
             self.selection_batch_task_id = None
 
     def just_click_event(self, event):
-        item_id = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.identify_row(event.y)
-        self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.selection_set(item_id)
-        item_values = self.tk_tabs_main_tabs.tk_tabs_main_tabs_1.tk_table_batch_task_list.item(item_id, "values")
+        item_id = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.identify_row(event.y)
+        self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.selection_set(item_id)
+        item_values = self.tk_tabs_main_tabs.tk_tabs_batch_task.tk_table_batch_task_list.item(item_id, "values")
         if len(item_values) > 0:
             self.selection_batch_task_id = int(item_id)
         else:

+ 0 - 0
plugins/__init__.py


+ 91 - 0
plugins/agent.py

@@ -0,0 +1,91 @@
+from typing import Optional
+
+import requests
+from openai import OpenAI
+from wcferry import WxMsg
+
+from common.log import logger
+from common.sql_lite import init_new_db_connection
+from config import Config
+from openai import OpenAI, AuthenticationError, APIConnectionError, APIError
+
+from db.msg_history import msg_history_get, msg_history_update
+
+from .plugin import Plugin
+
+
+class Agent(Plugin):
+    def __init__(self):
+        super().__init__()
+
+    def answer(self, msg: WxMsg, wx_wxid: Optional[str] = None):
+        connection = init_new_db_connection()
+        dataset_id = self.config.get("dataset_id")
+        answer = None
+        cursor = connection.cursor()
+        try:
+            history_messages = msg_history_get(cursor, wx_wxid, msg.sender, msg.roomid)
+            if dataset_id is not None:
+                # 优化问题
+                expand_system_prompt = f"""# 任务:
+请根据上下文信息,优化用户发送的最后一条消息,补齐消息中可能缺失的主语、谓语、宾语、定语、状语、补语句子成分。
+
+# 用户发送的最后一条消息:{msg.content}
+
+# 回复要求
+1. 直接输出优化后的消息"""
+                expand_messages = [{"role": "system", "content": expand_system_prompt}] + history_messages
+                expand_bot_reply = self._client_reply(self.openAiClient, expand_messages)
+                answer = self._dataset_search(self.config.get("dataset_id"), expand_bot_reply)
+
+            system_prompt = f"""# 角色
+    {self.config.get("role")}
+    
+    # 背景:
+    {self.config.get("background")}
+"""
+            if answer is not None:
+                system_prompt += """# 相关知识:
+"""
+                for data in answer:
+                    system_prompt += f"""问题:{data['q']}
+答案:{data['a']}
+"""
+            system_prompt += """# 回复要求:
+1. 直接以角色设定的角度回答问题,并以第一人称输出。
+2. 不要在回复前加角色、姓名。
+3. 回复要正式"""
+            messages = [{"role": "system", "content": system_prompt}] + history_messages + [{"role": "user", "content": msg.content}]
+            bot_reply = self._client_reply(self.openAiClient, messages)
+            new_history_messages = history_messages + [{"role": "user", "content": msg.content}, {"role": "assistant", "content": bot_reply}]
+            msg_history_update(cursor, wx_wxid, new_history_messages, msg.sender, msg.roomid)
+            return bot_reply
+        except Exception as e:
+            # 回滚事务
+            connection.rollback()
+            print(f"发生错误: {e}")
+        finally:
+            # 确保资源被正确释放
+            connection.commit()
+            cursor.close()
+            connection.close()
+
+    def _dataset_search(self, dataset_id: str, text: str):
+        headers = {"Content-Type": "application/json",
+                   'Authorization': f'Bearer {self.config.get("dataset_key")}'}
+        args = {
+            "datasetId": dataset_id,
+            "text": text,
+            "limit": 100,
+            "similarity": 0.8,
+            "searchMode": "mixedRecall",
+            "usingReRank": False
+        }
+        response = requests.post(self.config.get("dataset_base"), headers=headers, json=args)
+        if response.status_code == 200:
+            response_json = response.json()
+            if response_json and "data" in response_json and response_json["data"] and "list" in response_json["data"]:
+                return response_json["data"]["list"]
+            else:
+                return None
+        return None

+ 31 - 0
plugins/custom_agent.py

@@ -0,0 +1,31 @@
+from typing import Optional
+
+from openai import OpenAI
+from wcferry import WxMsg
+
+from plugins.plugin import Plugin
+from openai import OpenAI, AuthenticationError, APIConnectionError, APIError
+
+
+class CustomAgent(Plugin):
+    def __init__(self):
+        super().__init__()
+        self.customAiClient = OpenAI(api_key=self.config.get("custom_agent_key"), base_url=self.config.get("custom_agent_base"))
+
+    def answer(self, msg: WxMsg, wx_wxid: Optional[str] = None):
+        rsp = ""
+        try:
+            messages = [
+                {"role": "user", "content": msg.content}
+            ]
+            rsp = self._client_reply(self.customAiClient, messages, msg.sender, msg.roomid)
+            self.LOG.info(rsp)
+        except AuthenticationError:
+            self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
+        except APIConnectionError:
+            self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
+        except APIError as e1:
+            self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
+        except Exception as e0:
+            self.LOG.error(f"发生未知错误:{str(e0)}")
+        return rsp

+ 60 - 0
plugins/plugin.py

@@ -0,0 +1,60 @@
+from typing import Optional
+
+from openai import OpenAI
+from wcferry import WxMsg
+
+from common.log import logger
+from config import conf
+from openai import OpenAI, AuthenticationError, APIConnectionError, APIError
+
+
+class Plugin:
+    def __init__(self):
+        self.config = conf()
+        self.LOG = logger
+        self.openAiClient = OpenAI(api_key=self.config.get("openai_key"),
+                                   base_url=self.config.get("openai_base"))
+
+    def answer(self, msg: WxMsg, wx_wxid: Optional[str] = None) -> str:
+        rsp = ""
+        try:
+            messages = [
+                {"role": "user", "content": msg.content}
+            ]
+            rsp = self._client_reply(self.openAiClient, messages)
+            self.LOG.info(rsp)
+        except AuthenticationError:
+            self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
+        except APIConnectionError:
+            self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
+        except APIError as e1:
+            self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
+        except Exception as e0:
+            self.LOG.error(f"发生未知错误:{str(e0)}")
+        return rsp
+
+    def _client_reply(self, client: OpenAI, messages: list, sender: Optional[str] = None, roomid: Optional[str] = None):
+        if sender is None and roomid is None:
+            extra_body = {}
+        else:
+            chat_id = "chatId"
+            if sender is not None:
+                chat_id += "_" + sender
+            if roomid is not None:
+                chat_id += "_" + roomid
+            extra_body = {
+                "chatId": chat_id
+            }
+        extra_body = {}
+        ret = client.chat.completions.create(
+            model=self.config.get("open_ai_model", "gpt-4o"),
+            max_tokens=self.config.get("open_ai_max_tokens", 8192),
+            temperature=self.config.get("open_ai_temperature", 0.7),
+            top_p=self.config.get("open_ai_top_p", 1),
+            extra_body=extra_body,
+            messages=messages
+        )
+        rsp = ret.choices[0].message.content
+        rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
+        rsp = rsp.replace("\n\n", "\n")
+        return rsp

+ 46 - 35
service/robot.py

@@ -12,6 +12,7 @@ from wcferry import Wcf, WxMsg, wcf_pb2
 
 from common.log import logger
 from config import Config, conf
+from plugins.plugin import Plugin
 
 
 class Robot():
@@ -22,10 +23,15 @@ class Robot():
         self.wxid = self.wcf.get_self_wxid()
         self.user = self.wcf.get_user_info()
         self.allContacts = self.getAllContacts()
-        self.aiClient = OpenAI(api_key=self.config.get("api_key"), base_url=self.config.get("api_base"))
+        self.plugins: list[Plugin] = []
+
+        # self.aiClient = OpenAI(api_key=self.config.get("openai_key"), base_url=self.config.get("openai_base"))
 
         self.LOG.info(f"{self.user} 登录成功")
 
+    def register_plugin(self, plugin: Plugin):
+        self.plugins.append(plugin)
+
     def enableRecvMsg(self) -> None:
         """
         打开消息通知,可能会丢消息
@@ -131,41 +137,46 @@ class Robot():
 
     def get_answer(self, msg: WxMsg) -> str:
         rsp = ""
-        try:
-            self.aiClient.api_key = self.config.get("api_key")
-            self.aiClient.base_url = self.config.get("api_base")
-
-            # 在fastgpt的时候增加chatId字段
-            if "fastgpt" in self.config.get("api_base"):
-                extra_body = {
-                    "chatId": "chatId-" + msg.sender
-                }
-            else:
-                extra_body = {}
-
-            ret = self.aiClient.chat.completions.create(
-                model=self.config.get("open_ai_model", "gpt-3.5-turbo"),
-                max_tokens=self.config.get("open_ai_max_tokens", 8192),
-                temperature=self.config.get("open_ai_temperature", 0.7),
-                top_p=self.config.get("open_ai_top_p", 1),
-                extra_body=extra_body,
-                messages=[
-                    {"role": "user", "content": msg.content}
-                ]
-            )
-            rsp = ret.choices[0].message.content
-            rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
-            rsp = rsp.replace("\n\n", "\n")
-            self.LOG.info(rsp)
-        except AuthenticationError:
-            self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
-        except APIConnectionError:
-            self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
-        except APIError as e1:
-            self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
-        except Exception as e0:
-            self.LOG.error(f"发生未知错误:{str(e0)}")
+        for plugin in self.plugins:
+            rsp = plugin.answer(msg, self.wxid)
+            if rsp != "":
+                break
         return rsp
+        # try:
+        #     # self.aiClient.api_key = self.config.get("api_key")
+        #     # self.aiClient.base_url = self.config.get("api_base")
+        #
+        #     # 在fastgpt的时候增加chatId字段
+        #     if "fastgpt" in self.config.get("api_base"):
+        #         extra_body = {
+        #             "chatId": "chatId-" + msg.sender + "-" + msg.roomid
+        #         }
+        #     else:
+        #         extra_body = {}
+        #
+        #     ret = self.aiClient.chat.completions.create(
+        #         model=self.config.get("open_ai_model", "gpt-3.5-turbo"),
+        #         max_tokens=self.config.get("open_ai_max_tokens", 8192),
+        #         temperature=self.config.get("open_ai_temperature", 0.7),
+        #         top_p=self.config.get("open_ai_top_p", 1),
+        #         extra_body=extra_body,
+        #         messages=[
+        #             {"role": "user", "content": msg.content}
+        #         ]
+        #     )
+        #     rsp = ret.choices[0].message.content
+        #     rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
+        #     rsp = rsp.replace("\n\n", "\n")
+        #     self.LOG.info(rsp)
+        # except AuthenticationError:
+        #     self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
+        # except APIConnectionError:
+        #     self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
+        # except APIError as e1:
+        #     self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
+        # except Exception as e0:
+        #     self.LOG.error(f"发生未知错误:{str(e0)}")
+        # return rsp
 
     def sendTextMsg(self, msg: str, receiver: str, at_list: str = "") -> int:
         """ 发送消息

+ 0 - 0
ui/__init__.py


+ 41 - 0
ui/batch_task/frame_tab.py

@@ -0,0 +1,41 @@
+import tkinter
+from tkinter import *
+from tkinter.ttk import *
+
+
+class FrameTabsBatchTask(Frame):
+    def __init__(self, parent):
+        super().__init__(parent)
+        self.__frame()
+        self.tk_table_batch_task_list = self.__tk_table_batch_task_list(self)
+        self.tk_button_create = self.__tk_button_create(self)
+
+    def __frame(self):
+        self.place(x=0, y=0, width=675, height=510)
+
+    def __tk_table_batch_task_list(self, parent):
+        # 表头字段 表头宽度
+        columns = {"创建时间": 121, "状态": 121, "进度": 121, "失败": 121, "内容": 121}
+        # 初始化表格 表格是基于Treeview,tkinter本身没有表格。show="headings" 为隐藏首列。
+        tk_table = Treeview(parent, show="headings", columns=list(columns))
+        for text, width in columns.items():  # 批量设置列属性
+            tk_table.heading(text, text=text, anchor='center')
+            tk_table.column(text, anchor='center', width=width, stretch=False)  # stretch 不自动拉伸
+
+        # 插入数据示例
+        # data = [
+        #     [1, "github", "https://github.com/iamxcd/tkinter-helper"],
+        #     [2, "演示地址", "https://www.pytk.net/tkinter-helper"]
+        # ]
+        #
+        # # 导入初始数据
+        # for values in data:
+        #     tk_table.insert('', END, values=values)
+
+        tk_table.place(x=30, y=50, width=608, height=420)
+        return tk_table
+
+    def __tk_button_create(self, parent):
+        btn = Button(parent, text="新建任务", takefocus=False, state=tkinter.DISABLED)
+        btn.place(x=560, y=10, width=80, height=30)
+        return btn

+ 0 - 0
ui/ui_batch_task_create.py → ui/batch_task/ui_batch_task_create.py


+ 0 - 0
ui/ui_batch_task_detail.py → ui/batch_task/ui_batch_task_detail.py


+ 76 - 0
ui/start/frame_tab.py

@@ -0,0 +1,76 @@
+import tkinter
+from tkinter import *
+from tkinter.ttk import *
+from config import conf
+
+
+class FrameTabsStart(Frame):
+    def __init__(self, parent):
+        super().__init__(parent)
+        self.__frame()
+        # self.tk_label_lab_api_base = self.__tk_label_lab_api_base(self)
+        # self.tk_input_api_base = self.__tk_input_api_base(self)
+        # self.tk_label_lab_api_key = self.__tk_label_lab_api_key(self)
+        # self.tk_input_api_key = self.__tk_input_api_key(self)
+        self.tk_label_lab_token = self.__tk_label_lab_token(self)
+        self.tk_input_token = self.__tk_input_token(self)
+        self.tk_button_save = self.__tk_button_save(self)
+        self.tk_button_start = self.__tk_button_start(self)
+        self.tk_button_pause = self.__tk_button_pause(self)
+        self.tk_button_version = self.__tk_button_version(self)
+
+    def __frame(self):
+        self.place(x=0, y=0, width=675, height=510)
+
+    # def __tk_label_lab_api_base(self, parent):
+    #     label = Label(parent, text="API_BASE", anchor="center", )
+    #     label.place(x=30, y=30, width=60, height=30)
+    #     return label
+    #
+    # def __tk_label_lab_api_key(self, parent):
+    #     label = Label(parent, text="API_KEY", anchor="center", )
+    #     label.place(x=30, y=80, width=60, height=30)
+    #     return label
+    #
+    # def __tk_input_api_base(self, parent):
+    #     ipt = Entry(parent, )
+    #     ipt.insert(0, conf().get("api_base"))
+    #     ipt.place(x=100, y=30, width=520, height=30)
+    #     return ipt
+    #
+    # def __tk_input_api_key(self, parent):
+    #     ipt = Entry(parent, )
+    #     ipt.insert(0, conf().get("api_key"))
+    #     ipt.place(x=100, y=80, width=520, height=30)
+    #     return ipt
+
+    def __tk_input_token(self, parent):
+        ipt = Entry(parent, )
+        ipt.insert(0, conf().get("token"))
+        ipt.place(x=100, y=30, width=520, height=30)
+        return ipt
+
+    def __tk_label_lab_token(self, parent):
+        label = Label(parent, text="TOKEN", anchor="center", )
+        label.place(x=30, y=30, width=60, height=30)
+        return label
+
+    def __tk_button_save(self, parent):
+        btn = Button(parent, text="保存", takefocus=False, )
+        btn.place(x=30, y=80, width=80, height=30)
+        return btn
+
+    def __tk_button_start(self, parent):
+        btn = Button(parent, text="启动", takefocus=False, )
+        btn.place(x=130, y=80, width=80, height=30)
+        return btn
+
+    def __tk_button_pause(self, parent):
+        btn = Button(parent, text="暂停", takefocus=False, state=tkinter.DISABLED)
+        btn.place(x=230, y=80, width=50, height=30)
+        return btn
+
+    def __tk_button_version(self, parent):
+        btn = Button(parent, text="版本", takefocus=False)
+        btn.place(x=530, y=80, width=50, height=30)
+        return btn

+ 6 - 114
ui/ui.py

@@ -2,7 +2,8 @@ import tkinter
 from tkinter import *
 from tkinter.ttk import *
 
-from config import conf
+from .batch_task.frame_tab import FrameTabsBatchTask
+from .start.frame_tab import FrameTabsStart
 
 
 class WinGUI(Tk):
@@ -68,120 +69,11 @@ class Frame_main_tabs(Notebook):
         self.__frame()
 
     def __frame(self):
-        self.tk_tabs_main_tabs_0 = Frame_main_tabs_0(self)
-        self.add(self.tk_tabs_main_tabs_0, text="启动服务")
+        self.tk_tabs_start = FrameTabsStart(self)
+        self.add(self.tk_tabs_start, text="启动服务")
 
-        self.tk_tabs_main_tabs_1 = Frame_main_tabs_1(self)
-        self.add(self.tk_tabs_main_tabs_1, text="群发管理")
+        self.tk_tabs_batch_task = FrameTabsBatchTask(self)
+        self.add(self.tk_tabs_batch_task, text="群发管理")
 
         self.place(x=0, y=0, width=675, height=510)
 
-
-class Frame_main_tabs_0(Frame):
-    def __init__(self, parent):
-        super().__init__(parent)
-        self.__frame()
-        self.tk_label_lab_api_base = self.__tk_label_lab_api_base(self)
-        self.tk_input_api_base = self.__tk_input_api_base(self)
-        self.tk_label_lab_api_key = self.__tk_label_lab_api_key(self)
-        self.tk_input_api_key = self.__tk_input_api_key(self)
-        self.tk_label_lab_token = self.__tk_label_lab_token(self)
-        self.tk_input_token = self.__tk_input_token(self)
-        self.tk_button_save = self.__tk_button_save(self)
-        self.tk_button_start = self.__tk_button_start(self)
-        self.tk_button_pause = self.__tk_button_pause(self)
-        self.tk_button_version = self.__tk_button_version(self)
-
-    def __frame(self):
-        self.place(x=0, y=0, width=675, height=510)
-
-    def __tk_label_lab_api_base(self, parent):
-        label = Label(parent, text="API_BASE", anchor="center", )
-        label.place(x=30, y=30, width=60, height=30)
-        return label
-
-    def __tk_label_lab_api_key(self, parent):
-        label = Label(parent, text="API_KEY", anchor="center", )
-        label.place(x=30, y=80, width=60, height=30)
-        return label
-
-    def __tk_input_api_base(self, parent):
-        ipt = Entry(parent, )
-        ipt.insert(0, conf().get("api_base"))
-        ipt.place(x=100, y=30, width=520, height=30)
-        return ipt
-
-    def __tk_input_api_key(self, parent):
-        ipt = Entry(parent, )
-        ipt.insert(0, conf().get("api_key"))
-        ipt.place(x=100, y=80, width=520, height=30)
-        return ipt
-
-    def __tk_input_token(self, parent):
-        ipt = Entry(parent, )
-        ipt.insert(0, conf().get("token"))
-        ipt.place(x=100, y=130, width=520, height=30)
-        return ipt
-
-    def __tk_label_lab_token(self, parent):
-        label = Label(parent, text="TOKEN", anchor="center", )
-        label.place(x=30, y=130, width=60, height=30)
-        return label
-
-    def __tk_button_save(self, parent):
-        btn = Button(parent, text="保存", takefocus=False, )
-        btn.place(x=30, y=190, width=80, height=30)
-        return btn
-
-    def __tk_button_start(self, parent):
-        btn = Button(parent, text="启动", takefocus=False, )
-        btn.place(x=130, y=190, width=80, height=30)
-        return btn
-
-    def __tk_button_pause(self, parent):
-        btn = Button(parent, text="暂停", takefocus=False, state=tkinter.DISABLED)
-        btn.place(x=230, y=190, width=50, height=30)
-        return btn
-
-    def __tk_button_version(self, parent):
-        btn = Button(parent, text="版本", takefocus=False)
-        btn.place(x=530, y=190, width=50, height=30)
-        return btn
-
-
-class Frame_main_tabs_1(Frame):
-    def __init__(self, parent):
-        super().__init__(parent)
-        self.__frame()
-        self.tk_table_batch_task_list = self.__tk_table_batch_task_list(self)
-        self.tk_button_create = self.__tk_button_create(self)
-
-    def __frame(self):
-        self.place(x=0, y=0, width=675, height=510)
-
-    def __tk_table_batch_task_list(self, parent):
-        # 表头字段 表头宽度
-        columns = {"创建时间": 121, "状态": 121, "进度": 121, "失败": 121, "内容": 121}
-        # 初始化表格 表格是基于Treeview,tkinter本身没有表格。show="headings" 为隐藏首列。
-        tk_table = Treeview(parent, show="headings", columns=list(columns))
-        for text, width in columns.items():  # 批量设置列属性
-            tk_table.heading(text, text=text, anchor='center')
-            tk_table.column(text, anchor='center', width=width, stretch=False)  # stretch 不自动拉伸
-
-        # 插入数据示例
-        # data = [
-        #     [1, "github", "https://github.com/iamxcd/tkinter-helper"],
-        #     [2, "演示地址", "https://www.pytk.net/tkinter-helper"]
-        # ]
-        #
-        # # 导入初始数据
-        # for values in data:
-        #     tk_table.insert('', END, values=values)
-
-        tk_table.place(x=30, y=50, width=608, height=420)
-        return tk_table
-
-    def __tk_button_create(self, parent):
-        btn = Button(parent, text="新建任务", takefocus=False, state=tkinter.DISABLED)
-        btn.place(x=560, y=10, width=80, height=30)
-        return btn