Hook

Summary

Hook触发时机API
反向 hook反向梯度计算完成后register_hook(hook)
前向 after hookforward() 调用之后register_forward_hook(hook)
前向 before hookforward() 调用之前register_forward_pre_hook(hook)
module hookmodule inputs 反向梯度计算完成后register_full_backward_hook(hook)
节点 hookautograd hook, 节点计算后register_hook(hook)
通信 hookhook 参数 dist.GradBucket readyregister_comm_hook(state, hook)

torch.nn.parallel.DistributedDataParallel.register_comm_hook Tensor.register_hook(hook) torch.nn.Module.register_forward_hook(hook) torch.nn.Module.register_forward_pre_hook(hook) torch.nn.Module.register_full_backward_hook(hook) Node.register_hook(hook) torch.nn.parallel.DistributedDataParallel.register_comm_hook(state, hook)

反向 hook

反向 hook 最为常用的 hook,但反向梯度生成后加入逻辑。

# hook signature
hook(grad) -> Tensor or None

def hook(grad):
    return grad * 2

v = torch.tensor([0., 0., 0.], requires_grad=True)
h = v.register_hook(hook)  # double the gradient
v.backward(torch.tensor([1., 2., 3.]))
v.grad # 2 4 6
h.remove()  # removes the hook

Module hook

register_forward_hook(hook)
hook(module, input, output) -> None or modified output
register_forward_pre_hook(hook)
hook(module, input) -> None or modified input
register_full_backward_hook(hook)
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
# 不允许修改参数 tensor,但是可以返回 tensor 用于后续梯度计算

Autograd Hook

autograd hook,node grad 经典用法来自 horovod optimizer

p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_hook(p))
# 示例
>>> p = torch.tensor([1,2,3], dtype=float, requires_grad=True)
>>> p
tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
>>> pt = p.expand_as(p)
>>> pt
tensor([1., 2., 3.], dtype=torch.float64, grad_fn=<ExpandBackward0>)
>>> pg = pt.grad_fn.next_functions[0][0]
>>> pg
<AccumulateGrad object at 0x7fa7ade03710>

DDP 通信 hook

torch.nn.parallel.DistributedDataParallel.register_comm_hook(state, hook)
hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]

Example

def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
    encoded_tensor = encode(bucket.buffer()) # encode gradients
    fut = torch.distributed.all_reduce(encoded_tensor).get_future()
    # Define the then callback to decode.
    def decode(fut):
        decoded_tensor = decode(fut.value()[0]) # decode gradients
        return decoded_tensor
    return fut.then(decode)

Reference

  • https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html
  • https://pytorch.org/docs/stable/ddp_comm_hooks.html