乐趣区

Go-WaitGroup-源码分析

概述

go 语言 sync 库中的 WaitGroup 是用于等待一个协程或者一组携程。使用 Add 函数增加计数器,使用 Done 函数减少计数器。当使用 Wait 函数等待计数器归零之后则唤醒主携程。需要注意的是:

  • Add 和 Done 函数一定要配对,否则可能发生死锁
  • WaitGroup 结构体不能复制

源码分析

WaitGroup 对象

type WaitGroup struct {
    noCopy noCopy
    // 位值: 高 32 位是计数器,低 32 位是 goroution 等待计数。state1 [12]byte
    // 信号量,用于唤醒 goroution
    sema   uint32
}

func (wg *WaitGroup) state() *uint64 {if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {return (*uint64)(unsafe.Pointer(&wg.state1))
    } else {return (*uint64)(unsafe.Pointer(&wg.state1[4]))
    }
}

Add,Done,Wait

func (wg *WaitGroup) Add(delta int) {
    // 获取状态码
    statep := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()}
    // 把传入的 delta 用原子操作加入到 statep,state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 获取计数器数值
    v := int32(state >> 32)
    // 获取等待数量
    w := uint32(state)
    if race.Enabled && delta > 0 && v == int32(delta) {
        // The first increment must be synchronized with Wait.
        // Need to model this as a read, because there can be
        // several concurrent wg.counter transitions from 0.
        race.Read(unsafe.Pointer(&wg.sema))
    }
    // 计数器小于 0 报错
    if v < 0 {panic("sync: negative WaitGroup counter")
    }
    
    if w != 0 && delta > 0 && v == int32(delta) {panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 如果等待为 0 或者计数器大于 0 意味着没有等待或者还有读锁 不需要唤醒 goroutine 则返回 add 操作完毕
    if v > 0 || w == 0 {return}
    
    if *statep != state {panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 
    // 唤醒所有等待的线程
    for ; w != 0; w-- {runtime_Semrelease(&wg.sema, false)
    }
}

// Done 函数 调用了 Add 函数传入 -1 相当于锁的数量减 1
func (wg *WaitGroup) Done() {wg.Add(-1)
}

func (wg *WaitGroup) Wait() {
    // 获取 waitGroup 的状态码
    statep := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        race.Disable()}
    // 循环
    for {
        // 调用 load 获取状态
        state := atomic.LoadUint64(statep)
        // 获取计数器数值
        v := int32(state >> 32)
        // 获取等待数量
        w := uint32(state)
        
        if v == 0 {
            // Counter is 0, no need to wait.
            if race.Enabled {race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
        // 添加等待数量 如果 cas 失败则重新获取状态 避免计数有错
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(&wg.sema))
            }
            // 阻塞 goroutine 等待唤醒
            runtime_Semacquire(&wg.sema)
            if *statep != 0 {panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            if race.Enabled {race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
    }
}
退出移动版