浏览代码

增加表单提取接口

boweniac 1 月之前
父节点
当前提交
15f07aa167

+ 7 - 0
app/call_center/api/form/__init__.py

@@ -0,0 +1,7 @@
+from fastapi import APIRouter
+
+from app.call_center.api.form.form import router
+
+v1 = APIRouter(prefix='/form', tags=['表单提取'])
+
+v1.include_router(router, prefix='/extract', tags=['提取'])

+ 24 - 0
app/call_center/api/form/form.py

@@ -0,0 +1,24 @@
+import uuid
+
+from app.call_center.schema.form_records import CreateFormRecordsParam
+from app.call_center.service.form_records_service import form_records_service
+from common.response.response_schema import ResponseModel, response_base
+from common.security.jwt_call_center import DependsJwtAuth
+from fastapi import APIRouter, Request
+
+router = APIRouter()
+@router.post(
+    '',
+    summary='创建record',
+    dependencies=[
+        DependsJwtAuth
+    ],
+)
+async def create_form_record(obj: CreateFormRecordsParam, request: Request) -> ResponseModel:
+    id_str = str(uuid.uuid4())
+    obj.id = id_str
+    obj.org_id = request.user.id
+    form_value = await form_records_service.create(obj=obj)
+    if form_value:
+        return response_base.success(data=form_value)
+    return response_base.fail()

+ 19 - 0
app/call_center/crud/crud_form_records.py

@@ -0,0 +1,19 @@
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy_crud_plus import CRUDPlus
+
+from app.call_center.schema.form_records import CreateFormRecordsParam
+from model.form_records import FormRecords
+
+
+class CRUDFormRecords(CRUDPlus[FormRecords]):
+    async def create(self, db: AsyncSession, obj_in: CreateFormRecordsParam) -> None:
+        """
+        创建call record
+
+        :param db:
+        :param obj_in:
+        :return:
+        """
+        await self.create_model(db, obj_in)
+
+form_records_dao: CRUDFormRecords = CRUDFormRecords(FormRecords)

+ 30 - 0
app/call_center/schema/form_records.py

@@ -0,0 +1,30 @@
+
+from common.schema import SchemaBase
+from datetime import datetime
+
+class FormRecordsSchemaBase(SchemaBase):
+    pass
+
+class FormFieldOptions(FormRecordsSchemaBase):
+    value: str
+    label: str
+    children: list['FormFieldOptions'] | None = None
+
+class FormFieldProps(FormRecordsSchemaBase):
+    options: list[FormFieldOptions]
+
+class FormData(FormRecordsSchemaBase):
+    valueType: str
+    title: str
+    dataIndex: str
+    fieldProps: FormFieldProps | None = None
+
+class CreateFormRecordsParam(FormRecordsSchemaBase):
+    id: str | None = 0
+    external_id: str
+    chat_history: str
+    form_data: list[FormData]
+    form_value: list | None = None
+    org_id: int | None = 0
+    created_at: datetime | None = None
+    updated_at: datetime | None = None

+ 122 - 0
app/call_center/service/form_records_service.py

