Autograd

// torch/csrc/autograd/variable.h

struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
  std::string name_;

  Variable grad_;
  std::shared_ptr<Node> grad_fn_;
  std::weak_ptr<Node> grad_accumulator_;

  std::shared_ptr<ForwardGrad> fw_grad_;

  std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
  std::shared_ptr<hooks_list> cpp_hooks_list_;
}

AutogradMeta 中包含 autograd 所需的元素

  • grad_: Tensor 对应的 grad
  • grad_fn_: 反向 op
  • grad_accumulator_: 反向梯度累加器,Node 类型
  • cpp_hooks_list_, hooks_: 反向调用时的 hook

Node

Node

  • The most important method on Node is the call operator, which takes in a list of variables and produces a list of variables.
  • The precise size of these lists can be determined with num_inputs() and num_outputs().
  • Nodes are stitched together via their next_edge interface, which let you manipulate the set of outgoing edges of a Node.
  • You can add an edge with add_next_edge(), retrieve an edge with next_edge(index) and iterate over them via the next_edges() method.
// torch/csrc/autograd/function.h

using edge_list = std::vector<Edge>;

struct TORCH_API Node : std::enable_shared_from_this<Node> {
 public:
  explicit Node(
      uint64_t sequence_nr,
      edge_list&& next_edges = edge_list())
      : sequence_nr_(sequence_nr),
      next_edges_(std::move(next_edges)) {

    for (const Edge& edge: next_edges_) {
      update_topological_nr(edge);
    }
  }

  variable_list operator()(variable_list&& inputs) {
    ...
    return apply(std::move(inputs));
  }


  void update_topological_nr(const Edge& edge) {
    Node* node = edge.function.get();
    if (node) {
      auto topo_nr = node->topological_nr();
      if (topological_nr_ <= topo_nr) {
        topological_nr_ = topo_nr + 1;
      }
    }
  }

  void set_next_edge(size_t index, Edge edge) {
    update_topological_nr(edge);
    next_edges_[index] = std::move(edge);
  }

  void add_next_edge(Edge edge) {
    update_topological_nr(edge);
    next_edges_.push_back(std::move(edge));
  }

  void set_next_edges(edge_list&& next_edges) {
    next_edges_ = std::move(next_edges);
    for (const auto& next_edge : next_edges_) {
      update_topological_nr(next_edge);
    }
  }

  const Edge& next_edge(size_t index) const noexcept {
    return next_edges_[index];
  }

  const edge_list& next_edges() const noexcept {
    return next_edges_;
  }

  edge_list& next_edges() noexcept {
    return next_edges_;
  }

 protected:
  virtual variable_list apply(variable_list&& inputs) = 0;

  const uint64_t sequence_nr_;
  uint64_t topological_nr_ = 0;
  uint64_t thread_id_ = 0;
  edge_list next_edges_;
  std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
  std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
  at::SmallVector<InputMetadata, 2> input_metadata_;
};

可以看到

  • Node 的创建由 Edge 来完成,Node 中保存了连接情况和需要执行的方法。
  • Node 本身是 callable object, 通过虚函数 apply 被子类重载实现。
  • set_next_edge 方法可以添加 Edge

Edge

// torch/csrc/autograd/edge.h

struct Edge {
  Edge() noexcept : function(nullptr), input_nr(0) {}

  Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
      : function(std::move(function_)), input_nr(input_nr_) {}

  // Required for use in associative containers.
  bool operator==(const Edge& other) const noexcept {
    return this->function == other.function && this->input_nr == other.input_nr;
  }

  bool operator!=(const Edge& other) const noexcept {
    return !(*this == other);
  }

  /// The function this `Edge` points to.
  std::shared_ptr<Node> function;

  /// The identifier of a particular input to the function.
  uint32_t input_nr;
};

网络构建

先看简单的例子,

