gorm_logger.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package database
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/zeromicro/go-zero/core/logx"
  6. "github.com/zeromicro/go-zero/core/service"
  7. "gorm.io/gorm"
  8. "gorm.io/gorm/logger"
  9. "strconv"
  10. "time"
  11. )
  12. type GormLogger struct {
  13. SlowThreshold time.Duration
  14. Mode string
  15. }
  16. func NewGormLogger(mode string) *GormLogger {
  17. return &GormLogger{
  18. SlowThreshold: 200 * time.Millisecond, // 一般超过200毫秒就算慢查所以不使用配置进行更改
  19. Mode: mode,
  20. }
  21. }
  22. var _ logger.Interface = (*GormLogger)(nil)
  23. func (l *GormLogger) LogMode(lev logger.LogLevel) logger.Interface {
  24. return &GormLogger{}
  25. }
  26. func (l *GormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
  27. logx.WithContext(ctx).Infof(msg, data)
  28. }
  29. func (l *GormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
  30. logx.WithContext(ctx).Errorf(msg, data)
  31. }
  32. func (l *GormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
  33. logx.WithContext(ctx).Errorf(msg, data)
  34. }
  35. // MicrosecondsStr 将时间间隔转换为微秒并返回字符串
  36. func (l *GormLogger) MicrosecondsStr(d time.Duration) string {
  37. microseconds := d.Microseconds()
  38. return strconv.FormatInt(microseconds, 10)
  39. }
  40. func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
  41. // 获取运行时间
  42. elapsed := time.Since(begin)
  43. // 获取 SQL 语句和返回条数
  44. sql, rows := fc()
  45. // 通用字段
  46. logFields := []logx.LogField{
  47. logx.Field("sql", sql),
  48. logx.Field("time", l.MicrosecondsStr(elapsed)),
  49. logx.Field("rows", rows),
  50. }
  51. // Gorm 错误
  52. if err != nil {
  53. // 记录未找到的错误使用 warning 等级
  54. if errors.Is(err, gorm.ErrRecordNotFound) {
  55. logx.WithContext(ctx).Infow("Database ErrRecordNotFound", logFields...)
  56. } else {
  57. // 其他错误使用 error 等级
  58. logFields = append(logFields, logx.Field("catch error", err))
  59. logx.WithContext(ctx).Errorw("Database Error", logFields...)
  60. }
  61. }
  62. // 慢查询日志
  63. if l.SlowThreshold != 0 && elapsed > l.SlowThreshold {
  64. logx.WithContext(ctx).Sloww("Database Slow Log", logFields...)
  65. }
  66. // 非生产模式下,记录所有 SQL 请求
  67. if l.Mode != service.ProMode {
  68. logx.WithContext(ctx).Infow("Database Query", logFields...)
  69. }
  70. }