@@ -0,0 +1,122 @@
+import json
+
+from openai import OpenAI
+
+from app.admin.schema.intent_org import CurrentIntentOrgIns
+from app.call_center.crud.crud_form_records import form_records_dao
+from app.call_center.schema.form_records import CreateFormRecordsParam
+from core.conf import settings
+from database.db_mysql import async_db_session
+from database.db_redis import redis_client
+from common.log import log
+from utils.serializers import select_as_dict
+
+
+class FormRecordsService:
+    @staticmethod
+    async def create(*, obj: CreateFormRecordsParam) -> None:
+        # 从缓存中获取机构信息
+        key = f'{settings.TOKEN_CALL_REDIS_PREFIX}:{obj.org_id}'
+        org_json = await redis_client.get(key)
+        if not org_json:
+            # 缓存中没有,从数据库中获取
+            from app.admin.crud.crud_intent_org import intent_org_dao
+            async with async_db_session.begin() as db:
+                org = await intent_org_dao.get(db, obj.org_id)
+                if not org and org.status is not 1:
+                    log.error(f"表单信息提取时,机构不存在 org_id: {obj.org_id}")
+                    return None
+                org_data = CurrentIntentOrgIns(**select_as_dict(org))
+                # 将数据放进缓存
+                await redis_client.setex(
+                    key,
+                    settings.JWT_USER_REDIS_EXPIRE_SECONDS,
+                    org_data.model_dump_json(),
+                )
+        else:
+            org_data = CurrentIntentOrgIns(**json.loads(org_json))
+
+        # 开始提取
+        intent_schema = {
+            "name": "array",
+            "schema": {  # 添加 schema 字段
+                "type": "object",
+                "description": "从通话记录中提取表单值",
+                "properties": {
+                    "dataIndex": {"type": "string", "description": "表单ID"},
+                    "value": {
+                        "type": "array",
+                        "description": "表单值",
+                        "items": {"type": "string"}
+                    },
+                },
+                "required": ["dataIndex", "value"]
+            }
+        }
+        messages = [
+            {"role": "system", "content": f"""# 任务
+请帮助user从通话记录中提取表单值,并返回一个JSON格式的表单值。
+
+# 返回值示例
+* 如表单类型为 input、autoComplete、textarea,返回示例:["表单值"]
+* 如表单类型为 radio、select,返回示例:["值1"]
+* 如表单类型为 checkbox,返回示例:["值1", "值2"]
+* 如表单类型为 cascader,返回示例:["一级值1", "二级值3"]
+* 如表单类型为 date,返回示例:["2025-01-01"]"""
+             },
+            {
+                "role": "user",
+                "content": f"""# 表单数据
+{obj.form_data}
+
+# 聊天记录
+{obj.chat_history}
+"""
+            }
+        ]
+        response_data = generate_json(org_data.openai_key, org_data.openai_base, messages, intent_schema)
+        if response_data and isinstance(response_data.choices, list) and len(response_data.choices) > 0:
+            first_choice = response_data.choices[0]
+            if first_choice and first_choice.message:
+                response_json = first_choice.message.content
+                if response_json:
+                    form_value = json.loads(response_json)
+                    obj.form_value = form_value
+                    async with async_db_session.begin() as db:
+                        await form_records_dao.create(db, obj)
+                    return form_value
+
+
+
+
+
+def generate_json(api_key: str, openai_base: str, messages: list[dict], json_schema: dict):
+    try:
+        client_args = {}
+        if api_key:
+            client_args["api_key"] = api_key
+        if openai_base:
+            client_args["base_url"] = openai_base
+
+        oai_client = OpenAI(**client_args)
+
+        completion = oai_client.chat.completions.create(
+            model="gpt-4o",
+            messages=messages,
+            response_format={
+                "type": "json_schema",
+                "json_schema": json_schema
+            }
+        )
+        log.error(f"[oai] api_key failed: {api_key}")
+        log.error(f"[oai] openai_base failed: {openai_base}")
+        log.error(f"[oai] completion failed: {completion}")
+        if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
+            first_choice = completion.choices[0]
+            if first_choice and first_choice.message:
+                # return first_choice.message.content
+                return completion
+    except Exception as e:
+        log.error(f"[oai] generate_json failed: {e}")
+
+form_records_service: FormRecordsService = FormRecordsService()

+ 2 - 0
app/router.py

@@ -5,6 +5,7 @@ from fastapi import APIRouter
 from app.gpt.api import v1 as gpt_v1
 from app.admin.api import v1 as admin_v1
 from app.call_center.api.intent import v1 as intent_v1
+from app.call_center.api.form import v1 as form_v1
 
 route = APIRouter()
 
@@ -12,4 +13,5 @@ route.include_router(gpt_v1)
 
 call_center_route = APIRouter()
 call_center_route.include_router(intent_v1)
+call_center_route.include_router(form_v1)
 call_center_route.include_router(admin_v1)

+ 30 - 0
model/form_records.py

@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from datetime import datetime
+from uuid import UUID
+
+import sqlalchemy as sa
+from sqlalchemy import TIMESTAMP, text
+
+from common.model import Base, id_key_str
+from sqlalchemy.dialects import mysql
+from sqlalchemy.orm import Mapped, mapped_column
+
+
+class FormRecords(Base):
+    """call record"""
+
+    __tablename__ = 'form_records'
+
+    id: Mapped[id_key_str] = mapped_column()
+    created_at: Mapped[datetime | None] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'), comment='Create Time | 创建日期')
+    updated_at: Mapped[datetime | None] = mapped_column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'), comment='Update Time | 修改日期')
+    external_id: Mapped[str] = mapped_column(sa.String(255), default='', sort_order=2, comment='外部id')
+    industry_type: Mapped[int] = mapped_column(sa.Integer(), default=0, sort_order=3, comment='评分规则代码 0 通用 1 教育')
+    chat_history: Mapped[str] = mapped_column(sa.TEXT(), default='', sort_order=4, comment='通话记录')
+    form_data: Mapped[dict | None] = mapped_column(sa.JSON(), default=None, sort_order=12, comment='')
+    form_value: Mapped[dict | None] = mapped_column(sa.JSON(), default=None, sort_order=13, comment='')
+    org_id: Mapped[int] = mapped_column(sa.BIGINT(), default=0, sort_order=7, comment='机构 ID')
+    status: Mapped[int] = mapped_column(mysql.TINYINT(), default=0, sort_order=10, comment='状态 0 入库 1 已判断 2 已回调')
+    request_data: Mapped[dict | None] = mapped_column(sa.JSON(), default=None, sort_order=12, comment='')
+    response_data: Mapped[dict | None] = mapped_column(sa.JSON(), default=None, sort_order=13, comment='')