Gorm的使用心得和一些常用扩展(二)

Table of Contents

上一篇文章,我分享了自己在新增和更新的场景下,自己使用gorm的一些心得和扩展。本文,我将分享一些在查询的方面的心得。

首先,我把查询按照涉及到的表的数量分为:

  • 单表查询
  • 多表查询

按照查询范围又可以分为:

  • 查询一个
  • 范围查询
    • 查询一组
    • 有序查询
    • 查询前几个
    • 分页查询

在日常使用中,单表查询占据了多半的场景,把这部分的代码按照查询范围做一些封装,可以大大减少冗余的代码。

单表查询

于是,我仿照gorm API的风格,做了如下的封装:

ps:以下例子均以假定已定义user对象

查询一个

func (dw *DBExtension) GetOne(result interface{}, query interface{}, args ...interface{}) (found bool, err error) {
	var (
		tableNameAble TableNameAble
		ok            bool
	)

	if tableNameAble, ok = query.(TableNameAble); !ok {
		if tableNameAble, ok = result.(TableNameAble); !ok {
			return false, errors.New("neither the query nor result implement TableNameAble")
		}
	}

	err = dw.Table(tableNameAble.TableName()).Where(query, args...).First(result).Error

	if err == gorm.ErrRecordNotFound {
		dw.logger.LogInfoc("mysql", fmt.Sprintf("record not found for query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
		return false, nil
	}

	if err != nil {
		dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
		return false, err
	}

	return true, nil
}

这段值得说明的就是对查询不到数据时的处理,gorm是报了gorm.ErrRecordNotFound的error, 我是对这个错误做了特殊处理,用found这个boolean值表述这个特殊状态。

调用代码如下:

condition := User{Id:1}
result := User{}

if  found, err := dw.GetOne(&result, condition); !found {
	//not found
    if err != nil {
    	// has error
        return err
    }
    
}

也可以这样写,更加灵活的指定的查询条件:

result := User{}

if  found, err := dw.GetOne(&result, "id = ?" 1); !found {
	//not found
    if err != nil {
    	// has error
        return err
    }
    
}

两种写法执行的语句都是:

select * from test.user where id = 1

范围查询

针对四种范国查询,我做了如下封装:


func (dw *DBExtension) GetList(result interface{}, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, "", 0, 0, query, args)
}

func (dw *DBExtension) GetOrderedList(result interface{}, order string, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, 0, 0, query, args)
}

func (dw *DBExtension) GetFirstNRecords(result interface{}, order string, limit int, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, limit, 0, query, args)
}

func (dw *DBExtension) GetPageRangeList(result interface{}, order string, limit, offset int, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, limit, offset, query, args)
}

func (dw *DBExtension) getListCore(result interface{}, order string, limit, offset int, query interface{}, args []interface{}) error {
	var (
		tableNameAble TableNameAble
		ok            bool
	)

	if tableNameAble, ok = query.(TableNameAble); !ok {
		// type Result []*Item{}
		// result := &Result{}
		resultType := reflect.TypeOf(result)
		if resultType.Kind() != reflect.Ptr {
			return errors.New("result is not a pointer")
		}

		sliceType := resultType.Elem()
		if sliceType.Kind() != reflect.Slice {
			return errors.New("result doesn't point to a slice")
		}
		// *Item
		itemPtrType := sliceType.Elem()
		// Item
		itemType := itemPtrType.Elem()

		elemValue := reflect.New(itemType)
		elemValueType := reflect.TypeOf(elemValue)
		tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

		if elemValueType.Implements(tableNameAbleType) {
			return errors.New("neither the query nor result implement TableNameAble")
		}

		tableNameAble = elemValue.Interface().(TableNameAble)
	}

	db := dw.Table(tableNameAble.TableName()).Where(query, args...)
	if len(order) != 0 {
		db = db.Order(order)
	}

	if offset > 0 {
		db = db.Offset(offset)
	}

	if limit > 0 {
		db = db.Limit(limit)
	}

	if err := db.Find(result).Error; err != nil {
		dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, query is %+v, args are %+v, order is %s, limit is %d", tableNameAble.TableName(), query, args, order, limit))
		return err
	}

	return nil
}

为了减少冗余的代码,通用的逻辑写在getListCore函数里,里面用到了一些golang反射的知识。

但只要记得golang的反射和其它语言的反射最大的不同,是golang的反射是基本值而不是类型的,一切就好理解了。

