关于人工智能:PyTorch之Checkpoint机制解析

14次阅读

共计 14033 个字符,预计需要花费 36 分钟才能阅读完成。

PyTorch 之 Checkpoint 机制解析

本文已受权极市平台, 并首发于极市平台公众号. 未经容许不得二次转载.

原文链接:https://www.yuque.com/lart/ug…

PyTorch 提供了一种十分不便的节俭显存的形式,就是 Checkpoint 机制。这篇文章的目标在于更透彻的理解其内在的机制。

Checkpoint 机制

该技术的外围是一种应用工夫换空间的策略。在现有的许多办法中被大量应用,例如 DenseNet、Swin Transformer 源码中都能够看到它的身影。

为了理解它的工作原理,咱们先得弄明确的一个问题是,PyTorch 模型在训练过程中显存占用次要是用来存储什么?

对于这一点,Connolly 的文章《PyTorch 显存机制剖析》介绍的十分具体:

单刀直入的说,PyTorch 在进行深度学习训练的时候,有 4 大部分的显存开销,别离是 模型参数 (parameters) 模型参数的梯度 (gradients) 优化器状态 (optimizer states) 以及 两头激活值(intermediate activations) 或者叫两头后果(intermediate results)。

而通过 Checkpoint 技术,咱们能够通过一种取巧的形式,应用 PyTorch 提供的“no-grad”(no_grad())模式来防止将这部分运算被 autograd 记录到反向图“backward graph”中,从而防止了对于两头激活值的存储需要。

集体了解(欢送指出谬误):

前向流传时 autograd 记录各个操作反向流传须要的一些信息和两头变量。反向流传之后,用于计算梯度的两头后果会被开释。也就是说,模型参数、优化器状态和参数梯度是始终在占用着存储空间的,两头激活值在反向流传之后就主动被清空了。具体显存占用变动可见 PyTorch 显存占用剖析,这里我简略批改了《PyTorch 显存机制剖析》中给出的例子 进行了一下验证。

这里实际上会引申出另一个问题,为什么自定义 Function 个别状况下会缩小显存占用?(在 Vision Longformer 中各种实现的比照里能够显著看到这一景象)

我感觉次要是因为自定义 Function 的时候,咱们能够从一整个模块的角度来更有针对性的在 ctx 中存储两头变量,而主动求导引擎可能关注的太细了,导致存储许多不必要的两头变量。对于这一点临时不晓得如何验证。

这能够防止存储模型特定层两头运算后果,从而无效升高了前向流传中显存的占用。 这些两头后果会在反向流传的时候被即时从新计算一次。要留神,被 checkpoint 包裹的层反向流传时依然会在第一次反向流传的时候开拓存储梯度的空间。

因为 checkpoint 是在 torch.no_grad() 模式下计算的指标操作的前向函数,这并不会批改本来的叶子结点的状态,有梯度的还会放弃。只是关联这些叶子结点的长期生成的两头变量会被设置为不须要梯度,因而梯度链式关系会被断开。

通过这样的形式,尽管缩短了反向流传的工夫,然而却也在肯定水平上缓解了存储大量两头变量带来的显存占用。

源码解析

以下代码来自 PyTorch v1.10.1 版本:https://github.com/pytorch/py…。最新的版本中补充了一些新的内容,待其最终公布后再说吧,上面的内容自身曾经将 checkpoint 的外围介绍了。

辅助函数

这部分代码中首先结构了数个辅助函数,次要是用来做一些针对输出的检查和解决,同时也要解决好随机种子的问题。

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue
            
            # 间接 detach(),从 inp 所在的计算图中剥离,默认会主动将 requires_grad 置为 False
            x = inp.detach()
            # 然而这里的理论需要中,仍须要放弃其本身的须要记录梯度的属性,且其梯度变为 None
            x.requires_grad = inp.requires_grad
            # 因为只有须要保留梯度的参数才可能构建梯度的流传门路
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type:", type(inputs).__name__)

def check_backward_validity(inputs: Iterable[Any]) -> None:
    """查看输出参数是否至多有一个须要记录梯度的 Tensor,这样能力确保输入也有梯度。"""
    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
        warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

因为须要反复计算,所以随机状态的一致性是须要器重的。因为前向流传的局部在反向过程中仍会计算一次,所以如果不应用原始的随机状态的话,会导致从新计算和本来失常计算过程中的随机状态不同,而影响模型的行为。

另外在这段代码的正文中提到了一点乏味的中央:

因为无奈获悉被 checkpoint 解决的操作是否在运算两头会将一些参数挪动到不同的设施上,这可能须要手动保留这些设施对应的随机状态。以后的实现间接保留了所有可见设施上的随机状态,然而这样有时可能是不必要的,然而目前尚没有较好的解决策略。

所以依照文档的意思,就是在说如果没有这样的挪动,那就能够不必保留随机状态咯?这一点其实有些令人纳闷。

