package batch_msg import ( "context" "encoding/json" "errors" "strings" "time" "wechat-api/ent/custom_types" "wechat-api/ent/label" "wechat-api/ent" "wechat-api/ent/contact" "wechat-api/ent/labelrelationship" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/dberrorhandler" "github.com/suyuan32/simple-admin-common/msg/errormsg" "github.com/suyuan32/simple-admin-common/utils/uuidx" "github.com/zeromicro/go-zero/core/logx" ) type CreateBatchMsgLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } func NewCreateBatchMsgLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateBatchMsgLogic { return &CreateBatchMsgLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), } } func (l *CreateBatchMsgLogic) CreateBatchMsg(req *types.BatchMsgInfo) (*types.BaseMsgResp, error) { organizationId := l.ctx.Value("organizationId").(uint64) var err error var tagstring, tag string allContact, allGroup := false, false tagIdArray := make([]uint64, 0) for _, labelId := range req.Labels { tagIdArray = append(tagIdArray, labelId) if labelId == uint64(0) { //全部 allContact = true } } for _, labelId := range req.GroupLabels { tagIdArray = append(tagIdArray, labelId) if labelId == uint64(0) { //全部 allGroup = true } } startTime := time.Now() sendNow := true // req.StartTimeStr 不为nil并且不为空 if req.StartTimeStr != nil && *req.StartTimeStr != "" { // 将 req.StartTimeStr 转换为 req.StartTime startTime, err = time.Parse("2006-01-02 15:04:05", *req.StartTimeStr) if err != nil { // 处理错误,例如打印错误并返回 l.Logger.Errorf("时间字符串转换错误: %v", err) return nil, err } sendNow = false } // 把 req.Msg 字符串的内容 json_decode 到 msgArray // 然后再把每一条信息都给所有指定的用户记录到 message_records 表 var msgArray []custom_types.Action if req.Msg != nil && *req.Msg != "" { err = json.Unmarshal([]byte(*req.Msg), &msgArray) if err != nil { return nil, errors.New("解析JSON失败") } } var msgActionList []custom_types.Action if len(msgArray) > 0 { msgActionList = make([]custom_types.Action, len(msgArray)) for i, msg := range msgArray { if msg.Type == 1 { msgActionList[i] = custom_types.Action{ Type: msg.Type, Content: msg.Content, } } else { msgActionList[i] = custom_types.Action{ Type: msg.Type, Content: msg.Content, Meta: &custom_types.Meta{ Filename: msg.Meta.Filename, }, } } } } userList, groupList := make([]*ent.Contact, 0), make([]*ent.Contact, 0) tagMap := make(map[string][]uint64) if allContact && allGroup { // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1或2的数据 userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeIn(1, 2)).All(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } tagids := make([]uint64, 0, 1) tagids = append(tagids, uint64(0)) tagMap["contact_tag"] = tagids tagMap["group_tag"] = tagids tagByte, err := json.Marshal(tagMap) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } tagstring = string(tagByte) tag = "全部" } else { if allContact { // 所有联系人 // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为1的数据 userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(1)).All(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } tag = "全部联系人" } else { userList, err = l.getContactList(req.Labels, *req.Fromwxid, 1) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } } if allGroup { //所有群 // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 的 type 为2的数据 groupList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(*req.Fromwxid), contact.TypeEQ(2)).All(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } tag = "全部群" } else { groupList, err = l.getContactList(req.GroupLabels, *req.Fromwxid, 2) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } } tagMap["contact_tag"] = req.Labels tagMap["group_tag"] = req.GroupLabels tagByte, err := json.Marshal(tagMap) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } tagstring = string(tagByte) tagArray := l.getLabelListByIds(req.Labels, req.GroupLabels) if tag != "" { tagArray = append(tagArray, tag) } tag = strings.Join(tagArray, ",") } // 这里是根据userlist 和 批量消息数 获得最终待发送消息总数 total := int32(len(userList))*int32(len(msgActionList)) + int32(len(groupList))*int32(len(msgActionList)) if total == 0 { return &types.BaseMsgResp{Msg: "未查询到收信人,请重新选择", Code: 3}, nil } uuid := uuidx.NewUUID() batchNo := uuid.String() var sendTime *time.Time if !sendNow { sendTime = &startTime } _, err = l.svcCtx.DB.BatchMsg.Create(). SetNotNilBatchNo(&batchNo). SetNotNilFromwxid(req.Fromwxid). SetNotNilMsg(req.Msg). SetNotNilTag(&tag). SetNotNilTagids(&tagstring). SetTotal(total). SetNotNilTaskName(req.TaskName). SetNotNilStartTime(&startTime). SetNillableSendTime(sendTime). SetNotNilType(req.Type). SetNotNilOrganizationID(&organizationId). SetCtype(1). Save(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, req) } return &types.BaseMsgResp{Msg: errormsg.CreateSuccess}, nil } func (l *CreateBatchMsgLogic) getContactList(labels []uint64, fromWxId string, stype int) ([]*ent.Contact, error) { // 获取 label_relationship 表中,label_id 等于 labids 的 contact_id labelrelationships, err := l.svcCtx.DB.LabelRelationship.Query().Where(labelrelationship.LabelIDIn(labels...)).All(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil) } contact_ids := make([]uint64, 0, len(labelrelationships)) for _, labelrelationship := range labelrelationships { contact_ids = append(contact_ids, labelrelationship.ContactID) } userList := make([]*ent.Contact, 0) if len(contact_ids) > 0 { // 获取 contact 表中 wx_wxid 等于 req.Fromwxid 并且 id 等于 contact_ids 并且 type 为1或2 的数据 userList, err = l.svcCtx.DB.Contact.Query().Where(contact.WxWxid(fromWxId), contact.IDIn(contact_ids...), contact.TypeEQ(stype)).All(l.ctx) if err != nil { return nil, dberrorhandler.DefaultEntError(l.Logger, err, nil) } } return userList, nil } func (l *CreateBatchMsgLogic) getLabelListByIds(labels, groupLabels []uint64) []string { result := make([]string, 0) labels = append(labels, groupLabels...) if len(labels) > 0 { contacts, err := l.svcCtx.DB.Label.Query().Where( label.IDIn(labels...), ).Select("name").All(l.ctx) l.Logger.Infof("contacts=%v", contacts) if err != nil { return result } for _, val := range contacts { result = append(result, val.Name) } } return result }