其中的一个小技巧是如何判断一个类型是否实现了某个接口,用到了指向nil的指针。

	elemValue := reflect.New(itemType)
	elemValueType := reflect.TypeOf(elemValue)
	tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

	if elemValueType.Implements(tableNameAbleType) {
		return errors.New("neither the query nor result implement TableNameAble")
	}

关于具体的使用,就不再一一举例子了,熟悉gorm api的同学可以一眼看出。

多表查询

关于多表查询,因为不同场景很难抽取出不同,也就没有再做封装,但是我的经验是优先多使用gorm的方法,而不是自己拼sql。你想要做的gorm都可以实现。

这里,我偷个懒,贴出自己在项目中写的最复杂的一段代码,供各位看官娱乐。

一个复杂的例子

这段代码是从埋点数据的中间表,为了用通用的代码实现不同展示场景下的查询,代码设计的比较灵活,其中涉及了关联多表的查询,按查询条件动态过滤和聚合,还有分页查询的逻辑。

func buildCommonStatisticQuery(tableName, startDate, endDate string) *gorm.DB {
	query := models.DB().Table(tableName)

	if startDate == endDate || endDate == "" {
		query = query.Where("date = ?", startDate)
	} else {
		query = query.Where("date >= ? and date <= ?", startDate, endDate)
	}

	return query
}

func buildElementsStatisticQuery(startDate, endDate,  elemId string,  elemType int32) *gorm.DB {
	query := buildCommonStatisticQuery("spotanalysis.element_statistics", startDate, endDate)

	if elemId != "" && elemType != 0 {
		query = query.Where("element_id = ? and element_type = ?", elemId, elemType)
	}

	return query
}

func CountElementsStatistics(count *int32, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string ) error {
	query := buildElementsStatisticQuery(startDate, endDate,  elemId, elemType)

	query = whereInstAndApp(query, instId, appId)

	if len(groupFields) != 0 {
		query = query.Select(fmt.Sprintf("count(distinct(concat(%s)))", strings.Join(groupFields, ",")))
	} else {
		query = query.Select("count(id)")
	}

	query = query.Count(count)
	return query.Error
}


func GetElementsStatistics(result interface{}, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string, orderBy string, ascOrder bool, limit, offset int32) error {
	query := buildElementsStatisticQuery(startDate, endDate, elemId, elemType)
	if len(groupFields) != 0 {
		groupBy := strings.Join(groupFields, "`,`")
		groupBy = "`" + groupBy + "`"
		query = query.Group(groupBy)
		query = havingInstAndApp(query, instId, appId)

		sumFields := strings.Join([]string{
			"SUM(`element_statistics`.`mp_count`) AS `mp_count`",
			"SUM(`element_statistics`.`h5_count`) AS `h5_count`",
			"SUM(`element_statistics`.`total_count`) AS `total_count`",
			"SUM(`element_statistics`.`collection_count`) AS `collection_count`",
			"SUM(`element_statistics`.`mp_share_count`) AS `mp_share_count`",
			"SUM(`element_statistics`.`h5_share_count`) AS `h5_share_count`",
			"SUM(`element_statistics`.`poster_share_count`) AS `poster_share_count`",
			"SUM(`element_statistics`.`total_share_count`) AS `total_share_count`",
		}, ",")

		query = query.Select(groupBy + "," + sumFields)
	} else {
		query = whereInstAndApp(query, instId, appId)
	}

	query = getPagedList(query, orderBy, ascOrder, limit, offset)

	return query.Find(result).Error
}

func getPagedList(query *gorm.DB, orderBy string, ascOrder bool, limit , offset int32) *gorm.DB {
	if orderBy != "" {
		if ascOrder {
			orderBy += " asc"
		} else {
			orderBy += " desc"
		}
		query = query.Order(orderBy)
	}

	if offset != 0 {
		query = query.Offset(offset)
	}
	if limit != 0 {
		query = query.Limit(limit)
	}
	return query
}

func whereInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
	query = query.Where("inst_id = ?", instId)
	if appId != "" {
		query = query.Where("app_id = ?", appId)
	}
	return query
}

func havingInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
	query = query.Having("inst_id = ?", instId)
	if appId != "" {
		query = query.Having("app_id = ?", appId)
	}
	return query
}

感谢各位看官耐心看完,如果本文对你有用,请点个赞~~~

如果能到代码仓库:Github:Ksloveyuan/gorm-ex 给个✩star✩, 楼主就更加感谢了!