# We can't know if the run_fn will internally move some args to different devices,
# which would require logic to preserve rng states for those devices as well.
# We could paranoically stash and restore ALL the rng states for all visible devices,
# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
# the device of all Tensor args.
#
# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    """获取不同输出对应的 GPU 设施的随机数生成器的状态"""
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states

def set_device_states(devices, states) -> None:
    """针对不同的设施设置随机数生成器的状态"""
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

外围 Function

能够看到,这里的 Checkpoint 自身就是基于 PyTorch 的 PyTorch 自定义算子之 Function 实现的一个扩大算子,所以该局部代码也波及到了 Function 的诸多性能。

浏览它既能够帮忙咱们同时温习一下相干的常识,又能进一步理解更简单的解决逻辑该如何搭建。

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        # 暂存前向流传函数
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        # 用来保留以后模型的混合精度的状态,以用在反向流传中
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:  # 保留指标模块前向流传之前,此时 CPU 和 GPU 的随机数生成器的状态
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:  
                # PyTorch 提供的一个外部变量,用于断定 CUDA 状态是否曾经被初始化了
                # torch.cuda.is_initialized 中就用到了该变量
                ctx.had_cuda_in_fwd = True
                # 保留输出变量波及的各个 GPU 设施的随机状态
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        # save_for_backward()中保留反向流传中须要用到的输出和输入 tensor 量。# 因为在反向流传中须要从新计算记录梯度的 output,所以就不要保留 output 了。# 并且前面的计算也不须要在梯度模式下计算。ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():  
            # 不保留梯度的前向流传操作,也就是说这里的 output 是不会记录两头变量,无奈间接计算梯度的。outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                "is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                "argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # 获取前向流传中保留的输出 tensor

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        
        # 应用之前前向流传开始之前保留的随机数生成器的状态来进行一次截然不同的前向流传过程
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            # 应用上下文管理器爱护原始的随机数生成器的状态,外部解决后在进行还原
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            # 这里将 inputs 从计算图中剥来到,然而其属性 requires_grad 和原来是一样的,这么做的目标是为了截断反向流传的门路。# 从整个操作目标来看,因为咱们须要从新计算输入,并将梯度回传到输出上,所以输出自身须要能够记录梯度。# 然而这里的回传不能够影响到 checkpoint 之外更靠前的那些操作,# backward 之后会将之前保留的两头变量开释掉,而咱们仅仅是为了计算以后一小块构造,所以梯度回传须要截断。detached_inputs = detach_variable(tuple(inputs))  # 会变成叶子结点,grad 和 grad_fn 均重置为 None
            # 解决完随机状态之后,就该筹备着手从新前向流传了。# 这次前向流传是在梯度模式 (torch.enable_grad()) 下执行的。此时会保留两头变量。with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            # 记录须要计算梯度的输入 outputs[i]以及对应的回传回来的无效梯度 args[i]
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        # 查看须要计算梯度的输入,如果没有输入须要计算梯度,那么实际上就阐明这个模块是不参加梯度计算的,# 也就是说,该模块不须要应用 checkpoint 来调整。if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                "this checkpoint() is not necessary")
        # 该操作对被包裹的指标操作计算反向流传,即计算回传到输出 detached_inputs 上的梯度。# 因为输出的 tensor 已被从整体梯度图中剥离,所以能够看做是一个叶子结点,能够在反向流传之后取得其梯度,并且两头变量也会随之开释。# 另外这里反传计算梯度也不会导致将更靠前的构造中临时保留来计算梯度的参数给开释掉。torch.autograd.backward(outputs_with_grad, args_with_grad)
        # 如果后面不执行 detach(),这里的 inp.grad 会被间接开释并置为 None,这并不合乎预期
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)

        # 这里返回的梯度与以后类的 forward 输出一一对应,# 因为这里的 forward 蕴含着本不须要梯度的两个参数 run_function、preserve_rng_state,故对应回传 None 即可。return (None, None) + grads

这里实际上就是在原始的操作和整体的计算图之间增加了一个中间层,用于信息的交互:

  1. 原始模型的数据传输到被包裹的指标层的时候,数据进入 checkpoint 的 forward() 中,被 checkpoint 进行检查和记录后,再送入指标层中;
  2. 指标层在非梯度模式下执行前向流传。该模式下,新创建的 tensor 都是不会记录梯度信息的;
  3. 指标层的后果通过 checkpoint 的前向流传输入,送入模型后续的其余构造中;
  4. 执行反向流传,损失求导,链式回传,计算梯度;
  5. 回传回来的对应于 checkpoint 输入的梯度被送入其对应的反向流传函数,即 checkpoint 的 backward()
  6. 梯度送入 checkpoint 中后,须要进一步将梯度回传到指标层的输出上。因为在 checkpoint 的 forward 中指标层自身前向流传是处于非梯度状态下,所以回传门路上短少指标层中操作的梯度子图。于是为了获取这部分信息,须要先梯度状态下对指标层进行一次前向流传,通过将回传回来的梯度和指标层的输入一起执行 torch.autograd.backward(outputs_with_grad, args_with_grad),从而取得对应输出的梯度信息。
  7. 将对应指标操作输出的梯度信息依照 checkpoint 自身 Function 的 backward 的需要,应用 None 对其余辅助参数的梯度占位后进行返回。
  8. 返回的对应于其余模块的输出量的梯度,被沿着反向流传的门路送入对应操作的 backward 中,一层一层回传累加到各个叶子节点上。

