一、前言
Go 语言在设计上对同步(Synchronization,数据同步和线程同步)提供大量的反对,比方 goroutine 和 channel 同步原语,库层面有
- sync:提供根本的同步原语(比方 Mutex、RWMutex、Locker)和 工具类(Once、WaitGroup、Cond、Pool、Map)- sync/atomic:提供变量的原子操作(基于硬件指令 compare-and-swap)
— 援用自《Golang package sync 分析(一):sync.Once》
上一期中,咱们介绍了 sync.Once
如何保障 exactly once
语义,本期文章咱们介绍 package sync
下的另一个工具类:sync.WaitGroup
。
二、为什么须要 WaitGroup
?
设想一个场景:咱们有一个用户画像服务,当一个申请到来时,须要
- 从 request 里解析出 user_id 和 画像维度参数
- 依据 user_id 从 ABCDE 五个子服务(数据库服务、存储服务、rpc 服务等)拉取不同维度的信息
- 将读取的信息进行整合,返回给调用方
假如 ABCDE 五个服务的响应工夫 p99 是 20~50ms 之间。如果咱们顺序调用 ABCDE 读取信息,不思考数据整合耗费工夫,服务端整体响应工夫 p99 是:
sum(A, B, C, D, E) => [100ms, 250ms]
先不说业务上能不能承受,响应工夫上显然有很大的优化空间。最直观的优化方向就是,取数逻辑的总工夫耗费:
sum(A, B, C, D, E) -> max(A, B, C, D, E)
具体到 coding 上,咱们须要并行调用 ABCDE 五个子服务,待调用 全副 返回当前,进行数据整合。如何保障 全副
返回呢?
此时,sync.WaitGroup
闪耀退场。
三、WaitGroup
用法
官网文档对 WaitGroup 的形容是:一个 WaitGroup 对象能够期待一组协程完结
。应用办法是:
- main 协程通过调用
wg.Add(delta int)
设置 worker 协程的个数,而后创立 worker 协程; - worker 协程执行完结当前,都要调用
wg.Done()
; - main 协程调用
wg.Wait()
且被 block,直到所有 worker 协程全副执行完结后返回。
这里先看一个典型的例子:
// src/cmd/compile/internal/ssa/gen/main.go
func main() {
// 省略局部代码 ...
var wg sync.WaitGroup
for _, task := range tasks {
task := task
wg.Add(1)
go func() {task()
wg.Done()}()}
wg.Wait()
// 省略局部代码...
}
这个例子具备了 WaitGroup
正确应用的大部分因素,包含:
wg.Done
必须在 所有wg.Add
之后执行,所以要保障两个函数都在 main 协程中调用;wg.Done
在 worker 协程里调用,尤其要保障调用一次,不能因为 panic 或任何起因导致没有执行(倡议应用defer wg.Done()
);wg.Done
和wg.Wait
在时序上是没有先后。
仔细的敌人可能会发现一行十分诡异的代码:
task := task
Go 对 array/slice 进行遍历时,runtime 会把 task[i]
拷贝到 task
的内存地址,下标 i
会变,而 task
的内存地址不会变。如果不进行这次赋值操作,所有 goroutine 可能读到的都是最初一个 task。为了让大家有一个直观的感觉,咱们用上面这段代码做试验:
package main
import (
"fmt"
"unsafe"
)
func main() {tasks := []func(){func() {fmt.Printf("1.") },
func() { fmt.Printf("2.") },
}
for idx, task := range tasks {task()
fmt.Printf("遍历 = %v,", unsafe.Pointer(&task))
fmt.Printf("下标 = %v,", unsafe.Pointer(&tasks[idx]))
task := task
fmt.Printf("局部变量 = %vn", unsafe.Pointer(&task))
}
}
这段代码的打印后果是:
1. 遍历 = 0x40c140, 下标 = 0x40c138, 局部变量 = 0x40c150
2. 遍历 = 0x40c140, 下标 = 0x40c13c, 局部变量 = 0x40c158
不同机器上执行打印后果有所不同,但共同点是:
- 遍历时,数据的内存地址不变
- 通过下标取数时,内存地址不同
- for-loop 内创立的局部变量,即使名字雷同,内存地址也不会复用
应用 WaitGroup
时,除了下面提到的注意事项,还须要解决数据回收和异样解决的问题。这里咱们也提供两种形式供参考:
- 对于 rpc 调用,能够通过 data channel 和 error channel 收集信息,或者二合一的 channel
- 共享变量,比方加锁的 map
四、WaitGroup
实现
在探讨这个主题之前,倡议读者先思考一下:如果让你去实现 WaitGroup
,你会怎么做?
锁?必定不行!
信号量?怎么实现?
———— 切入正题 ————
在 Go 源码里,WaitGroup
在逻辑上蕴含:
- worker 计数器:main 协程调用
wg.Add(delta int)
时减少delta
,调用wg.Done
时减一。 - waiter 计数器:调用
wg.Wait
时,计数器加一; worker 计数器升高到 0 时,重置 waiter 计数器。 - 信号量:用于阻塞 main 协程。调用
wg.Wait
时,通过runtime_Semacquire
获取信号量;升高 waiter 计数器时,通过runtime_Semrelease
开释信号量。
为了便于演示,咱们魔改一下下面的例子:
package main
import (
"fmt"
"sync"
"time"
)
func main() {tasks := []func(){func() {time.Sleep(time.Second); fmt.Println("1 sec later") },
func() { time.Sleep(time.Second * 2); fmt.Println("2 sec later") },
}
var wg sync.WaitGroup // 1-1
wg.Add(len(tasks)) // 1-2
for _, task := range tasks {
task := task
go func() { // 1-3-1
defer wg.Done() // 1-3-2
task() // 1-3-3}() // 1-3-1}
wg.Wait() // 1-4
fmt.Println("exit")
}
下面这段代码中,
- 1-1 创立一个
WaitGroup
对象,worker 计数器和 waiter 计数器默认值均为 0。 - 1-2 设置 worker 计数器为
len(tasks)
。 - 1-3-1 创立 worker 协程,并启动工作。
- 1-4 设置 waiter 计数器,获取信号量,main 协程被阻塞。
-
1-3-3 中执行完结后,1-3-2 升高 worker 计数器。当 worker 计数器升高到 0 时,
- 重置 waiter 计数器
- 开释信号量,main 协程被激活,1-4
wg.Wait
返回
只管 Add(delta int)
里 delta 能够是负数、0、正数。咱们在应用时,delta
总是负数。
wg.Done
等价于 wg.Add(-1)
。在本文中,咱们提到 wg.Add
时,默认 delta > 0
。
理解了 WaitGroup
的原理当前,咱们看下它的源码。为了便于了解,我只保留了外围逻辑。对于这部分逻辑,咱们分三局部解说:
WaitGroup
构造Add
和Done
Wait
提醒:如果只想理解 WaitGroup 的正确用法,本文读到这儿就足够了。对底层有趣味的敌人能够持续读,不过最好关上 IDE,参考源码一起读。
4.1 WaitGroup 构造
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
WaitGroup
构造体里有 noCopy
和 state1
两个字段。
编译代码时,go vet
工具会查看 noCopy
字段,防止 WaitGroup
对象被拷贝。
state1
字段比拟秀,在逻辑上它蕴含了 worker 计数器、waiter 计数器和信号量。具体如何读这三个变量,参考上面代码:
// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
// 读取计数器和信号量
statep, semap := wg.state()
state := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)
三个变量的取数逻辑是:
- worker 计数器:
v
是statep *uint64
的左 32 位
- waiter 计数器:
w
是statep *uint64
的右 32 位
- 信号量:
semap
是state1 [3]uint32
的第一个字节 / 最初一个字节
所以,更新 worker 计数器,须要这样做:
state := atomic.AddUint64(statep, uint64(delta)<<32)
更新 waiter 计数器,须要这样做:
statep, semap := wg.state()
for {state := atomic.LoadUint64(statep)
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 疏忽其余逻辑
return
}
}
仔细的敌人可能会发现,worker 计数器的更新是间接累加,而 waiter 计数器的更新是 CompareAndSwap。这是因为在 main 协程中执行 wg.Add
时,只有 main 协程对 state1
做批改;而 wg.Wait
中批改 waiter 计数器时,可能有很多个协程在更新 state1
。如果你还不太了解这段话,无妨先往下走,理解 wg.Add
和 wg.Wait
的细节之后再回头看。
4.2 Add 和 Done
wg.Add
操作的外围逻辑比较简单,即批改 worker 计数器,依据 worker 计数器的状态进行后续操作。简化版的代码如下:
func (wg *WaitGroup) Add(delta int) {statep, semap := wg.state()
// 1. 批改 worker 计数器
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 2. 判断计数器
if v > 0 || w == 0 {return}
// 3. 当 worker 计数器升高到 0 时
// 重置 waiter 计数器,并开释信号量
*statep = 0
for ; w != 0; w-- {runtime_Semrelease(semap, false)
}
}
func (wg *WaitGroup) Done() {wg.Add(-1)
}
4.3 Wait
wg.Wait
的逻辑是批改 waiter 计数器,并期待信号量被开释。简化版的代码如下:
func (wg *WaitGroup) Wait() {statep, semap := wg.state()
for {
// 1. 读取计数器
state := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)
if v == 0 {return}
// 2. 减少 waiter 计数器
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 3. 获取信号量
runtime_Semacquire(semap)
if *statep != 0 {panic("sync: WaitGroup is reused before previous Wait has returned")
}
// 4. 信号量获取胜利
return
}
}
}
因为源码比拟长,蕴含了很多校验逻辑和正文,本文中在援用时,在保留外围逻辑的同时均做了不同水平的删减。最初,举荐各位把源码下载下来,细细研读一番,从细节上对 WaitGroup
的设计有更深刻的了解。