import torch
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)
c = torch.add(a, b)
d = torch.mul(a, c)
d.backward()
print(f"a grad:{a.grad} grad_fn:{a.grad_fn}")
print(f"b grad:{b.grad} grad_fn:{b.grad_fn}")
print(f"c grad:{c.grad} grad_fn:{c.grad_fn}")
print(f"d grad:{d.grad} grad_fn:{d.grad_fn}")
'''
a grad:4.0 grad_fn:None
b grad:1.0 grad_fn:None
c grad:None grad_fn:<AddBackward0 object at 0x7f6862dc76d0>
d grad:None grad_fn:<MulBackward0 object at 0x7f6862dc76d0>
'''

以上代码构建的网络如图所示

1

下面解析详细构建过程.

torch.add

torch/csrc/autograd/generated/ 目录需要 build 生成

// torch/csrc/autograd/generated/VariableType_2.cpp

// @generated by torchgen/gen.py from VariableType.cpp

at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  auto& self_ = unpack(self, "self", 0);
  auto& other_ = unpack(other, "other", 1);
  auto _any_requires_grad = compute_requires_grad( self, other );
  
  (void)_any_requires_grad;
  auto _any_has_forward_grad_result = (isFwGradDefined(self) || isFwGradDefined(other));
  (void)_any_has_forward_grad_result;
  std::shared_ptr<AddBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self, other ));
    grad_fn->other_scalar_type = other.scalar_type();
    grad_fn->alpha = alpha;
    grad_fn->self_scalar_type = self.scalar_type();
  }

  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);
  })();
  auto result = std::move(_tmp);

  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  return result;
}
  • 构建反向节点 AddBackward0
  • 计算 at::redispatch::add,结果保存至 result
  • 关联 AddBackward0 和 result

首先计算逻辑 add 的调用是自动生成的

// torch/include/ATen/RedispatchFunctions.h

// @generated by torchgen/gen.py from RedispatchFunctions.h

namespace at {
namespace redispatch {
    // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
    inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
        return at::_ops::add_Tensor::redispatch(dispatchKeySet, self, other, alpha);
    }
}
}

具体实现 kernel 如下

// torch/include/ATen/ops/add_ops.h

// @generated by torchgen/gen.py from Operator.h

namespace at {
namespace _ops {

struct TORCH_API add_Tensor {
  using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
  using ptr_schema = schema*;
  // See Note [static constexpr char* members for windows NVCC]
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
  static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
  static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
};

C API 入口

// torch/include/ATen/core/TensorBody.h

// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
    return at::_ops::add_Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}

// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
inline at::Tensor & Tensor::add_(const at::Tensor & other, const at::Scalar & alpha) const {
    return at::_ops::add__Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}

AddBackward0

// torch/csrc/autograd/generated/Functions.h

// @generated from ../tools/autograd/templates/Functions.h

struct TORCH_API AddBackward0 : public TraceableFunction {
  using TraceableFunction::TraceableFunction;
  variable_list apply(variable_list&& grads) override;
  std::string name() const override { return "AddBackward0"; }
  void release_variables() override { }

  at::ScalarType other_scalar_type;
  at::Scalar alpha;
  at::ScalarType self_scalar_type;
};

TraceableFunction

// torch/csrc/autograd/function.h

struct TraceableFunction : public Node {
  using Node::Node;
  bool is_traceable() final {
    return true;
  }
};

collect_next_edges 根据两个输入找到节点的 Edges

// torch/csrc/autograd/function.h

/// Return the next edges of all the given variables, or tuples of variables.
template <typename... Variables>
edge_list collect_next_edges(Variables&&... variables) {
  detail::MakeNextFunctionList make;
  make.apply(std::forward<Variables>(variables)...);
  return std::move(make.next_edges);
}

struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
  edge_list next_edges;
  using IterArgs<MakeNextFunctionList>::operator();
  void operator()(const Variable& variable) {
    if (variable.defined()) {
      next_edges.push_back(impl::gradient_edge(variable));
    } else {
      next_edges.emplace_back();
    }
  }
  void operator()(const Variable* variable) { ... }
  void operator()(const c10::optional<Variable>& variable) { ... }
};

