oai.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import asyncio
  2. import json
  3. import re
  4. from typing import Dict
  5. import aiohttp
  6. from openai import OpenAI, AsyncOpenAI
  7. from openai.types import ResponseFormatJSONSchema
  8. from openai.types.chat import completion_create_params, ChatCompletion
  9. from pydantic import BaseModel
  10. from common.log import log
  11. async def generate_text(api_key: str, openai_base: str, model: str, messages: list[dict]) -> (str | None, ChatCompletion | None):
  12. try:
  13. client_args = {}
  14. if api_key:
  15. client_args["api_key"] = api_key
  16. if openai_base:
  17. client_args["base_url"] = openai_base
  18. oai_client = AsyncOpenAI(**client_args)
  19. completion = await oai_client.chat.completions.create(
  20. model=model,
  21. temperature=0,
  22. messages=messages
  23. )
  24. if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
  25. first_choice = completion.choices[0]
  26. if first_choice and first_choice.message:
  27. return first_choice.message.content, completion
  28. except Exception as e:
  29. log.error(f"[oai] generate_json failed: {e}")
  30. async def generate_json(api_key: str, openai_base: str, model: str, messages: list[dict], json_schema: dict | str) -> (dict | list | None, ChatCompletion | None):
  31. try:
  32. # 拼接参数
  33. client_args = {}
  34. client_args["api_key"] = api_key
  35. client_args["base_url"] = openai_base
  36. if model == "deepseek-v3" or model == "DeepSeek-V3":
  37. response_format = { "type": "json_object" }
  38. messages[-1]["content"] = messages[-1]["content"] + f"""
  39. # 请以下方的json结构输出
  40. {json_schema}
  41. """
  42. else:
  43. response_format = {
  44. "type": "json_schema",
  45. "json_schema": json_schema
  46. }
  47. oai_client = AsyncOpenAI(**client_args)
  48. log.info(f"[oai] model: {model}")
  49. log.info(f"[oai] messages: {messages}")
  50. log.info(f"[oai] response_format: {response_format}")
  51. completion = await oai_client.chat.completions.create(
  52. model=model,
  53. messages=messages,
  54. response_format=response_format
  55. )
  56. log.info(f"[oai] completion: {completion}")
  57. if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
  58. first_choice = completion.choices[0]
  59. if first_choice and first_choice.message and first_choice.message.content:
  60. if model == "deepseek-v3" or model == "DeepSeek-V3":
  61. # match = re.search(r'\{.*\}', first_choice.message.content, re.DOTALL)
  62. # if not match:
  63. # return None
  64. # content = match.group(0)
  65. # 去除字符串开头和结尾的 ```json 和 ```
  66. if first_choice.message.content.startswith('```json'):
  67. # 提取中间部分
  68. content = first_choice.message.content[len('```json'):-len('```')].strip()
  69. else:
  70. content = first_choice.message.content
  71. log.info(f"[oai] content: {content}")
  72. try:
  73. response_json = json.loads(content)
  74. except json.JSONDecodeError:
  75. response_json = None
  76. return response_json, completion
  77. else:
  78. content = first_choice.message.content
  79. response_json = json.loads(content)
  80. result = response_json.get('properties', response_json)
  81. return result, completion
  82. except Exception as e:
  83. log.error(f"[oai] generate_json failed: {e}")
  84. # async def generate_json_by_class(api_key: str, openai_base: str, model: str, messages: list[dict], json_schema: any):
  85. # try:
  86. # client_args = {}
  87. # if api_key:
  88. # client_args["api_key"] = api_key
  89. # if openai_base:
  90. # client_args["base_url"] = openai_base
  91. #
  92. # oai_client = AsyncOpenAI(**client_args)
  93. #
  94. # completion = await oai_client.beta.chat.completions.parse(
  95. # model=model,
  96. # messages=messages,
  97. # response_format=json_schema
  98. # )
  99. # if completion and isinstance(completion.choices, list) and len(completion.choices) > 0:
  100. # first_choice = completion.choices[0]
  101. # if first_choice and first_choice.message:
  102. # # return first_choice.message.content
  103. # return completion
  104. # except Exception as e:
  105. # log.error(f"[oai] generate_json failed: {e}")
  106. async def send_request_with_retry(url: str, data: Dict, headers: Dict[str, str], max_retries: int, delay_between_retries: int) -> bool:
  107. for attempt in range(max_retries):
  108. try:
  109. async with aiohttp.ClientSession() as session:
  110. async with session.post(url, json=data, headers=headers, timeout=10) as response:
  111. response_data = await response.json()
  112. log.info(f"send_request_with_retry {url}: {response_data}")
  113. if response.status == 200:
  114. return True
  115. except (aiohttp.ClientError, asyncio.TimeoutError) as e:
  116. log.error(f"请求异常:{e}")
  117. if attempt < max_retries - 1:
  118. print("重试中...")
  119. await asyncio.sleep(delay_between_retries)
  120. return False