关于golang:用GO写一个RPC框架-s05-客户端编写

61次阅读

共计 9376 个字符,预计需要花费 24 分钟才能阅读完成。

前言

后面几章咱们实现了 服务端的编写 当初开始客户端编写

https://github.com/dollarkill…

Client

type Client struct {options *Options}

func NewClient(discover discovery.Discovery, options ...Option) *Client {
    client := &Client{options: defaultOptions(),
    }

    client.options.Discovery = discover

    for _, fn := range options {fn(client.options)
    }

    return client
}

option

type Options struct {
    Discovery         discovery.Discovery                 // 服务发现插件
    loadBalancing     load_banlancing.LoadBalancing       // 负载平衡插件
    serializationType codes.SerializationType             // 序列化插件
    compressorType    codes.CompressorType                // 压缩插件

    pool         int                                      // 连接池大小
    cryptology   cryptology.Cryptology
    rsaPublicKey []byte
    writeTimeout time.Duration
    readTimeout  time.Duration
    heartBeat    time.Duration
    Trace        bool
    AUTH         string                                   // AUTH TOKEN
}

func defaultOptions() *Options {defaultPoolSize := runtime.NumCPU() * 4
    if defaultPoolSize < 20 {defaultPoolSize = 20}

    return &Options{
        pool:              defaultPoolSize,
        serializationType: codes.MsgPack,
        compressorType:    codes.Snappy,
        loadBalancing:     load_banlancing.NewPolling(),
        cryptology:        cryptology.AES,
        rsaPublicKey: []byte(`
-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
        writeTimeout: time.Minute,
        readTimeout:  time.Minute * 3,
        heartBeat:    time.Minute,
        Trace:        false,
        AUTH:         "",
    }
}

具体每个链接

type Connect struct {
    Client     *Client
    pool       *connectPool
    close      chan struct{}
    serverName string
}

func (c *Client) NewConnect(serverName string) (conn *Connect, err error) {
    connect := &Connect{
        Client:     c,
        serverName: serverName,
        close:      make(chan struct{}),
    }

    connect.pool, err = initPool(connect)
    return connect, err
}

初始化连接池

func initPool(c *Connect) (*connectPool, error) {
    cp := &connectPool{
        connect: c,
        pool:    make(chan LightClient, c.Client.options.pool),
    }

    return cp, cp.initPool()}

func (c *connectPool) initPool() error {hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName) // 调用服务发现 查看  发现具体服务
    if err != nil {return err}

    if len(hosts) == 0 {return errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
    }

    c.connect.Client.options.loadBalancing.InitBalancing(hosts)  // 初始化 负载平衡插件

    // 初始化连接池
    for i := 0; i < c.connect.Client.options.pool; i++ {client, err := newBaseClient(c.connect.serverName, c.connect.Client.options)  // 建设链接
        if err != nil {return errors.WithStack(err)
        }
        c.pool <- client
    }

    return nil
}

// 连接池中获取一个链接
func (c *connectPool) Get(ctx context.Context) (LightClient, error) {
    select {case <-ctx.Done():
        return nil, errors.New("pool get timeout")
    case r := <-c.pool:
        return r, nil
    }
}

// 放回一个链接
func (c *connectPool) Put(client LightClient) {if client.Error() == nil {
        c.pool <- client
        return
    }

        // 如果 client.Error() 有异样  须要新初始化一个链接 放入连接池
    go func() {fmt.Println("The server starts to restore")
        for {time.Sleep(time.Second)
            hosts, err := c.connect.Client.options.Discovery.Discovery(c.connect.serverName)
            if err != nil {log.Println(err)
                continue
            }

            if len(hosts) == 0 {err := errors.New(fmt.Sprintf("%s server 404", c.connect.serverName))
                log.Println(err)
                continue
            }

            c.connect.Client.options.loadBalancing.InitBalancing(hosts)
            baseClient, err := newBaseClient(c.connect.serverName, c.connect.Client.options)
            if err != nil {log.Println(err)
                continue
            }

            c.pool <- baseClient
            fmt.Println("Service recovery success")
            break
        }
    }()}

Connect 调用具体服务

func (c *Connect) Call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}) error {ctxT, _ := context.WithTimeout(context.TODO(), time.Second*6)
    var err error
        
        // 连接池中获取一个链接
    client, err := c.pool.Get(ctxT)
    if err != nil {return errors.WithStack(err)
    }
        
        // 用完 放回链接
    defer func() {c.pool.Put(client)
    }()

        // 设置 token
    ctx.SetValue("Light_AUTH", c.Client.options.AUTH)
    // 具体调用
        err = client.Call(ctx, serviceMethod, request, response)
    if err != nil {return errors.WithStack(err)
    }

    return nil
}

调用外围 重点

温习 s03 协定设计

/**
    协定设计
    起始符 :  版本号 :  crc32 校验 :   magicNumberSize:    serverNameSize :   serverMethodSize :  metaDataSize : payloadSize:  respType :   compressorType :    serializationType :    magicNumber :  serverName :   serverMethod :  metaData :  payload
        0x05  :  0x01  :     4     :        4         :         4         :         4          :       4       :      4     :      1    :          1       :           1          :        xxx     :       xxx   :        xxx     :    xxx    :    xxx
*/

留神: 每一个申请都有一个 magicNumber 都有一个申请 ID

单个链接定义

type BaseClient struct {
    conn       net.Conn
    options    *Options
    serverName string

    aesKey        []byte
    serialization codes.Serialization
    compressor    codes.Compressor

    respInterMap map[string]*respMessage
    respInterRM  sync.RWMutex     // 返回构造锁
    writeMu      sync.Mutex   // 写锁

    err   error          // 谬误
    close chan struct{}  // 用于敞开服务}

type respMessage struct {response interface{}
    ctx      *light.Context
    respChan chan error
}

初始化单个链接

func newBaseClient(serverName string, options *Options) (*BaseClient, error) {
        // 服务发现用
    service, err := options.loadBalancing.GetService()
    if err != nil {return nil, err}
    con, err := transport.Client.Gen(service.Protocol, service.Addr)
    if err != nil {return nil, errors.WithStack(err)
    }

    serialization, ex := codes.SerializationManager.Get(options.serializationType)
    if !ex {return nil, pkg.ErrSerialization404}

    compressor, ex := codes.CompressorManager.Get(options.compressorType)
    if !ex {return nil, pkg.ErrCompressor404}

    // 握手
    encrypt, err := cryptology.RsaEncrypt([]byte(options.AUTH), options.rsaPublicKey)
    if err != nil {return nil, err}

    aesKey := []byte(strings.ReplaceAll(uuid.New().String(), "-", ""))

    // 替换秘钥
    aesKey2, err := cryptology.RsaEncrypt(aesKey, options.rsaPublicKey)
    if err != nil {return nil, err}
    handshake := protocol.EncodeHandshake(aesKey2, encrypt, []byte(""))
    _, err = con.Write(handshake)
    if err != nil {con.Close()
        return nil, err
    }

    hsk := &protocol.Handshake{}
    err = hsk.Handshake(con)
    if err != nil {con.Close()
        return nil, err
    }
    if hsk.Error != nil && len(hsk.Error) > 0 {con.Close()
        err := string(hsk.Error)
        return nil, errors.New(err)
    }

    bc := &BaseClient{
        serverName:    serverName,
        conn:          con,
        options:       options,
        serialization: serialization,
        compressor:    compressor,
        respInterMap:  map[string]*respMessage{},
        aesKey:        aesKey,
        close:         make(chan struct{}),
    }

    go bc.heartBeat()  // 心跳服务
    go bc.processMessageManager()  // 返回音讯的解决

    return bc, nil
}

heartBeat 心跳服务

func (b *BaseClient) heartBeat() {defer func() {fmt.Println("heartBeat Close")
    }()

loop:
    for {
        select {
        case <-b.close:
            break loop
        case <-time.After(b.options.heartBeat):  // 定时发送心跳
            _, i, err := protocol.EncodeMessage("x", []byte(""), []byte(""), []byte(""), byte(protocol.HeartBeat), byte(b.options.compressorType), byte(b.options.serializationType), []byte(""))
            if err != nil {log.Println(err)
                break
            }
            now := time.Now()
            b.conn.SetDeadline(now.Add(b.options.writeTimeout))
            b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
            b.writeMu.Lock()
            _, err = b.conn.Write(i)
            b.writeMu.Unlock()
            if err != nil {
                b.err = err
                break loop
            }
        }
    }
}

processMessageManager 返回音讯的解决服务 (留神这里能够并发的来)

func (b *BaseClient) processMessageManager() {defer func() {fmt.Println("processMessageManager Close")
    }()

    for {magic, respChan, err := b.processMessage() // 解决某个音讯
        if err == nil && magic == "" {continue}

        if err != nil && magic == "" {break}

        if err != nil && magic != "" && respChan != nil {respChan <- err}

        if err == nil && magic != "" && respChan != nil {close(respChan)
        }
    }
}

func (b *BaseClient) processMessage() (magic string, respChan chan error, err error) {
    // 3. 封装回执
    now := time.Now()
    b.conn.SetReadDeadline(now.Add(b.options.readTimeout))

    proto := protocol.NewProtocol()
    msg, err := proto.IODecode(b.conn)
    if err != nil {
        b.err = err
        close(b.close)
        return "", nil, err
    }

    // heartbeat
    if msg.Header.RespType == byte(protocol.HeartBeat) {
        if b.options.Trace {log.Println("is HeartBeat")
        }
        return "", nil, nil
    }

    b.respInterRM.RLock()
    message, ex := b.respInterMap[msg.MagicNumber]
    b.respInterRM.RUnlock()
    if !ex { // 不存在 代表音讯曾经生效
        if b.options.Trace {log.Println("Not Ex", msg.MagicNumber)
        }
        return "", nil, nil
    }

    comp, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
    if !ex {return "", nil, nil}

    // 1. 解压缩
    msg.MetaData, err = comp.Unzip(msg.MetaData)
    if err != nil {return "", nil, err}
    msg.Payload, err = comp.Unzip(msg.Payload)
    if err != nil {return "", nil, err}
    // 2. 解密
    msg.MetaData, err = cryptology.AESDecrypt(b.aesKey, msg.MetaData)
    if err != nil {if len(msg.MetaData) != 0 {return "", nil, err}
        msg.Payload = []byte("")
    }

    msg.Payload, err = cryptology.AESDecrypt(b.aesKey, msg.Payload)
    if err != nil {if len(msg.Payload) != 0 {return "", nil, err}
        msg.Payload = []byte("")
    }
    // 3. 反序列化 RespError
    mtData := make(map[string]string)
    err = b.serialization.Decode(msg.MetaData, &mtData)
    if err != nil {return "", nil, err}

    message.ctx.SetMetaData(mtData)

    value := message.ctx.Value("RespError")
    if value != "" {return msg.MagicNumber, message.respChan, errors.New(value)
    }

    return msg.MagicNumber, message.respChan, b.serialization.Decode(msg.Payload, message.response)
}

服务调用

func (b *BaseClient) call(ctx *light.Context, serviceMethod string, request interface{}, response interface{}, respChan chan error) (magic string, err error) {metaData := ctx.GetMetaData()  // 获取 ctx 进行根底编码

    // 1. 结构申请
    // 1.1 序列化
    serviceNameByte := []byte(b.serverName)
    serviceMethodByte := []byte(serviceMethod)
    var metaDataBytes []byte
    var requestBytes []byte
    metaDataBytes, err = b.serialization.Encode(metaData)
    if err != nil {return "", err}
    requestBytes, err = b.serialization.Encode(request)
    if err != nil {return "", err}

    // 1.2 加密
    metaDataBytes, err = cryptology.AESEncrypt(b.aesKey, metaDataBytes)
    if err != nil {return "", err}

    requestBytes, err = cryptology.AESEncrypt(b.aesKey, requestBytes)
    if err != nil {return "", err}

    compressorType := b.options.compressorType
    if len(metaDataBytes) > compressorMin && len(metaDataBytes) < compressorMax {
        // 1.3 压缩
        metaDataBytes, err = b.compressor.Zip(metaDataBytes)
        if err != nil {return "", err}

        requestBytes, err = b.compressor.Zip(requestBytes)
        if err != nil {return "", err}
    } else {compressorType = codes.RawData}

    // 1.4 封装音讯
    magic, message, err := protocol.EncodeMessage("", serviceNameByte, serviceMethodByte, metaDataBytes, byte(protocol.Request), byte(compressorType), byte(b.options.serializationType), requestBytes)
    if err != nil {return "", err}
    // 2. 发送音讯
    if b.options.writeTimeout > 0 {now := time.Now()
        timeout := ctx.GetTimeout() // 如果 ctx 存在设置 则采纳 返之应用默认配置
        if timeout > 0 {b.conn.SetDeadline(now.Add(timeout))
            b.conn.SetWriteDeadline(now.Add(timeout))
        } else {b.conn.SetDeadline(now.Add(b.options.writeTimeout))
            b.conn.SetWriteDeadline(now.Add(b.options.writeTimeout))
        }
    }
    // 写 MAP
    b.respInterRM.Lock()
    b.respInterMap[magic] = &respMessage{
        response: response,
        ctx:      ctx,
        respChan: respChan,
    }
    b.respInterRM.Unlock()

    // 有点暴力呀 间接上锁
    b.writeMu.Lock()
    _, err = b.conn.Write(message)
    b.writeMu.Unlock()
    if err != nil {
        if b.options.Trace {log.Println(err)
        }
        b.err = err
        return "", errors.WithStack(err)
    }

    return magic, nil
}

正文完
 0