Deep Gradient Compression
Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training
Yujun Lin et. al. Tsinghua University, ICLR 2018
Gradient exchange require hight bandwidth,
- high latency
- low throughput
- poor connnection
Since 99.9% of the gradient exchange in SGD are redundant, propose deep gradient compression (DGC) to reduce communication bandwidth.
To preserve accuracy,
- momentum correction
- local gradient clipping
- momentum factor masking, alleviate staleness
- warm-up training
improving local gradient accumulation and overcomming the staleness effect
DGC
- pushes the gradient compression ratio to up to 600×
- not need to change the model structure
- no loss of accuracy
How
- only gradients larger than a threshold are transmitted
- accumulate the rest of the gradients locally, local gradient accumulation is equivalent to increasing the batch size over time
DGC naively perform fine-grained (i.e., element-wise) top-k to select gradients, and thus the communication will suffer from increased allgather data volume as #nodes increases.
CSC modified the process with coarse-grained sparsification: gradients are partioned into chunks, allreduce the gradient chunks selected based on allreduced L1-norm of each chunk, which gets rid of the allgather and solves the problem.
deep-gradient-compression github
配置
# configs/dgc/__init__.py
训练流程
# train.py
from dgc.horovod.optimizer import DistributedOptimizer
# from dgc.compression import DGCCompressor
compression = configs.train.compression()
# cpr_parameters 即 dgc 处理的范围
compression.initialize(cpr_parameters.items())
# from dgc.optim import DGCSGD
optimizer = configs.train.optimizer(model.parameters())
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=configs.train.num_batches_per_step,
op=hvd.Average
)
# 训练基本循环 zero_grad -> loss.backward -> optimizer.step
# 特别注意这里多次 backward 才走一次 step 更新
model.train()
for step, (inputs, targets) in enumerate(...):
optimizer.zero_grad()
# 这里用了内置循环累积梯度,比直接使用大 batch 剩显存
# 注意这个 for 循环,对于 optimizer 里面理解 synchronize 过程非常重要
for b in range(0, step_size, batch_size):
_inputs = inputs[b:b+batch_size]
_targets = targets[b:b+batch_size]
_outputs = model(_inputs)
_loss = criterion(_outputs, _targets)
_loss.mul_(_r_num_batches_per_step)
_loss.backward()
loss += _loss.item()
optimizer.step()
Optimizer
# dgc/horovod/optimizer.py
class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, ...):
# 初始化最后注册 通信 hook
self._register_hooks()
def _register_hooks(self):
for param_group in self.param_groups:
for p in param_group['params']:
if p.requires_grad:
# 注册函数只执行一次,这里 zero grad 不是每次调用 hook
p.grad = p.data.new(p.size()).zero_()
self._requires_update.add(p)
# 创建幽灵 tensor 来累积梯度,节点间同步,直至更新;
# p_tmp 和 p 使用同样的 storage,不占用额外显存
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
# 注册 _make_hook 这个关键 hook
grad_acc.register_hook(self._make_hook(p))
self._grad_accs.append(grad_acc)
def _make_hook(self, p):
# 这个 hook 有一个计数器,_allreduce_delay, 根据对象 p 不一样可以取不一样的值
# 计数器不为零时跳过,这样可以让 grad 在本地累积,因为这个 hook 是做通信的
# 效果为这个 hook 在多次调用才会被执行一次
def hook(*ignore):
handle, ctx = None, None
self._allreduce_delay[p] -= 1
if self._allreduce_delay[p] == 0:
handle, ctx = self._allreduce_grad_async(p)
self._handles[p] = (handle, ctx)
return hook
# 然后主要流程 step
def step(self, closure=None):
self.synchronize()
return super(self.__class__, self).step(closure)
# step 调用 synchronize, 可以有跳过逻辑
def synchronize(self):
# 处理 hook 注册不成功,或者说 hook 没有被调用
# hook 被调用后会添加 self._handles
missing_p = self._requires_update - set(self._handles.keys())
for p in missing_p:
handle, ctx = self._allreduce_grad_async(p)
self._handles[p] = (handle, ctx)
# handle 为 None 的 hook 跳过又不跳过了?
# 需要注意 synchronize 函数每个 step 被调用,但不是每次 backward 都会被调用
# 在之前的 train 中有每个 step 会多次 backward,所以 grad 的 hook 会被多次调用,次数匹配
# 所以代码执行到这里 handle 应该是一次调用_allreduce_grad_async 如果不是就补上
for p, (handle, ctx) in self._handles.items():
if handle is None:
handle, ctx = self._allreduce_grad_async(p)
self._handles[p] = (handle, ctx)
# for 循环处理异步通信的结果
for p, (handle, ctx) in self._handles.items():
output = self._synchronize_(handle)
# 重置本地累积次数
self._allreduce_delay[p] = self.backward_passes_per_step
# 解压更新梯度
p.grad.set_(self._compression.decompress(output, ctx))
# 执行完毕,清理
self._handles.clear()
# 异步通信的 op,核心逻辑在 compression 中
def _allreduce_grad_async(self, p):
name = self._parameter_names.get(p)
tensor_compressed, ctx = self._compression.compress(p.grad, name)
handle = self._communicate_(tensor_compressed, name=name, op=self.op)
return handle, ctx
- hook 函数是一次注册,多次调用,所以
self._handles
会不断被填充,每次 synchronize 后可以被 clear
# dgc/compression.py
class DGCCompressor:
def __init__(self, ...):
self.attributes = {}
def initialize(self, named_parameters):
# 工作范围
for name, param in named_parameters:
self.attributes[name] = (numel, shape, num_selects, num_samples, top_k_samples, sample_stride)
def _sparsify(self, tensor, name):
# 选出稀疏的 tensor 去通信
# 原实现中比较复杂
# 先随机选取部分梯度值的 TOPK 来计算阈值
# 然后通过该阈值对原 tensor 做稀疏化
importance = tensor.abs()
mask = torch.ge(importance, threshold)
indices = mask.nonzero().view(-1)
num_indices = indices.numel()
# 这里实现上有个 for 循环确保选出的 topk 满足要求
indices = indices[:num_selects]
values = tensor[indices]
return values, indices, numel, shape, num_selects
def compress(self, tensor, name):
if self.compress_ratio < 1.0 and name in self.attributes:
# compress
tensor_compensated = self.memory.compensate(tensor, name, accumulate=True)
values, indices, numel, shape, num_selects = self._sparsify(tensor_compensated, name)
self.memory.update(name, (indices, ))
return tensor, ctx
else:
return tensor, ctx
def decompress(self, tensor, ctx):
name, numel, shape, vdtype, idtype, grad = ctx
if self.compress_ratio < 1.0 and name in self.attributes:
# 这里的 tensor 是个 tuple
# decompress
values, indices = tensor
# 把同步回来的稀疏 tensor 对应位置更新
# accumulate=True 处理 indices 中有重复的情况
grad.zero_().index_put_([indices], values, accumulate=True)
if self.op == Average:
grad.mul_(1. / self.world_size)
return grad.view(shape)
else:
return self.memory.compensate(tensor, name, accumulate=False)
# optimizer _communicate_
def communicate(self, tensor_compressed, name, op):
# 两个分支
if self.compress_ratio < 1.0 and name in self.attributes:
# dgc 分支,tensor_compressed 是 tuple,各个节点选的 topk index 不相同
# 所以使用 allgather 交换,然后各自解压、更新
return [allgather_async_(t, name=f'{name}.t{e}')
for e, t in enumerate(tensor_compressed)]
else:
# 普通分支,直接 allreduce 完整 tensor
return allreduce_async_(tensor_compressed, name=name, op=op)
# optimizer _synchronize_
def synchronize(self, handle):
# from horovod.torch.mpi_ops import synchronize as synchronize_
if isinstance(handle, (tuple, list)):
return [synchronize_(h) for h in handle]
else:
return synchronize_(handle)
为了保证精度文章中介绍了下面几种补偿策略
- momentum correction
- local gradient clipping
- momentum factor masking, alleviate staleness
- warm-up training
前三种策略在 Memory 实现
# dgc/memory.py
class DGCSGDMemory(Memory):
def compensate(self, grad, name, accumulate=True):
if self.gradient_clipping is not None:
grad = self.gradient_clipping(grad)
mmt = self.momentums[name]
if accumulate:
# Momentum Correction
vec = self.velocities[name]
if self.nesterov:
mmt.add_(grad).mul_(self.momentum)
vec.add_(mmt).add_(grad)
else:
mmt.mul_(self.momentum).add_(grad)
vec.add_(mmt)
return vec
else:
if self.nesterov:
mmt.add_(grad).mul_(self.momentum)
return mmt.add(grad)
else:
mmt.mul_(self.momentum).add_(grad)
return mmt.clone() # TODO: save this clone
def update(self, name, ctx):
indices = ctx[0]
if self.momentum_masking:
self.momentums[name].view(-1).index_fill_(0, indices, 0)
self.velocities[name].view(-1).index_fill_(0, indices, 0)