PyTorch Patch
summary
这个库提供了一种对 pytorch 进行 patch/hack 的方式,在很多场景可以借鉴,例如容错弹性。
patch_apex()
patch_torch_classes() # [torch, torch.Tensor, torch.nn.functional, torch.distributed]
patch_torch_nn_forward_functions() # [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]
patch_apex --> patch_apex_c + patch_apex_pyt
patch_apex_c --> patchClass --> add_wrapper
patch_apex_pyt --> patch_apex_module --> patch_apex_class --> add_wrapper
patch_torch_classes --> patchClass --> add_wrapper
patch_torch_nn_forward_functions --> add_wrapper
patch_apex
def patch_apex():
patch_apex_c()
patch_apex_pyt()
def patch_apex_c():
if importlib.util.find_spec("amp_C") is not None:
import amp_C
patchClass(amp_C)
# fused_adam_cuda
# fused_lamb_cuda
# fused_layer_norm_cuda
# distributed_lamb_cuda
# xentropy_cuda
# mlp_cuda
def patch_apex_pyt():
if importlib.util.find_spec("apex") is not None:
patch_apex_module("apex.amp")
patch_apex_module("apex.contrib.groupbn")
patch_apex_module("apex.contrib.multihead_attn")
patch_apex_module("apex.contrib.optimizers")
patch_apex_module("apex.contrib.sparsity")
patch_apex_module("apex.contrib.xentropy")
patch_apex_module("apex.fp16_utils")
patch_apex_module("apex.mlp")
patch_apex_module("apex.multi_tensor_apply")
patch_apex_module("apex.optimizers")
patch_apex_module("apex.parallel")
def patch_apex_module(modstr):
"""
Patch all forward/backward/step functions in classes in the given apex module.
"""
if importlib.util.find_spec(modstr) is not None:
mod = importlib.import_module(modstr)
for _, v in ins.getmembers(mod):
# This makes sure we don't patch random other modules that are imported by the target module
if is_same_module_or_submodule(mod, ins.getmodule(v)):
if (ins.isclass(v)):
patch_apex_class(v)
def patch_apex_class(cls):
"""
Patch all forward/backward/step functions in the given apex class
"""
for f in cls.__dict__:
if (ins.isfunction(cls.__dict__[f])):
if f in ["forward", "backward", "step"]:
add_wrapper(cls, f)
patch_torch_classes
def patchClass(cls):
for f in dir(cls):
if isfunc(cls, f):
add_wrapper(cls, f)
def patch_torch_classes():
"""Monkey-patch all classes in torch"""
for cls in [torch, torch.Tensor, torch.nn.functional, torch.distributed]:
patchClass(cls)
patch_torch_nn_forward_functions
def patch_torch_nn_forward_functions():
"""Monkey-patch all forward functions in torch.nn libraries"""
for cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]:
if isfunc(cls, 'forward'):
add_wrapper(cls, 'forward')
add_wrapper
def add_wrapper(mod, fn_name):
# Get a pointer to the original function
func = getattr(mod, fn_name)
# Check if the mod has a string representation
# and is not a Script or Traced module (used by JIT)
# yapf: disable
s = hasattr(mod, "extra_repr") and (type(mod) is not torch.jit.ScriptModule
) and (type(mod) is not torch.jit.TopLevelTracedModule)
# yapf: enable
def wrapper_func(*args, **kwargs):
# Extract the stacktrace
stack = traceback.extract_stack()
# Push trace marker
nvtx.range_push(traceMarker(stack))
# Push module marker
if s:
m = modMarker(mod, fn_name, args)
nvtx.range_push(m)
# Create and push argument marker
cadena = argMarker(mod, fn_name, args, kwargs)
nvtx.range_push(cadena)
# Call the original function
result = func(*args, **kwargs)
# Pop argumet marker
nvtx.range_pop()
# Pop module marker
if s:
nvtx.range_pop()
# Pop trace marker
nvtx.range_pop()
return result
setattr(mod, fn_name, wrapper_func)
Reference
- https://github.com/NVIDIA/PyProf.git