Develop

这里列的 horovod 对象是用于二次开发的接口。

Horovod 对象

# horovod/common/common.h

enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA };
enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT, IN_PROGRESS };
enum DeviceType { CPU, GPU };

// for gpu
struct Event {
  std::shared_ptr<gpuEvent_t> event;
  uint64_t event_idx;
  gpuStream_t stream = nullptr;
};

class Status {
public:
  Status();
  static Status OK();
  static Status UnknownError(const std::string& message);
  static Status PreconditionError(const std::string& message);
  static Status Aborted(const std::string& message);
  static Status InvalidArgument(const std::string& message);
  static Status InProgress();
  bool ok() const;
  bool in_progress() const;
  StatusType type() const;
  const std::string& reason() const;
  Event event;

private:
  StatusType type_ = StatusType::OK;
  std::string reason_;
  Status(StatusType type, std::string reason);
};

class TensorShape {
public:
  TensorShape() : shape_() {}
  TensorShape(std::vector<int64_t> vec) : shape_(vec) {}
  void AddDim(int64_t dim);
  void AppendShape(TensorShape& other);

  std::string DebugString() const;
  int dims() const;
  int64_t dim_size(int idx) const;
  int64_t num_elements() const;
  const std::vector<int64_t>& to_vector() const;

  inline bool operator==(const TensorShape& rhs) const {
    return shape_ == rhs.shape_;
  }

  inline bool operator!=(const TensorShape& rhs) const {
    return shape_ != rhs.shape_;
  }

private:
  std::vector<int64_t> shape_;
};

class ReadyEvent {};
class ReadyEventList {};

class PersistentBuffer {
public:
  virtual const void* AccessData(std::shared_ptr<OpContext> context) const = 0;
  virtual ~PersistentBuffer() = default;
};

class Tensor {
public:
  virtual const DataType dtype() const = 0;
  virtual const TensorShape shape() const = 0;
  virtual const void* data() const = 0;
  virtual int64_t size() const = 0;
  virtual ~Tensor() = default;
};

class OpContext {
public:
  // These allocators are fully synchronous, unlike TensorFlow counterparts.
  virtual Status
  AllocatePersistent(int64_t size,
                     std::shared_ptr<PersistentBuffer>* tensor) = 0;
  virtual Status AllocateOutput(TensorShape shape,
                                std::shared_ptr<Tensor>* tensor,
                                std::shared_ptr<ReadyEvent>* event = nullptr) = 0;
  virtual Status AllocateOutput(int output_index, TensorShape shape,
                                std::shared_ptr<Tensor>* tensor,
                                std::shared_ptr<ReadyEvent>* event = nullptr) {
    if (output_index == 0) {
      return AllocateOutput(std::move(shape), tensor);
    } else {
      throw std::logic_error("output_index != 0 not supported");
    }
  }
  virtual Status AllocateZeros(int64_t num_elements, DataType dtype,
                                std::shared_ptr<Tensor>* tensor) = 0;
  virtual Framework framework() const = 0;
  virtual ~OpContext() = default;
};

// A callback to call after the communication completes. Since the
// allreduce and allgather ops are asynchronous, this callback is what resumes
// computation after the reduction is completed.
using StatusCallback = std::function<void(const Status&)>;

// Table storing Tensors to be reduced, keyed by unique name.
// This table contains everything necessary to do the distributed operation.
struct TensorTableEntry {
  std::string tensor_name;
  std::shared_ptr<OpContext> context;
  std::shared_ptr<Tensor> tensor;
  std::shared_ptr<Tensor> output;
  // Identifier for the subset of Horovod processes partaking in this operation.
  int32_t process_set_id = 0;
  // Root rank for broadcast operation (relative to process set).
  int root_rank = 0;
  // List of events indicating that data is ready.
  ReadyEventList ready_event_list;
  // GPU to do reduction on, or CPU_DEVICE_ID in case of CPU.
  int device = CPU_DEVICE_ID;
  // A callback to call with the status.
  StatusCallback callback;
  // If we build with NVTX support: A range marking the start
  // and end of the distributed op for this tensor (may be
  // shared by multiple tensors).
  SharedNvtxOpRange nvtx_op_range;

  // Alltoall splits (if tensor is for an Alltoall operation)
  // Note: splits are stored in TensorTableEntry to avoid N^2
  // storage complexity of collecting all worker split arrays
  // on coordinator rank.
  std::vector<int32_t> splits;
  std::shared_ptr<Tensor> received_splits;

