create_batch_msg_logic.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package batch_msg
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "strings"
  7. "time"
  8. "wechat-api/ent/custom_types"
  9. "wechat-api/ent/label"
  10. "wechat-api/ent"
  11. "wechat-api/ent/contact"
  12. "wechat-api/ent/labelrelationship"
  13. "wechat-api/internal/svc"
  14. "wechat-api/internal/types"
  15. "wechat-api/internal/utils/dberrorhandler"
  16. "github.com/suyuan32/simple-admin-common/msg/errormsg"
  17. "github.com/suyuan32/simple-admin-common/utils/uuidx"
  18. "github.com/zeromicro/go-zero/core/logx"
  19. )
  20. type CreateBatchMsgLogic struct {
  21. ctx context.Context
  22. svcCtx *svc.ServiceContext
  23. logx.Logger
  24. }
  25. func NewCreateBatchMsgLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateBatchMsgLogic {
  26. return &CreateBatchMsgLogic{
  27. ctx: ctx,
  28. svcCtx: svcCtx,
  29. Logger: logx.WithContext(ctx),
  30. }
  31. }
  32. func (l *CreateBatchMsgLogic) CreateBatchMsg(req *types.BatchMsgInfo) (*types.BaseMsgResp, error) {
  33. organizationId := l.ctx.Value("organizationId").(uint64)
  34. var err error
  35. var tagstring, tag string
  36. allContact, allGroup := false, false
  37. tagIdArray := make([]uint64, 0)
  38. for _, labelId := range req.Labels {
  39. tagIdArray = append(tagIdArray, labelId)
  40. if labelId == uint64(0) { //全部
  41. allContact = true
  42. }
  43. }
  44. for _, labelId := range req.GroupLabels {
  45. tagIdArray = append(tagIdArray, labelId)
  46. if labelId == uint64(0) { //全部
  47. allGroup = true
  48. }
  49. }
  50. startTime := time.Now()
  51. sendNow := true
  52. // req.StartTimeStr 不为nil并且不为空
  53. if req.StartTimeStr != nil && *req.StartTimeStr != "" {
  54. // 将 req.StartTimeStr 转换为 req.StartTime
  55. startTime, err = time.Parse("2006-01-02 15:04:05", *req.StartTimeStr)
  56. if err != nil {
  57. // 处理错误,例如打印错误并返回
  58. l.Logger.Errorf("时间字符串转换错误: %v", err)
  59. return nil, err
  60. }
  61. sendNow = false
  62. }
  63. // 把 req.Msg 字符串的内容 json_decode 到 msgArray
  64. // 然后再把每一条信息都给所有指定的用户记录到 message_records 表
  65. var msgArray []custom_types.Action
  66. if req.Msg != nil && *req.Msg != "" {
  67. err = json.Unmarshal([]byte(*req.Msg), &msgArray)
  68. if err != nil {
  69. return nil, errors.New("解析JSON失败")
  70. }
  71. }
  72. var msgActionList []custom_types.Action
  73. if len(msgArray) > 0 {
  74. msgActionList = make([]custom_types.Action, len(msgArray))
  75. for i, msg := range msgArray {
  76. if msg.Type == 1 {
  77. msgActionList[i] = custom_types.Action{
  78. Type: msg.Type,
  79. Content: msg.Content,
  80. }
  81. } else {
  82. msgActionList[i] = custom_types.Action{
  83. Type: msg.Type,
  84. Content: msg.Content,
  85. Meta: &custom_types.Meta{
  86. Filename: msg.Meta.Filename,
  87. },
  88. }
  89. }
  90. }
  91. }
  92. userList, groupList := make([]*ent.Contact, 0), make([]*ent.Contact, 0)
  93. tagMap := make(map[string][]uint64)
  94. if allContact && allGroup {
  95. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1或2的数据
  96. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeIn(1, 2)).All(l.ctx)
  97. if err != nil {
  98. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  99. }
  100. tagids := make([]uint64, 0, 1)
  101. tagids = append(tagids, uint64(0))
  102. tagMap["contact_tag"] = tagids
  103. tagMap["group_tag"] = tagids
  104. tagByte, err := json.Marshal(tagMap)
  105. if err != nil {
  106. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  107. }
  108. tagstring = string(tagByte)
  109. tag = "全部"
  110. } else {
  111. if allContact { // 所有联系人
  112. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1的数据
  113. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(1)).All(l.ctx)
  114. if err != nil {
  115. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  116. }
  117. tag = "全部联系人"
  118. } else {
  119. userList, err = l.getContactList(req.Labels, *req.Fromwxid, 1)
  120. if err != nil {
  121. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  122. }
  123. }
  124. if allGroup { //所有群
  125. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为2的数据
  126. groupList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(2)).All(l.ctx)
  127. if err != nil {
  128. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  129. }
  130. tag = "全部群"
  131. } else {
  132. groupList, err = l.getContactList(req.GroupLabels, *req.Fromwxid, 2)
  133. if err != nil {
  134. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  135. }
  136. }
  137. tagMap["contact_tag"] = req.Labels
  138. tagMap["group_tag"] = req.GroupLabels
  139. tagByte, err := json.Marshal(tagMap)
  140. if err != nil {
  141. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  142. }
  143. tagstring = string(tagByte)
  144. tagArray := l.getLabelListByIds(req.Labels, req.GroupLabels)
  145. if tag != "" {
  146. tagArray = append(tagArray, tag)
  147. }
  148. tag = strings.Join(tagArray, ",")
  149. }
  150. // 这里是根据userlist 和 批量消息数 获得最终待发送消息总数
  151. total := int32(len(userList))*int32(len(msgActionList)) + int32(len(groupList))*int32(len(msgActionList))
  152. if total == 0 {
  153. return &types.BaseMsgResp{Msg: "未查询到收信人,请重新选择", Code: 3}, nil
  154. }
  155. uuid := uuidx.NewUUID()
  156. batchNo := uuid.String()
  157. var sendTime *time.Time
  158. if !sendNow {
  159. sendTime = &startTime
  160. }
  161. _, err = l.svcCtx.DB.BatchMsg.Create().
  162. SetNotNilBatchNo(&batchNo).
  163. SetNotNilFromwxid(req.Fromwxid).
  164. SetNotNilMsg(req.Msg).
  165. SetNotNilTag(&tag).
  166. SetNotNilTagids(&tagstring).
  167. SetTotal(total).
  168. SetNotNilTaskName(req.TaskName).
  169. SetNotNilStartTime(&startTime).
  170. SetNillableSendTime(sendTime).
  171. SetNotNilType(req.Type).
  172. SetNotNilOrganizationID(&organizationId).
  173. SetCtype(1).
  174. Save(l.ctx)
  175. if err != nil {
  176. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  177. }
  178. return &types.BaseMsgResp{Msg: errormsg.CreateSuccess}, nil
  179. }
  180. func (l *CreateBatchMsgLogic) getContactList(labels []uint64, fromWxId string, stype int) ([]*ent.Contact, error) {
  181. // 获取 label_relationship 表中,label_id 等于 labids 的 contact_id
  182. labelrelationships, err := l.svcCtx.DB.LabelRelationship.Query().Where(labelrelationship.LabelIDIn(labels...)).All(l.ctx)
  183. if err != nil {
  184. return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil)
  185. }
  186. contact_ids := make([]uint64, 0, len(labelrelationships))
  187. for _, labelrelationship := range labelrelationships {
  188. contact_ids = append(contact_ids, labelrelationship.ContactID)
  189. }
  190. userList := make([]*ent.Contact, 0)
  191. if len(contact_ids) > 0 {
  192. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 并且 id 等于 contact_ids 并且 type 为1或2 的数据
  193. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(fromWxId), contact.IDIn(contact_ids...), contact.TypeEQ(stype)).All(l.ctx)
  194. if err != nil {
  195. return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil)
  196. }
  197. }
  198. return userList, nil
  199. }
  200. func (l *CreateBatchMsgLogic) getLabelListByIds(labels, groupLabels []uint64) []string {
  201. result := make([]string, 0)
  202. labels = append(labels, groupLabels...)
  203. if len(labels) > 0 {
  204. contacts, err := l.svcCtx.DB.Label.Query().Where(
  205. label.IDIn(labels...),
  206. ).Select("name").All(l.ctx)
  207. l.Logger.Infof("contacts=%v", contacts)
  208. if err != nil {
  209. return result
  210. }
  211. for _, val := range contacts {
  212. result = append(result, val.Name)
  213. }
  214. }
  215. return result
  216. }