定义好操作后,进行一个简略的包装,同时解决一下默认参数,补充了更粗疏的文档:

def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
    r"""Checkpoint a model or part of the model
    
    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.
    
    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.
    这一段具体介绍了 checkpoint 的核心技术,也就是在非梯度模式下执行指标操作的前向流传,只保留输出和构造参数,省去了两头激活的保留。反向流传时在梯度模式下从新计算这些激活,重建这部分反向图,进而实现了梯度的失常回传。The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.
    因为 checkpoint 的 backward 实现的逻辑中,间接遍历指标操作的输入(会被自定转换成元组类型)并确定那些须要回流梯度的输入。如果输入中蕴含其余的非 tensor 构造,就会导致在遍历过程中这些输入被疏忽掉。不过也的确,这样间接简化解决尽管使得灵活性降落,然而却也防止了代码过于简单。.. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.
    
    .. warning::
        If :attr:`function` invocation during backward does anything different
        than the one during forward, e.g., due to some global variable, the
        checkpointed version won't be equivalent, and unfortunately it can't be
        detected.
        尽量保障指标操作在反向计算期间和前向期间的操作的一致性。因为在 checkpoint 会在反向中从新计算一次前向,这可能会带来一些因为无奈检测到的不确定因素而造成的与惯例版本的差别。.. warning::
        If checkpointed segment contains tensors detached from the computational
        graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
        error. This is because `checkpoint` makes all the outputs require
        gradients which causes issues when a tensor is defined to have no
        gradient in the model. To circumvent this, detach the tensors outside of
        the `checkpoint` function.
        不要在指标操作中蕴含 detach 或者非梯度模式的解决。** 在我的理论测试中仿佛并没有这个问题?** 或者这里应该看一下 pytorch 提供的测试案例。.. warning::
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients. At least one of the outputs needs to have
        :code:`requires_grad=True` as well.
        要保障至多有一个输出是 requires_grad 的,这样才能够保障这部分操作能够被记录梯度。也要保障输入至多有一个须要计算梯度。Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`

    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments:" + ",".join(arg for arg in kwargs))

    return CheckpointFunction.apply(function, preserve, *args)

利用案例

Checkpoint for Sequential

PyTorch 源码中给了一个很间接的利用案例,就是将 checkpoint 利用于 Sequential 搭建起来的模型。依照分段数 segments 指定的,将模型划分为多段。

def checkpoint_sequential(functions, segments, input, **kwargs):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a model in various segments
    and checkpoint each segment. All segments except the last will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. The inputs of each checkpointed segment will be saved for
    re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    .. warning:
        Since PyTorch 1.4, it allows only one Tensor as the input and
        intermediate outputs, just like :class:`torch.nn.Sequential`.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or
            functions (comprising the model) to run sequentially.
        segments: Number of chunks to create in the model
        input: A Tensor that is input to :attr:`functions`
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_sequential(model, chunks, input_var)
    """
    # Hack for keyword-only parameter in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments:" + ",".join(arg for arg in kwargs))

    def run_function(start, end, functions):
        def forward(input):
            for j in range(start, end + 1):
                input = functions[j](input)
            return input
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = list(functions.children()) 
        # 获取 Sequential 的子模块,这里应用 children 办法,仅获取最外层

    segment_size = len(functions) // segments
    # the last chunk has to be non-volatile(为什么?仿佛加上也是能够的)end = -1
    for start in range(0, segment_size * (segments - 1), segment_size):
        end = start + segment_size - 1
        # 迭代式的将各个子模块汇合应用 checkpoint 包装并前向流传。input = checkpoint(run_function(start, end, functions), input,
                           preserve_rng_state=preserve)
    # 残余的构造不再应用 checkpoint
    return run_function(end + 1, len(functions) - 1, functions)(input)

参考链接

  • Checkpoint 源码:https://github.com/pytorch/pytorch/blob/master/torch/utils/checkpoint.py
  • PyTorch 的 Autograd – xiaopl 的文章 – 知乎 https://zhuanlan.zhihu.com/p/69294347
  • PyTorch 源码解读之 torch.autograd:梯度计算详解 – OpenMMLab 的文章 – 知乎 https://zhuanlan.zhihu.com/p/321449610
  • 浅谈 PyTorch 中的 tensor 及应用 – xiaopl 的文章 – 知乎 https://zhuanlan.zhihu.com/p/67184419
  • https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc
  • https://pytorch.org/tutorials/beginner/introyt/autogradyt_tutorial.html
正文完
 0