gradient_edge 会返回一组 Edges

// torch/csrc/autograd/variable.cpp

Edge gradient_edge(const Variable& self) {
  if (const auto& gradient = self.grad_fn()) {
    return Edge(gradient, self.output_nr());
  } else {
    return Edge(grad_accumulator(self), 0);
  }
}

如果 self 是内部创建的(非叶子结点),即通过运算生成的,则返回 self 的 grad_fn 数据成员,否则(即用户创建的叶子结点)返回 AccumulateGrad 实例。

// torch/csrc/autograd/variable.cpp

std::shared_ptr<Node> grad_accumulator(const Variable& self) {
  auto autograd_meta = get_autograd_meta(self);
  c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
  auto intrusive_from_this =
      c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
  result = std::make_shared<AccumulateGrad>(
      Variable(std::move(intrusive_from_this)));
  autograd_meta->grad_accumulator_ = result;
  return result;
}

其中 AcculateGrad 中的 Variable 即 aten::Tensor 指向 self 的 TensorImpl 用于更新聚合梯度:

// torch/csrc/autograd/functions/accumulate_grad.h

struct TORCH_API AccumulateGrad : public Node {
  explicit AccumulateGrad(Variable variable_);
  variable_list apply(variable_list&& grads) override;
  Variable variable;
};

set_history

// torch/csrc/autograd/functions/utils.h

inline void set_history(
    at::Tensor& variable,
    const std::shared_ptr<Node>& grad_fn) {
  AT_ASSERT(grad_fn);
  if (variable.defined()) {
    auto output_nr = grad_fn->add_input_metadata(variable);
    impl::set_gradient_edge(variable, {grad_fn, output_nr});
  } else {
    grad_fn->add_input_metadata(Node::undefined_input());
  }
}

inline void set_history(
    std::vector<Variable>&& variables,
    const std::shared_ptr<Node>& grad_fn) {
  for (auto& variable : variables) {
    set_history(variable, grad_fn);
  }
}

set_gradient_edge 设置 Tensor 和 grad_fn_.

// torch/csrc/autograd/variable.cpp

void set_gradient_edge(const Variable& self, Edge edge) {
  auto* meta = materialize_autograd_meta(self);
  meta->grad_fn_ = std::move(edge.function);
  meta->output_nr_ = edge.input_nr;
  auto diff_view_meta = get_view_autograd_meta(self);
  if (diff_view_meta && diff_view_meta->has_bw_view()) {
    diff_view_meta->set_attr_version(self._version());
  }
}
// torch/csrc/autograd/generated/python_variable_methods.cpp

static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  const Tensor& self = THPVariable_Unpack(self_);
  static PythonArgParser parser({
    "add(Scalar alpha, Tensor other)|deprecated",
    "add(Tensor other, *, Scalar alpha=1)",
  }, /*traceable=*/true);
    
  ParsedArgs<2> parsed_args;
  auto _r = parser.parse(self_, args, kwargs, parsed_args);
  if(_r.has_torch_function()) {
    return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
  }
  switch (_r.idx) {
    case 0: {
      // [deprecated] aten::add(Tensor self, Scalar alpha, Tensor other) -> Tensor
      
      auto dispatch_add = [](const at::Tensor & self, const at::Scalar & alpha, const at::Tensor & other) -> at::Tensor {
        pybind11::gil_scoped_release no_gil;
        return self.add(other, alpha);
      };
      return wrap(dispatch_add(self, _r.scalar(0), _r.tensor(1)));
    } 
    case 1: {
      // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
      
      auto dispatch_add = [](const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) -> at::Tensor {
        pybind11::gil_scoped_release no_gil;
        return self.add(other, alpha);
      };
      return wrap(dispatch_add(self, _r.tensor(0), _r.scalar(1)));
    }
  }
}
// aten/src/ATen/core/dispatch/Dispatcher.h

// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
  detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
  // do not use RecordFunction on redispatch
#ifndef NDEBUG
  DispatchTraceNestingGuard debug_guard;
  if (show_dispatch_trace()) {
      auto nesting_value = dispatch_trace_nesting_value();
      for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
      std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
  }
#endif
  const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
  return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}
// aten/src/ATen/native/BinaryOps.cpp

Tensor add(const Tensor& self, const Scalar& other, const Scalar& alpha) {
  return at::add(self, wrapped_scalar_tensor(other), alpha);
}

Tensor& add_(Tensor& self, const Scalar& other, const Scalar& alpha) {
  return self.add_(wrapped_scalar_tensor(other), alpha);
}

torch.mul

流程类似

// torch/csrc/autograd/generated/VariableType_0.cpp

at::Tensor mul_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other) {
  auto& self_ = unpack(self, "self", 0);
  auto& other_ = unpack(other, "other", 1);
  auto _any_requires_grad = compute_requires_grad( self, other );

  (void)_any_requires_grad;
  auto _any_has_forward_grad_result = (isFwGradDefined(self) || isFwGradDefined(other));
  (void)_any_has_forward_grad_result;
  std::shared_ptr<MulBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<MulBackward0>(new MulBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self, other ));
    if (grad_fn->should_compute_output(1)) {
      grad_fn->self_ = SavedVariable(self, false);
    }
    grad_fn->other_scalar_type = other.scalar_type();
    grad_fn->self_scalar_type = self.scalar_type();
    if (grad_fn->should_compute_output(0)) {
      grad_fn->other_ = SavedVariable(other, false);
    }
  }
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::mul(ks & c10::after_autograd_keyset, self_, other_);
  })();
  auto result = std::move(_tmp);

  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }

  if (result_new_fw_grad_opt.has_value() && result_new_fw_grad_opt.value().defined() && result.defined()) {
    // The hardcoded 0 here will need to be updated once we support multiple levels.
    result._set_fw_grad(result_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false);
  }
  return result;
}

不同是的是因为乘法的求导和输入有关,所以我们在构建 MulBackward0 的时候需要把输入保存下来,即代码中的 SavedVariable 用于保存实例.

// torch/csrc/autograd/generated/Functions.h 

struct TORCH_API MulBackward0 : public TraceableFunction {
  using TraceableFunction::TraceableFunction;
  variable_list apply(variable_list&& grads) override;
  std::string name() const override { return "MulBackward0"; }
  void release_variables() override {
    std::lock_guard<std::mutex> lock(mutex_);
    self_.reset_data();
    other_.reset_data();
  }

  SavedVariable self_;
  at::ScalarType other_scalar_type;
  at::ScalarType self_scalar_type;
  SavedVariable other_;

};

使用 SavedVariable 来保存前向 Var 的数据区而不影响其管理反向 Op 的生命周期:

// torch/csrc/autograd/saved_variable.h

class TORCH_API SavedVariable {
 public:
  /// Reconstructs the saved variable. Pass `saved_for` as the gradient
  /// function if constructing the `SavedVariable` with it would have caused a
  /// circular reference.
  Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;

 private:
  at::Tensor data_;

  std::shared_ptr<ForwardGrad> fw_grad_;

  std::weak_ptr<Node> weak_grad_fn_;
  c10::VariableVersion version_counter_;

  uint32_t saved_version_ = 0;
  uint32_t output_nr_ = 0;
  bool was_default_constructed_ = true;
  bool is_inplace_on_view_ = false;
  bool saved_original_ = false;
  bool is_leaf_ = false;
  bool is_output_ = false;

  std::unique_ptr<SavedVariableHooks> hooks_;
  std::shared_ptr<Node> grad_fn_;
  std::weak_ptr<Node> grad_accumulator_;
  bool requires_grad_ = false;
};

Lifecycle

1