123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- package intercept
- import (
- "context"
- "fmt"
- "wechat-api/ent"
- "wechat-api/ent/contact"
- "wechat-api/ent/predicate"
- "wechat-api/ent/server"
- "wechat-api/ent/wx"
- "entgo.io/ent/dialect/sql"
- )
- type Query interface {
-
- Type() string
-
- Limit(int)
-
- Offset(int)
-
- Unique(bool)
-
- Order(...func(*sql.Selector))
-
-
- WhereP(...func(*sql.Selector))
- }
- type Func func(context.Context, Query) error
- func (f Func) Intercept(next ent.Querier) ent.Querier {
- return ent.QuerierFunc(func(ctx context.Context, q ent.Query) (ent.Value, error) {
- query, err := NewQuery(q)
- if err != nil {
- return nil, err
- }
- if err := f(ctx, query); err != nil {
- return nil, err
- }
- return next.Query(ctx, q)
- })
- }
- type TraverseFunc func(context.Context, Query) error
- func (f TraverseFunc) Intercept(next ent.Querier) ent.Querier {
- return next
- }
- func (f TraverseFunc) Traverse(ctx context.Context, q ent.Query) error {
- query, err := NewQuery(q)
- if err != nil {
- return err
- }
- return f(ctx, query)
- }
- type ContactFunc func(context.Context, *ent.ContactQuery) (ent.Value, error)
- func (f ContactFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
- if q, ok := q.(*ent.ContactQuery); ok {
- return f(ctx, q)
- }
- return nil, fmt.Errorf("unexpected query type %T. expect *ent.ContactQuery", q)
- }
- type TraverseContact func(context.Context, *ent.ContactQuery) error
- func (f TraverseContact) Intercept(next ent.Querier) ent.Querier {
- return next
- }
- func (f TraverseContact) Traverse(ctx context.Context, q ent.Query) error {
- if q, ok := q.(*ent.ContactQuery); ok {
- return f(ctx, q)
- }
- return fmt.Errorf("unexpected query type %T. expect *ent.ContactQuery", q)
- }
- type ServerFunc func(context.Context, *ent.ServerQuery) (ent.Value, error)
- func (f ServerFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
- if q, ok := q.(*ent.ServerQuery); ok {
- return f(ctx, q)
- }
- return nil, fmt.Errorf("unexpected query type %T. expect *ent.ServerQuery", q)
- }
- type TraverseServer func(context.Context, *ent.ServerQuery) error
- func (f TraverseServer) Intercept(next ent.Querier) ent.Querier {
- return next
- }
- func (f TraverseServer) Traverse(ctx context.Context, q ent.Query) error {
- if q, ok := q.(*ent.ServerQuery); ok {
- return f(ctx, q)
- }
- return fmt.Errorf("unexpected query type %T. expect *ent.ServerQuery", q)
- }
- type WxFunc func(context.Context, *ent.WxQuery) (ent.Value, error)
- func (f WxFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
- if q, ok := q.(*ent.WxQuery); ok {
- return f(ctx, q)
- }
- return nil, fmt.Errorf("unexpected query type %T. expect *ent.WxQuery", q)
- }
- type TraverseWx func(context.Context, *ent.WxQuery) error
- func (f TraverseWx) Intercept(next ent.Querier) ent.Querier {
- return next
- }
- func (f TraverseWx) Traverse(ctx context.Context, q ent.Query) error {
- if q, ok := q.(*ent.WxQuery); ok {
- return f(ctx, q)
- }
- return fmt.Errorf("unexpected query type %T. expect *ent.WxQuery", q)
- }
- func NewQuery(q ent.Query) (Query, error) {
- switch q := q.(type) {
- case *ent.ContactQuery:
- return &query[*ent.ContactQuery, predicate.Contact, contact.OrderOption]{typ: ent.TypeContact, tq: q}, nil
- case *ent.ServerQuery:
- return &query[*ent.ServerQuery, predicate.Server, server.OrderOption]{typ: ent.TypeServer, tq: q}, nil
- case *ent.WxQuery:
- return &query[*ent.WxQuery, predicate.Wx, wx.OrderOption]{typ: ent.TypeWx, tq: q}, nil
- default:
- return nil, fmt.Errorf("unknown query type %T", q)
- }
- }
- type query[T any, P ~func(*sql.Selector), R ~func(*sql.Selector)] struct {
- typ string
- tq interface {
- Limit(int) T
- Offset(int) T
- Unique(bool) T
- Order(...R) T
- Where(...P) T
- }
- }
- func (q query[T, P, R]) Type() string {
- return q.typ
- }
- func (q query[T, P, R]) Limit(limit int) {
- q.tq.Limit(limit)
- }
- func (q query[T, P, R]) Offset(offset int) {
- q.tq.Offset(offset)
- }
- func (q query[T, P, R]) Unique(unique bool) {
- q.tq.Unique(unique)
- }
- func (q query[T, P, R]) Order(orders ...func(*sql.Selector)) {
- rs := make([]R, len(orders))
- for i := range orders {
- rs[i] = orders[i]
- }
- q.tq.Order(rs...)
- }
- func (q query[T, P, R]) WhereP(ps ...func(*sql.Selector)) {
- p := make([]P, len(ps))
- for i := range ps {
- p[i] = ps[i]
- }
- q.tq.Where(p...)
- }
|