create_batch_msg_logic.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package batch_msg
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "strings"
  7. "time"
  8. "wechat-api/ent/custom_types"
  9. "wechat-api/ent/label"
  10. "wechat-api/ent"
  11. "wechat-api/ent/contact"
  12. "wechat-api/ent/labelrelationship"
  13. "wechat-api/internal/svc"
  14. "wechat-api/internal/types"
  15. "wechat-api/internal/utils/dberrorhandler"
  16. "github.com/suyuan32/simple-admin-common/msg/errormsg"
  17. "github.com/suyuan32/simple-admin-common/utils/uuidx"
  18. "github.com/zeromicro/go-zero/core/logx"
  19. )
  20. type CreateBatchMsgLogic struct {
  21. ctx context.Context
  22. svcCtx *svc.ServiceContext
  23. logx.Logger
  24. }
  25. func NewCreateBatchMsgLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateBatchMsgLogic {
  26. return &CreateBatchMsgLogic{
  27. ctx: ctx,
  28. svcCtx: svcCtx,
  29. Logger: logx.WithContext(ctx),
  30. }
  31. }
  32. func (l *CreateBatchMsgLogic) CreateBatchMsg(req *types.BatchMsgInfo) (*types.BaseMsgResp, error) {
  33. organizationId := l.ctx.Value("organizationId").(uint64)
  34. var err error
  35. var tagstring, tag string
  36. allContact, allGroup := false, false
  37. tagIdArray := make([]uint64, 0)
  38. for _, labelId := range req.Labels {
  39. tagIdArray = append(tagIdArray, labelId)
  40. if labelId == uint64(0) { //全部
  41. allContact = true
  42. }
  43. }
  44. for _, labelId := range req.GroupLabels {
  45. tagIdArray = append(tagIdArray, labelId)
  46. if labelId == uint64(0) { //全部
  47. allGroup = true
  48. }
  49. }
  50. startTime := time.Now()
  51. sendNow := true
  52. // req.StartTimeStr 不为nil并且不为空
  53. if req.StartTimeStr != nil && *req.StartTimeStr != "" {
  54. // 将 req.StartTimeStr 转换为 req.StartTime
  55. startTime, err = time.Parse("2006-01-02 15:04:05", *req.StartTimeStr)
  56. if err != nil {
  57. // 处理错误,例如打印错误并返回
  58. l.Logger.Errorf("时间字符串转换错误: %v", err)
  59. return nil, err
  60. }
  61. startTime = startTime.Add(-8 * time.Hour)
  62. sendNow = false
  63. }
  64. // 把 req.Msg 字符串的内容 json_decode 到 msgArray
  65. // 然后再把每一条信息都给所有指定的用户记录到 message_records 表
  66. var msgArray []custom_types.Action
  67. if req.Msg != nil && *req.Msg != "" {
  68. err = json.Unmarshal([]byte(*req.Msg), &msgArray)
  69. if err != nil {
  70. return nil, errors.New("解析JSON失败")
  71. }
  72. }
  73. var msgActionList []custom_types.Action
  74. if len(msgArray) > 0 {
  75. msgActionList = make([]custom_types.Action, len(msgArray))
  76. for i, msg := range msgArray {
  77. if msg.Type == 1 {
  78. msgActionList[i] = custom_types.Action{
  79. Type: msg.Type,
  80. Content: msg.Content,
  81. }
  82. } else {
  83. msgActionList[i] = custom_types.Action{
  84. Type: msg.Type,
  85. Content: msg.Content,
  86. Meta: &custom_types.Meta{
  87. Filename: msg.Meta.Filename,
  88. },
  89. }
  90. }
  91. }
  92. }
  93. userList, groupList := make([]*ent.Contact, 0), make([]*ent.Contact, 0)
  94. tagMap := make(map[string][]uint64)
  95. if allContact && allGroup {
  96. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1或2的数据
  97. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeIn(1, 2)).All(l.ctx)
  98. if err != nil {
  99. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  100. }
  101. tagids := make([]uint64, 0, 1)
  102. tagids = append(tagids, uint64(0))
  103. tagMap["contact_tag"] = tagids
  104. tagMap["group_tag"] = tagids
  105. tagByte, err := json.Marshal(tagMap)
  106. if err != nil {
  107. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  108. }
  109. tagstring = string(tagByte)
  110. tag = "全部"
  111. } else {
  112. if allContact { // 所有联系人
  113. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1的数据
  114. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(1)).All(l.ctx)
  115. if err != nil {
  116. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  117. }
  118. tag = "全部联系人"
  119. } else {
  120. userList, err = l.getContactList(req.Labels, *req.Fromwxid, 1)
  121. if err != nil {
  122. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  123. }
  124. }
  125. if allGroup { //所有群
  126. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为2的数据
  127. groupList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(2)).All(l.ctx)
  128. if err != nil {
  129. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  130. }
  131. tag = "全部群"
  132. } else {
  133. groupList, err = l.getContactList(req.GroupLabels, *req.Fromwxid, 2)
  134. if err != nil {
  135. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  136. }
  137. }
  138. tagMap["contact_tag"] = req.Labels
  139. tagMap["group_tag"] = req.GroupLabels
  140. tagByte, err := json.Marshal(tagMap)
  141. if err != nil {
  142. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  143. }
  144. tagstring = string(tagByte)
  145. tagArray := l.getLabelListByIds(req.Labels, req.GroupLabels)
  146. if tag != "" {
  147. tagArray = append(tagArray, tag)
  148. }
  149. tag = strings.Join(tagArray, ",")
  150. }
  151. // 这里是根据userlist 和 批量消息数 获得最终待发送消息总数
  152. total := int32(len(userList))*int32(len(msgActionList)) + int32(len(groupList))*int32(len(msgActionList))
  153. if total == 0 {
  154. return &types.BaseMsgResp{Msg: "未查询到收信人,请重新选择", Code: 3}, nil
  155. }
  156. uuid := uuidx.NewUUID()
  157. batchNo := uuid.String()
  158. var sendTime *time.Time
  159. if !sendNow {
  160. sendTime = &startTime
  161. }
  162. _, err = l.svcCtx.DB.BatchMsg.Create().
  163. SetNotNilBatchNo(&batchNo).
  164. SetNotNilFromwxid(req.Fromwxid).
  165. SetNotNilMsg(req.Msg).
  166. SetNotNilTag(&tag).
  167. SetNotNilTagids(&tagstring).
  168. SetTotal(total).
  169. SetNotNilTaskName(req.TaskName).
  170. SetNotNilStartTime(&startTime).
  171. SetNillableSendTime(sendTime).
  172. SetNotNilType(req.Type).
  173. SetNotNilOrganizationID(&organizationId).
  174. SetNotNilCtype(req.Ctype).
  175. Save(l.ctx)
  176. if err != nil {
  177. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  178. }
  179. return &types.BaseMsgResp{Msg: errormsg.CreateSuccess}, nil
  180. }
  181. func (l *CreateBatchMsgLogic) getContactList(labels []uint64, fromWxId string, stype int) ([]*ent.Contact, error) {
  182. // 获取 label_relationship 表中,label_id 等于 labids 的 contact_id
  183. labelrelationships, err := l.svcCtx.DB.LabelRelationship.Query().Where(labelrelationship.LabelIDIn(labels...)).All(l.ctx)
  184. if err != nil {
  185. return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil)
  186. }
  187. contact_ids := make([]uint64, 0, len(labelrelationships))
  188. for _, labelrelationship := range labelrelationships {
  189. contact_ids = append(contact_ids, labelrelationship.ContactID)
  190. }
  191. userList := make([]*ent.Contact, 0)
  192. if len(contact_ids) > 0 {
  193. // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 并且 id 等于 contact_ids 并且 type 为1或2 的数据
  194. userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(fromWxId), contact.IDIn(contact_ids...), contact.TypeEQ(stype)).All(l.ctx)
  195. if err != nil {
  196. return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil)
  197. }
  198. }
  199. return userList, nil
  200. }
  201. func (l *CreateBatchMsgLogic) getLabelListByIds(labels, groupLabels []uint64) []string {
  202. result := make([]string, 0)
  203. labels = append(labels, groupLabels...)
  204. if len(labels) > 0 {
  205. contacts, err := l.svcCtx.DB.Label.Query().Where(
  206. label.IDIn(labels...),
  207. ).Select("name").All(l.ctx)
  208. l.Logger.Infof("contacts=%v", contacts)
  209. if err != nil {
  210. return result
  211. }
  212. for _, val := range contacts {
  213. result = append(result, val.Name)
  214. }
  215. }
  216. return result
  217. }