  void FinishWithCallback(const Status& status);
};

Message

Request

// horovod/common/message.h

// Request 是非 0 worker 向 0 号 worker 即 coordinator 发送消息的消息体
class Request {
  enum RequestType {
    ALLREDUCE = 0,
    ALLGATHER = 1,
    BROADCAST = 2,
    JOIN = 3,
    ADASUM = 4,
    ALLTOALL = 5,
    BARRIER = 6,
    REDUCESCATTER = 7
  };

  static const std::string& RequestType_Name(RequestType value);
  int32_t request_rank() const;
  void set_request_rank(int32_t value);
  RequestType request_type() const;
  void set_request_type(RequestType value);
  DataType tensor_type() const;
  void set_tensor_type(DataType value);
  const std::string& tensor_name() const;
  void set_tensor_name(const std::string& value);
  int32_t root_rank() const;
  void set_root_rank(int32_t value);
  int32_t device() const;
  void set_device(int32_t value);
  int32_t group_id() const;
  void set_group_id(int32_t value);
  const std::vector<int64_t>& tensor_shape() const;
  void set_tensor_shape(const std::vector<int64_t>& value);
  void add_tensor_shape(int64_t value);
  double prescale_factor() const;
  double postscale_factor() const;
  void set_prescale_factor(double prescale_factor);
  void set_postscale_factor(double postscale_factor);
  static void ParseFromBytes(Request& request, const uint8_t* input);
  static void SerializeToString(const Request& request, std::string& output);

  int32_t request_rank_ = 0;
  RequestType request_type_ = RequestType::ALLREDUCE;
  DataType tensor_type_ = DataType::HOROVOD_UINT8;
  int32_t root_rank_ = 0;
  int32_t device_ = 0;
  int32_t group_id_ = NULL_GROUP_ID;
  std::string tensor_name_;
  std::vector<int64_t> tensor_shape_;
  double prescale_factor_ = 1.0;
  double postscale_factor_ = 1.0;
};

class RequestList {
  const std::vector<Request>& requests() const;
  void set_requests(const std::vector<Request>& value);
  void add_request(const Request& value);
  void emplace_request(Request&& value);
  bool shutdown() const;
  void set_shutdown(bool value);
  static void ParseFromBytes(RequestList& request_list, const uint8_t* input);
  static void SerializeToString(const RequestList& request_list, std::string& output);

  std::vector<Request> requests_;
  bool shutdown_ = false;
};

Response

// Response 是 0 号 worker 即 coordinator 向非 0 worker 发送消息的消息体
class Response {
public:
  enum ResponseType {
    ALLREDUCE = 0,
    ALLGATHER = 1,
    BROADCAST = 2,
    JOIN = 3,
    ADASUM = 4,
    ALLTOALL = 5,
    BARRIER = 6,
    REDUCESCATTER = 7,
    ERROR = 8
  };

  static const std::string& ResponseType_Name(ResponseType value);
  ResponseType response_type() const;
  void set_response_type(ResponseType value);
  const std::vector<std::string>& tensor_names() const;
  DataType tensor_type() const;
  void set_tensor_type(DataType value);
  const std::string tensor_names_string() const;
  void set_tensor_names(const std::vector<std::string>& value);
  void add_tensor_name(const std::string& value);
  void add_tensor_name(std::string&& value);
  const std::string& error_message() const;
  void set_error_message(const std::string& value);
  const std::vector<int32_t>& devices() const;
  void set_devices(const std::vector<int32_t>& value);
  void add_device(int32_t value);
  const std::vector<int64_t>& tensor_sizes() const;
  void set_tensor_sizes(const std::vector<int64_t>& value);
  void add_tensor_size(int64_t value);
  void add_allgather_response(const Response& response);
  double prescale_factor() const;
  double postscale_factor() const;
  void set_prescale_factor(double prescale_factor);
  void set_postscale_factor(double postscale_factor);
  int last_joined_rank() const;
  void set_last_joined_rank(int value);
  static void ParseFromBytes(Response& response, const uint8_t* input);
  static void SerializeToString(const Response& response, std::string& output);

  ResponseType response_type_ = ResponseType::ALLREDUCE;
  std::vector<std::string> tensor_names_;
  DataType tensor_type_ = DataType::HOROVOD_UINT8;
  std::string error_message_;
  std::vector<int32_t> devices_;
  std::vector<int64_t> tensor_sizes_;
  double prescale_factor_ = 1.0;
  double postscale_factor_ = 1.0;
  int last_joined_rank_ = -1;
};

