前言

通过上两篇的学习 咱们曾经理解了 服务端本地服务的注册, 服务端配置,协定 当初咱们开始写服务端的外围逻辑

https://github.com/dollarkill...

默认配置

咱们先看下默认的配置

func defaultOptions() *Options {    return &Options{        Protocol:     transport.TCP, // default TCP        Uri:          "0.0.0.0:8397",        UseHttp:      false,        readTimeout:  time.Minute * 3, // 心跳包 默认 3min        writeTimeout: time.Second * 30,        ctx:          context.Background(), // ctx 是管制服务退出的        options: map[string]interface{}{            "TCPKeepAlivePeriod": time.Minute * 3,        },        processChanSize: 1000,            Trace:           false,        RSAPublicKey: []byte(`-----BEGIN PUBLIC KEY----------END PUBLIC KEY-----`),        RSAPrivateKey: []byte(`-----BEGIN RSA PRIVATE KEY----------END RSA PRIVATE KEY-----`),        Discovery: &discovery.SimplePeerToPeer{},    }}

run

服务注册结束之后 调用Run办法 启动服务

func (s *Server) Run(options ...Option) error {        // 初始化 服务端配置    for _, fn := range options {        fn(s.options)    }    var err error        // 更具配置传入的protocol 获取到 网络插件 (KCP UDP TCP) 咱们等下细讲    s.options.nl, err = transport.Transport.Gen(s.options.Protocol, s.options.Uri)    if err != nil {        return err    }    log.Printf("LightRPC: %s  %s \n", s.options.Protocol, s.options.Uri)        // 这里是服务注册 咱们这里先跳过      if s.options.Discovery != nil {                // 读取服务配置文件        sIdb, err := ioutil.ReadFile("./light.conf")        if err != nil {                        // 如果没有 就生成 分布式ID            id, err := utils.DistributedID()            if err != nil {                return err            }            sIdb = []byte(id)        }        // 进行服务注册        sId := string(sIdb)        for k := range s.serviceMap {   // 进行服务注册             err := s.options.Discovery.Registry(k, s.options.registryAddr, s.options.weights, s.options.Protocol, s.options.MaximumLoad, &sId)            if err != nil {                return err            }            log.Printf("Discovery Registry: %s addr: %s SUCCESS", k, s.options.registryAddr)        }        ioutil.WriteFile("./light.conf", sIdb, 00666)    }                // 启动服务    return s.run()}func (s *Server) run() error {loop:    for {        select {        case <-s.options.ctx.Done():  // 查看是否须要退出服务            break loop        default:            accept, err := s.options.nl.Accept() // 获取一个链接            if err != nil {                log.Println(err)                continue            }            if s.options.Trace {                log.Println("connect: ", accept.RemoteAddr())            }            go s.process(accept) // 开一个协程去解决 该 链接        }    }    return nil}

咱们先回顾一下 上章讲的 握手逻辑

  1. 建设链接 通过非对称加密 传输 aes 密钥给服务端 (携带token)
  2. 服务端 验证 token 并记录 aes 密钥 前面与客户端交互 都采纳对称加密

具体解决 链接 process (重点!!!)

func (s *Server) process(conn net.Conn) {    defer func() {        // 网络不牢靠        if err := recover(); err != nil {            utils.PrintStack()            log.Println("Recover Err: ", err)        }    }()        // 每进来一个申请这里就ADD    s.options.Discovery.Add(1)    defer func() {        s.options.Discovery.Less(1) // 解决完 申请就退出        // 退出 回收句柄        err := conn.Close()          if err != nil {            log.Println(err)            return        }        if s.options.Trace {            log.Println("close connect: ", conn.RemoteAddr())        }    }()        // 这里定义一个xChannel 用于拆散 申请和返回    xChannel := utils.NewXChannel(s.options.processChanSize)    // 握手    handshake := protocol.Handshake{}    err := handshake.Handshake(conn)    if err != nil {        return    }                    // 非对称加密  解密 AES KEY    aesKey, err := cryptology.RsaDecrypt(handshake.Key, s.options.RSAPrivateKey)    if err != nil {        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))        conn.Write(encodeHandshake)        return    }        // 检测 AES KEY 是否正确    if len(aesKey) != 32 && len(aesKey) != 16 {        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte("aes key != 32 && key != 16"))        conn.Write(encodeHandshake)        return    }                // 解密 TOKEN    token, err := cryptology.RsaDecrypt(handshake.Token, s.options.RSAPrivateKey)    if err != nil {        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))        conn.Write(encodeHandshake)        return    }        // 对TOKEN进行校验      if s.options.AuthFunc != nil {        err := s.options.AuthFunc(light.DefaultCtx(), string(token))        if err != nil {            encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))            conn.Write(encodeHandshake)            return        }    }    // limit 限流    if s.options.Discovery.Limit() {        // 熔断        encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(pkg.ErrCircuitBreaker.Error()))        conn.Write(encodeHandshake)        log.Println(s.options.Discovery.Limit())        return    }                // 如果握手没有问题 则返回握手胜利    encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(""))    _, err = conn.Write(encodeHandshake)    if err != nil {        return    }            // send    go func() {    loop:        for {            select {                        // 这就是刚刚的xChannel 对读写进行拆散            case msg, ex := <-xChannel.Ch:                 if !ex {                    if s.options.Trace {                        log.Printf("ip: %s  close send server", conn.RemoteAddr())                    }                    break loop                }                now := time.Now()                if s.options.writeTimeout > 0 {                    conn.SetWriteDeadline(now.Add(s.options.writeTimeout))                }                // send message                _, err := conn.Write(msg)                if err != nil {                    if s.options.Trace {                        log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)                    }                    break loop                }            }        }    }()    defer func() {        xChannel.Close()    }()loop:    for { // 具体音讯获取        now := time.Now()        if s.options.readTimeout > 0 {            conn.SetReadDeadline(now.Add(s.options.readTimeout))        }        proto := protocol.NewProtocol()        msg, err := proto.IODecode(conn) // 获取一个音讯        if err != nil {            if err == io.EOF {                if s.options.Trace {                    log.Printf("ip: %s close", conn.RemoteAddr())                }                break loop            }            // 遇到谬误敞开链接            if s.options.Trace {                log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)            }            break loop        }        go s.processResponse(xChannel, msg, conn.RemoteAddr().String(), aesKey)    }}

具体解决 (重点!!!)

留神此RPC传输音讯都是编码过的 要进行转码

  • 第一层 为压缩编码
  • 第二层 为加密编码
  • 第三层 为序列化
func (s *Server) processResponse(xChannel *utils.XChannel, msg *protocol.Message, addr string, aesKey []byte) {    var err error    s.options.Discovery.Add(1)    defer func() {        s.options.Discovery.Less(1)        if err != nil {            if s.options.Trace {                log.Println("ProcessResponse Error: ", err, "  ID: ", addr)            }            xChannel.Close()        }    }()    // heartBeat 判断    if msg.Header.RespType == byte(protocol.HeartBeat) {        // 心跳返回        if s.options.Trace {            log.Println("HeartBeat: ", addr)        }        // 4. 打包        _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), []byte(""), byte(protocol.HeartBeat), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))        if err != nil {            return        }        // 5. 回写        err = xChannel.Send(message)        if err != nil {            return        }        return    }    // 限流    if s.options.Discovery.Limit() {        serialization, _ := codes.SerializationManager.Get(codes.MsgPack)        metaData := make(map[string]string)        metaData["RespError"] = pkg.ErrCircuitBreaker.Error()        meta, err := serialization.Encode(metaData)        if err != nil {            return        }        decrypt, err := cryptology.AESDecrypt(aesKey, meta)        if err != nil {            return        }        _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), decrypt, byte(protocol.Response), byte(codes.RawData), byte(codes.MsgPack), []byte(""))        if err != nil {            return        }        // 5. 回写        err = xChannel.Send(message)        if err != nil {            return        }        log.Println(s.options.Discovery.Limit())        log.Println("限流/////////////")        return    }    // 1. 解压缩    compressor, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))    if !ex {        err = errors.New("compressor 404")        return    }    msg.MetaData, err = compressor.Unzip(msg.MetaData)    if err != nil {        return    }    msg.Payload, err = compressor.Unzip(msg.Payload)    if err != nil {        return    }    // 2. 解密    msg.MetaData, err = cryptology.AESDecrypt(aesKey, msg.MetaData)    if err != nil {        return    }    msg.Payload, err = cryptology.AESDecrypt(aesKey, msg.Payload)    if err != nil {        return    }    // 3. 反序列化    serialization, ex := codes.SerializationManager.Get(codes.SerializationType(msg.Header.SerializationType))    if !ex {        err = errors.New("serialization 404")        return    }    metaData := make(map[string]string)    err = serialization.Decode(msg.MetaData, &metaData)    if err != nil {        return    }        // 初始化context    ctx := light.DefaultCtx()    ctx.SetMetaData(metaData)    // 1.3 auth    if s.options.AuthFunc != nil {        auth := metaData["Light_AUTH"]        err := s.options.AuthFunc(ctx, auth)        if err != nil {            ctx.SetValue("RespError", err.Error())            var metaDataByte []byte            metaDataByte, _ = serialization.Encode(ctx.GetMetaData())            metaDataByte, _ = cryptology.AESEncrypt(aesKey, metaDataByte)            metaDataByte, _ = compressor.Zip(metaDataByte)            // 4. 打包            _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))            if err != nil {                return            }            // 5. 回写            err = xChannel.Send(message)            if err != nil {                return            }            return        }    }        // 找到具体调用的服务    ser, ex := s.serviceMap[msg.ServiceName]    if !ex {        err = errors.New("service does not exist")        return    }        // 找到具体调用的办法    method, ex := ser.methodType[msg.ServiceMethod]    if !ex {        err = errors.New("method does not exist")        return    }        // 初始化 req, resp    req := utils.RefNew(method.RequestType)    resp := utils.RefNew(method.ResponseType)    err = serialization.Decode(msg.Payload, req)    if err != nil {        return    }        // 定义ctx paht 为   服务名称.服务办法    path := fmt.Sprintf("%s.%s", msg.ServiceName, msg.ServiceMethod)    ctx.SetPath(path)    // 前置middleware    if len(s.beforeMiddleware) != 0 {        for idx := range s.beforeMiddleware {            err := s.beforeMiddleware[idx](ctx, req, resp)            if err != nil {                return            }        }    }    funcs, ex := s.beforeMiddlewarePath[path]    if ex {        if len(funcs) != 0 {            for idx := range funcs {                err := funcs[idx](ctx, req, resp)                if err != nil {                    return                }            }        }    }    // 外围调用    callErr := ser.call(ctx, method, reflect.ValueOf(req), reflect.ValueOf(resp))    if callErr != nil {        ctx.SetValue("RespError", callErr.Error())    }    // 后置middleware    if len(s.afterMiddleware) != 0 {        for idx := range s.afterMiddleware {            err := s.afterMiddleware[idx](ctx, req, resp)            if err != nil {                return            }        }    }    funcs, ex = s.afterMiddlewarePath[path]    if ex {        if len(funcs) != 0 {            for idx := range funcs {                err := funcs[idx](ctx, req, resp)                if err != nil {                    return                }            }        }    }    // response    // 1. 序列化    var respBody []byte    respBody, err = serialization.Encode(resp)    var metaDataByte []byte    metaDataByte, _ = serialization.Encode(ctx.GetMetaData())    // 2. 加密    metaDataByte, err = cryptology.AESEncrypt(aesKey, metaDataByte)    if err != nil {        return    }    respBody, err = cryptology.AESEncrypt(aesKey, respBody)    if err != nil {        return    }    // 3. 压缩    metaDataByte, err = compressor.Zip(metaDataByte)    if err != nil {        return    }    respBody, err = compressor.Zip(respBody)    if err != nil {        return    }    // 4. 打包    _, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, respBody)    if err != nil {        return    }    // 5. 回写    err = xChannel.Send(message)    if err != nil {        return    }}

调用具体方法

func (s *service) call(ctx *light.Context, mType *methodType, request, response reflect.Value) (err error) {    // recover 捕捉堆栈音讯    defer func() {        if r := recover(); r != nil {            buf := make([]byte, 4096)            n := runtime.Stack(buf, false)            buf = buf[:n]            err = fmt.Errorf("[painc service internal error]: %v, method: %s, argv: %+v, stack: %s",                r, mType.method.Name, request.Interface(), buf)            log.Println(err)        }    }()    fn := mType.method.Func    returnValue := fn.Call([]reflect.Value{s.refVal, reflect.ValueOf(ctx), request, response})    errInterface := returnValue[0].Interface()    if errInterface != nil {        return errInterface.(error)    }    return nil}

这里就实现了服务端的根底逻辑了