Gorm-源码分析二-简单query分析

30次阅读

共计 7167 个字符,预计需要花费 18 分钟才能阅读完成。

简单使用

上一篇文章我们已经知道了不使用 orm 如何调用 mysql 数据库,这篇文章我们要查看的是 Gorm 的源码,从最简单的一个查询语句作为切入点。当然 Gorm 的功能很多支持 where 条件支持外键 group 等等功能,这些功能大体的流程都是差不多先从简单的看起。下面先看如何使用

package main

import (
    "fmt"
    _ "github.com/go-sql-driver/mysql"
    "github.com/panlei/gorm"
)

var db *gorm.DB

func main() {InitMysql()
    var u User
    db.Where("Id = ?", 2).First(&u)

}

func InitMysql() {
    var err error
    db, err = gorm.Open("mysql", "root:***@******@tcp(**.***.***.***:****)/databasename?charset=utf8&loc=Asia%2FShanghai&parseTime=True")
    fmt.Println(err)
}

type User struct {
    Id       int    `gorm:"primary_key;column:Id" json:"id"`
    UserName string `json:"userName" gorm:"column:UserName"`
    Password string `json:"password" gorm:"column:Password"`
}

func (User) TableName() string {return "user"}
  1. 首先注册对象 添加 tag 标注主键 设置数据库 column 名 数据库名可以和字段名不一样
  2. 设置 table 名字,如果类名和数据库名一样则不需要设置
  3. 初始化数据库连接 创建 GormDB 对象 使用 Open 方法返回 DB
  4. 使用最简单的 where 函数和 First 来获取 翻译过来的 sql 语句就是
    select * from user where Id = 2 limit 1

源码分析

1. DB、search、callback 对象

DB 对象包含所有处理 mysql 的方法,主要的还是 search 和 callbacks
search 对象存放了所有的查询条件
Callback 对象存放了 sql 的调用链 存放了一系列的 callback 函数

// Gorm 中使用的 DB 对象
type DB struct {
    sync.RWMutex                // 锁
    Value        interface{}
    Error        error
    RowsAffected int64

    // single db
    db                SQLCommon  // 原生 db.sql 对象,包含 query 相关的原生方法
    blockGlobalUpdate bool
    logMode           logModeValue
    logger            logger
    search            *search      // 保存搜索的条件 where, limit, group,比如调用 db.clone()时,会指定 search
    values            sync.Map

    // global db
    parent        *DB
    callbacks     *Callback        // 当前 sql 绑定的函数调用链
    dialect       Dialect           // 不同数据库适配注册 sql.db
    singularTable bool
}
// search 对象存放了所有查询的条件 从名字就能看出来 有 where or having 各种条件
type search struct {
    db               *DB
    whereConditions  []map[string]interface{}
    orConditions     []map[string]interface{}
    notConditions    []map[string]interface{}
    havingConditions []map[string]interface{}
    joinConditions   []map[string]interface{}
    initAttrs        []interface{}
    assignAttrs      []interface{}
    selects          map[string]interface{}
    omits            []string
    orders           []interface{}
    preload          []searchPreload
    offset           interface{}
    limit            interface{}
    group            string
    tableName        string
    raw              bool
    Unscoped         bool
    ignoreOrderQuery bool
}

// Callback 记录了调用链 区分了 update delete query create 等不同
// 这些 callback 都在 callback.go 中的 init 方法中注册
type Callback struct {
    logger     logger
    creates    []*func(scope *Scope)
    updates    []*func(scope *Scope)
    deletes    []*func(scope *Scope)
    queries    []*func(scope *Scope)
    rowQueries []*func(scope *Scope)
    processors []*CallbackProcessor}
2. Scope 对象 每一个 sql 操作所有的信息
// 包含每一个 sql 操作的相关信息
type Scope struct {
    Search          *search            // 检索条件在 1 中是同一个对象
    Value           interface{}     // 保存实体类
    SQL             string            // sql 语句
    SQLVars         []interface{}
    db              *DB                // DB 对象
    instanceID      string
    primaryKeyField *Field
    skipLeft        bool
    fields          *[]*Field        // 字段
    selectAttrs     *[]string}
3. Open 函数 创建数据库 DB 对象初始化数据库连接

Open 函数主要是根据输入的数据库信息

  1. 创建连接
  2. 初始化 DB 对象
  3. 设置调用链函数
  4. 发送一个 ping 测试是否能可用
func Open(dialect string, args ...interface{}) (db *DB, err error) {if len(args) == 0 {err = errors.New("invalid database source")
        return nil, err
    }
    var source string
    // 接口对应 database/sql 接口
    var dbSQL SQLCommon
    var ownDbSQL bool

    switch value := args[0].(type) {
    // 如果第一个参数是 string 则使用 sql.open 创建连接 返回 sql.Db 对象
    case string:
        var driver = dialect
        if len(args) == 1 {source = value} else if len(args) >= 2 {
            driver = value
            source = args[1].(string)
        }
        dbSQL, err = sql.Open(driver, source)
        ownDbSQL = true
        // 如果是 SQLCommon 直接赋值
    case SQLCommon:
        dbSQL = value
        ownDbSQL = false
    default:
        return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
    }
    // 初始化 DB 对象
    db = &DB{
        db:        dbSQL,
        logger:    defaultLogger,
        // 在 callback_create.go
        // callback_deleta.go
        // callback_query.go
        // callback_save.go
        // callback_update.go  等等 注册了默认的 callback
        // callback 文件中的 init 方法中注册了默认的 callback 方法
        // 主要处理的逻辑几乎都在各个不同的 callback 中
        callbacks: DefaultCallback,
        dialect:   newDialect(dialect, dbSQL),
    }
    db.parent = db
    if err != nil {return}
    // 发送一个 ping 确认这个连接是可用的
    if d, ok := dbSQL.(*sql.DB); ok {if err = d.Ping(); err != nil && ownDbSQL {d.Close()
        }
    }
    return
}
4. where 函数 创建数据库 DB 对象初始化数据库连接