class ResponseList {
  const std::vector<Response>& responses() const;
  void set_responses(const std::vector<Response>& value);
  void add_response(const Response& value);
  void add_response(Response&& value);
  void emplace_response(Response&& value);
  bool shutdown() const;
  void set_shutdown(bool value);
  static void ParseFromBytes(ResponseList& response_list, const uint8_t* input);
  static void SerializeToString(const ResponseList& response_list, std::string& output);

  std::vector<Response> responses_;
  bool shutdown_ = false;
};

ResponseCache

// horovod/common/response_cache.h

struct TensorParams {
  DataType dtype;
  std::vector<int64_t> shape;
  int32_t device;
};

// LRU cache of Responses
class ResponseCache {
public:
  ResponseCache() = default;
  ResponseCache(const ResponseCache&) = delete;

  enum CacheState { MISS = 0, HIT = 1, INVALID = 2 };

  void clear();

  void set_capacity(uint32_t capacity);

  uint32_t capacity() const;

  size_t num_active_bits() const;

  CacheState cached(const Request& message) const;

  CacheState cached(const Response& response, const TensorParams& params,
                    bool joined = false) const;

  void put(const Response& response, TensorQueue& tensor_queue,
           bool joined = false);

  const Response& get_response(uint32_t cache_bit);

  const Response& peek_response(uint32_t cache_bit) const;

  uint32_t peek_cache_bit(const Request& message) const;

  uint32_t peek_cache_bit(const std::string& tensor_name) const;

  std::vector<uint32_t> list_all_bits() const;

  void erase_response(uint32_t cache_bit);

  void update_cache_bits();

private:
  void put_(const Response& response, TensorParams& params,
            bool joined = false);

  uint32_t capacity_ = 0;

  // List containing cached entries. Each entry in the cache is a pair
  // of a Response and a TensorParams struct.
  std::list<std::pair<Response, TensorParams>> cache_;

  // Vector of iterators to cache entries. Indexed by cache bit.
  std::vector<std::list<std::pair<Response, TensorParams>>::iterator>
      cache_iters_;

  // Lookup table mapping tensor names to assigned cache bits.
  std::unordered_map<std::string, uint32_t> tensor_name_to_bit_;

  bool bits_outdated_ = false;

  bool print_warning_ = true;
};

// Helper class to coordinate cache and state information
// across workers. Uses global controller operations on a bit vector
// for cheaper coordination.
class CacheCoordinator {
public:
  explicit CacheCoordinator(size_t num_active_bits_);

  void record_hit(uint32_t bit);

  void record_invalid_bit(uint32_t bit);

  void erase_hit(uint32_t bit);

  void set_should_shut_down(bool should_shut_down);

  void set_uncached_in_queue(bool uncached_in_queue);

  const std::set<uint32_t>& cache_hits() const;

  const std::set<uint32_t>& invalid_bits() const;

  const std::set<uint32_t>& timeline_bits() const;

  bool should_shut_down() const;

  bool uncached_in_queue() const;

  // Method to sync state and bit sets across workers.
  void sync(std::shared_ptr<Controller> controller, bool timeline_enabled);

private:
  enum StatusBit {
    SHOULD_SHUT_DOWN = 0,
    UNCACHED_IN_QUEUE = 1,
    INVALID_IN_QUEUE = 2
  };

  // Number of active bits in the cache. Required to size the
  // bitvector identically across workers.
  size_t num_active_bits_;

  // Set of cache hit bits. After sync(), contains only common
  // cache hit bits across workers.
  std::set<uint32_t> cache_hits_;

  // Set of invalid bits. After sync(), contains only common
  // invalid bits across workers.
  std::set<uint32_t> invalid_bits_;

  // Set of bits for timeline handling. After sync(), contains bits
  // where at least one worker recorded a cache hit. This indicates
  // that the timeline negotion phase should be started/continued.
  std::set<uint32_t> timeline_bits_;

  // States used externally in cycle loop.
  bool should_shut_down_ = false;
  bool uncached_in_queue_ = false;

  // State used internally to trigger second bit vector communication
  // to sync invalid bits.
  bool invalid_in_queue_ = false;

  std::vector<long long> bitvector_;

  bool synced_ = false;
};