概述
go 语言 sync 库中的 WaitGroup 是用于等待一个协程或者一组携程。使用 Add 函数增加计数器,使用 Done 函数减少计数器。当使用 Wait 函数等待计数器归零之后则唤醒主携程。需要注意的是:
- Add 和 Done 函数一定要配对,否则可能发生死锁
- WaitGroup 结构体不能复制
源码分析
WaitGroup 对象
type WaitGroup struct {
noCopy noCopy
// 位值: 高 32 位是计数器,低 32 位是 goroution 等待计数。state1 [12]byte
// 信号量,用于唤醒 goroution
sema uint32
}
func (wg *WaitGroup) state() *uint64 {if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {return (*uint64)(unsafe.Pointer(&wg.state1))
} else {return (*uint64)(unsafe.Pointer(&wg.state1[4]))
}
}
Add,Done,Wait
func (wg *WaitGroup) Add(delta int) {
// 获取状态码
statep := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()}
// 把传入的 delta 用原子操作加入到 statep,state := atomic.AddUint64(statep, uint64(delta)<<32)
// 获取计数器数值
v := int32(state >> 32)
// 获取等待数量
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
// 计数器小于 0 报错
if v < 0 {panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 如果等待为 0 或者计数器大于 0 意味着没有等待或者还有读锁 不需要唤醒 goroutine 则返回 add 操作完毕
if v > 0 || w == 0 {return}
if *statep != state {panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
//
// 唤醒所有等待的线程
for ; w != 0; w-- {runtime_Semrelease(&wg.sema, false)
}
}
// Done 函数 调用了 Add 函数传入 -1 相当于锁的数量减 1
func (wg *WaitGroup) Done() {wg.Add(-1)
}
func (wg *WaitGroup) Wait() {
// 获取 waitGroup 的状态码
statep := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()}
// 循环
for {
// 调用 load 获取状态
state := atomic.LoadUint64(statep)
// 获取计数器数值
v := int32(state >> 32)
// 获取等待数量
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// 添加等待数量 如果 cas 失败则重新获取状态 避免计数有错
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
// 阻塞 goroutine 等待唤醒
runtime_Semacquire(&wg.sema)
if *statep != 0 {panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}