123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- import asyncio
- import json
- import re
- from typing import Dict
- import aiohttp
- from openai import OpenAI, AsyncOpenAI
- from openai.types import ResponseFormatJSONSchema
- from openai.types.chat import completion_create_params, ChatCompletion
- from pydantic import BaseModel
- from common.log import log
- async def generate_text(api_key: str, openai_base: str, model: str, messages: list[dict]) -> (str | None, ChatCompletion | None):
- try:
- client_args = {}
- if api_key:
- client_args["api_key"] = api_key
- if openai_base:
- client_args["base_url"] = openai_base
- oai_client = AsyncOpenAI(**client_args)
- completion = await oai_client.chat.completions.create(
- model=model,
- temperature=0,
- messages=messages
- )
- if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
- first_choice = completion.choices[0]
- if first_choice and first_choice.message:
- return first_choice.message.content, completion
- except Exception as e:
- log.error(f"[oai] generate_json failed: {e}")
- async def generate_json(api_key: str, openai_base: str, model: str, messages: list[dict], json_schema: dict | str) -> (dict | list | None, ChatCompletion | None):
- try:
- # 拼接参数
- client_args = {}
- client_args["api_key"] = api_key
- client_args["base_url"] = openai_base
- if model == "deepseek-v3" or model == "DeepSeek-V3":
- response_format = { "type": "json_object" }
- messages[-1]["content"] = messages[-1]["content"] + f"""
- # 请以下方的json结构输出
- {json_schema}
- """
- else:
- response_format = {
- "type": "json_schema",
- "json_schema": json_schema
- }
- oai_client = AsyncOpenAI(**client_args)
- log.info(f"[oai] model: {model}")
- log.info(f"[oai] messages: {messages}")
- log.info(f"[oai] response_format: {response_format}")
- completion = await oai_client.chat.completions.create(
- model=model,
- messages=messages,
- response_format=response_format
- )
- log.info(f"[oai] completion: {completion}")
- if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
- first_choice = completion.choices[0]
- if first_choice and first_choice.message and first_choice.message.content:
- if model == "deepseek-v3" or model == "DeepSeek-V3":
- # match = re.search(r'\{.*\}', first_choice.message.content, re.DOTALL)
- # if not match:
- # return None
- # content = match.group(0)
- # 去除字符串开头和结尾的 ```json 和 ```
- if first_choice.message.content.startswith('```json'):
- # 提取中间部分
- content = first_choice.message.content[len('```json'):-len('```')].strip()
- else:
- content = first_choice.message.content
- log.info(f"[oai] content: {content}")
- try:
- response_json = json.loads(content)
- except json.JSONDecodeError:
- response_json = None
- return response_json, completion
- else:
- content = first_choice.message.content
- response_json = json.loads(content)
- result = response_json.get('properties', response_json)
- return result, completion
- except Exception as e:
- log.error(f"[oai] generate_json failed: {e}")
- # async def generate_json_by_class(api_key: str, openai_base: str, model: str, messages: list[dict], json_schema: any):
- # try:
- # client_args = {}
- # if api_key:
- # client_args["api_key"] = api_key
- # if openai_base:
- # client_args["base_url"] = openai_base
- #
- # oai_client = AsyncOpenAI(**client_args)
- #
- # completion = await oai_client.beta.chat.completions.parse(
- # model=model,
- # messages=messages,
- # response_format=json_schema
- # )
- # if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
- # first_choice = completion.choices[0]
- # if first_choice and first_choice.message:
- # # return first_choice.message.content
- # return completion
- # except Exception as e:
- # log.error(f"[oai] generate_json failed: {e}")
- async def send_request_with_retry(url: str, data: Dict, headers: Dict[str, str], max_retries: int, delay_between_retries: int) -> bool:
- for attempt in range(max_retries):
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url, json=data, headers=headers, timeout=10) as response:
- response_data = await response.json()
- log.info(f"send_request_with_retry {url}: {response_data}")
- if response.status == 200:
- return True
- except (aiohttp.ClientError, asyncio.TimeoutError) as e:
- log.error(f"请求异常:{e}")
- if attempt < max_retries - 1:
- print("重试中...")
- await asyncio.sleep(delay_between_retries)
- return False
|