models.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package credit
  2. import (
  3. "fmt"
  4. "math"
  5. "regexp"
  6. "sort"
  7. "strings"
  8. "wechat-api/ent/custom_types"
  9. "wechat-api/hook/dify"
  10. )
  11. var modelArray = []string{
  12. "o1",
  13. "gpt-4o",
  14. "gpt-4.1",
  15. "o3-mini",
  16. "moonshot-v1-32k",
  17. "deepseek-r1",
  18. "moonshot-v1-8k",
  19. "gpt-4.1-mini",
  20. "gpt-3.5-turbo",
  21. "qwen-max",
  22. "doubao1.5-pro-256k",
  23. "deepseek-v3",
  24. "qwq-32b-preview",
  25. "gpt-4o-mini",
  26. "qwen2.5-14b-instruct-1m",
  27. "gpt-4.1-nano",
  28. "doubao1.5-pro",
  29. "doubao1.5-pro-32k",
  30. "chatglm3",
  31. "qwen-turbo",
  32. "doubao1.5-lite-32k",
  33. }
  34. var priceArray = []float64{
  35. 0.01,
  36. 0.001667,
  37. 0.001333,
  38. 0.000733,
  39. 0.000548,
  40. 0.000365,
  41. 0.000274,
  42. 0.000267,
  43. 0.00025,
  44. 0.000219,
  45. 0.000205,
  46. 0.000183,
  47. 0.000137,
  48. 0.0001,
  49. 0.000068,
  50. 0.000067,
  51. 0.000046,
  52. 0.000046,
  53. 0.000023,
  54. 0.000014,
  55. 0.000014,
  56. }
  57. func getModelName(modelName string) string {
  58. // 将字符串转换为小写
  59. return strings.ToLower(modelName)
  60. }
  61. func GetModelPrice(modelName string) (model string, price float64) {
  62. difyModelName := getModelName(modelName)
  63. for i, v := range modelArray {
  64. if v == difyModelName {
  65. return v, priceArray[i]
  66. }
  67. }
  68. return modelArray[0], priceArray[0]
  69. }
  70. func ComputePrice(price float64, tokens uint64) float64 {
  71. scale := float64(1000000)
  72. return math.Round(price*float64(tokens)*scale) / scale
  73. }
  74. // Subtraction 保留小数点后6位的精确减法
  75. func Subtraction(number1, number2 float64) float64 {
  76. d1 := number1 * 1000000
  77. d2 := number2 * 1000000
  78. res := math.Floor(d1-d2) / 1000000
  79. return res
  80. }
  81. func ComputeModelPrice(response interface{}) (model string, price float64) {
  82. fmt.Printf("response=%v \n", response)
  83. // 先获取所有本次响应里所有的model
  84. modelInputArray := make([]string, 0)
  85. // 如果是 gpt/chat/submit 处调用这里会是非正常的结构,所以要提前处理
  86. if _, ok := response.(dify.ChatResp); ok {
  87. return modelArray[0], priceArray[0]
  88. }
  89. // 如果是标准 VResponse 结构
  90. if _, ok := response.(custom_types.VResponse); ok {
  91. if modelName := response.(custom_types.VResponse).Model; modelName != "" {
  92. modelInputArray = append(modelInputArray, modelName)
  93. }
  94. // 获取 responseData 下所有model
  95. responseDataVal := response.(custom_types.VResponse).ResponseData
  96. if responseDataVal != nil {
  97. for _, res := range responseDataVal {
  98. fmt.Printf("model=%v \n", res.Model)
  99. if res.Model == "" {
  100. continue
  101. } else {
  102. modelInputArray = append(modelInputArray, strings.ToLower(res.Model))
  103. }
  104. }
  105. }
  106. }
  107. fmt.Printf("modelInputArray=%v \n", modelInputArray)
  108. // 如果精确搜索能找到model,返回model及其价格
  109. if len(modelInputArray) > 0 {
  110. for i, providedModelName := range modelInputArray {
  111. for _, modelName := range modelArray {
  112. if modelName == providedModelName {
  113. return modelName, priceArray[i]
  114. }
  115. }
  116. }
  117. }
  118. // 如果精确搜索搜不到,则把所有model按从长倒短,挨个匹配下,看有无符合条件
  119. // 如果模型里带日期,这时候要在筛选一次,比如 gpt-4.1-mini-2024-07-18 其实应该匹配到 gpt-4.1-mini
  120. // 对模型名数组按名字长度排序,然后按长到短匹配
  121. sortedModelArray := make([]string, len(modelArray))
  122. copy(sortedModelArray, modelArray)
  123. // 使用長度降序排序模型陣列
  124. sort.Slice(sortedModelArray, func(i, j int) bool {
  125. return len(sortedModelArray[i]) > len(sortedModelArray[j])
  126. })
  127. // 排序后按长到短挨个匹配合适的model
  128. if len(sortedModelArray) > 0 {
  129. for _, model := range sortedModelArray {
  130. for _, modelName := range modelInputArray {
  131. re := regexp.MustCompile(`(?i)` + regexp.QuoteMeta(model))
  132. if re.MatchString(modelName) {
  133. fmt.Printf("In models.go the given model_name: %s match: %s\n", modelName, model)
  134. model, price := GetModelPrice(model)
  135. return model, price
  136. }
  137. }
  138. }
  139. }
  140. // 最后返回最高的价格
  141. return modelArray[0], priceArray[0]
  142. }