其实跟 where 相同的还有很多比如 having、group、limit、select、or、not 等等其实操作都是类似的
调用 DB 对象函数 where 在调用 search 对象的 where
具体就是把 where 条件放到 search 对象中的 whereConditions 中等最后拼接 sql

func (s *DB) Where(query interface{}, args ...interface{}) *DB {return s.clone().search.Where(query, args...).db
}

func (s *search) Where(query interface{}, values ...interface{}) *search {s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
    return s
}
4. First 函数
  1. First 创建一个 Scpoe
  2. NewScope 中调用 DB.clone 函数 克隆 DB 对象 底层指针属性不变
  3. inlineCondition 初始化查询条件
  4. callCallbacks 函数,传入调用链 for 循环调用传入的函数
func (s *DB) First(out interface{}, where ...interface{}) *DB {newScope := s.NewScope(out)
    newScope.Search.Limit(1)

    // callCallbacks 调用 query callback 方法
    return newScope.Set("gorm:order_by_primary_key", "ASC").
        inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
// 新建 Scope
func (s *DB) NewScope(value interface{}) *Scope {
    // 克隆 DB 对象
    dbClone := s.clone()
    dbClone.Value = value
    scope := &Scope{db: dbClone, Value: value}
    if s.search != nil {scope.Search = s.search.clone()
    } else {scope.Search = &search{}
    }
    return scope
}

func (scope *Scope) inlineCondition(values ...interface{}) *Scope {if len(values) > 0 {scope.Search.Where(values[0], values[1:]...)
    }
    return scope
}

// 循环调用传入的 functions
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {defer func() {if err := recover(); err != nil {if db, ok := scope.db.db.(sqlTx); ok {db.Rollback()
            }
            panic(err)
        }
    }()
    // 使用 for 循环 调用回调函数
    for _, f := range funcs {(*f)(scope)
        if scope.skipLeft {break}
    }
    return scope
}
5. 真正查询方法 queryCallback
  1. queryCallback 方法组成 sql 语句 调用 database/sql 中的 query 方法在上一篇分析中可以看到 循环 rows 结果获取数据
  2. prepareQuerySQL 方法主要是组成 sql 语句的方法 通过反射获取字段名表明等属性
func queryCallback(scope *Scope) {if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {return}

    //we are only preloading relations, dont touch base model
    if _, skip := scope.InstanceGet("gorm:only_preload"); skip {return}

    defer scope.trace(NowFunc())

    var (
        isSlice, isPtr bool
        resultType     reflect.Type
        results        = scope.IndirectValue())
    // 找到排序字段
    if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {if primaryField := scope.PrimaryField(); primaryField != nil {scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
        }
    }

    if value, ok := scope.Get("gorm:query_destination"); ok {results = indirect(reflect.ValueOf(value))
    }

    if kind := results.Kind(); kind == reflect.Slice {
        isSlice = true
        resultType = results.Type().Elem()
        results.Set(reflect.MakeSlice(results.Type(), 0, 0))

        if resultType.Kind() == reflect.Ptr {
            isPtr = true
            resultType = resultType.Elem()}
    } else if kind != reflect.Struct {scope.Err(errors.New("unsupported destination, should be slice or struct"))
        return
    }
    // 准备查询语句
    scope.prepareQuerySQL()

    if !scope.HasError() {
        scope.db.RowsAffected = 0
        if str, ok := scope.Get("gorm:query_option"); ok {scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
        }
        // 调用 database/sql 包中的 query 来查询
        if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {defer rows.Close()

            columns, _ := rows.Columns()
            // 循环 rows 组成对象
            for rows.Next() {
                scope.db.RowsAffected++

                elem := results
                if isSlice {elem = reflect.New(resultType).Elem()}

                scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

                if isSlice {
                    if isPtr {results.Set(reflect.Append(results, elem.Addr()))
                    } else {results.Set(reflect.Append(results, elem))
                    }
                }
            }

            if err := rows.Err(); err != nil {scope.Err(err)
            } else if scope.db.RowsAffected == 0 && !isSlice {scope.Err(ErrRecordNotFound)
            }
        }
    }
}
func (scope *Scope) prepareQuerySQL() {
    // 如果是 rwa 则组织 sql 语句
    if scope.Search.raw {scope.Raw(scope.CombinedConditionSql())
    } else {
        // 组织 select 语句
        // scope.selectSQL() 组织 select 需要查询的字段
        // scope.QuotedTableName() 获取表名
        // scope.CombinedConditionSql()组织条件语句
        scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
    }
    return
}

总结

这篇文章从一个最简单的 where 条件和 first 函数入手了解 Gorm 主体的流程和主要的对象。其实可以看出 Gorm 的本质:

  1. 创建 DB 对象,注册 mysql 连接
  2. 创建对象 通过 tag 设置一些主键,外键等
  3. 通过 where 或者其他比如 group having 等设置查询的条件
  4. 通过 first 函数最终生成 sql 语句
  5. 调用 database/sql 中的方法通过 mysql 驱动真正的查询数据
  6. 通过反射来组成对象或者是数组对象提供使用

之后我们可以看一些复杂的操作,比如外键 预加载 多表查询等操作。

正文完
 0