Elastic
弹性启动
前述接正常启动部分,弹性使用 gloo 启动
def _run_elastic(args):
settings = elastic_settings.ElasticSettings(discovery=discover_hosts,...)
return gloo_run_elastic(settings, env, args.run_func if args.run_func else args.command, executable)
from horovod.runner.gloo_run import gloo_run, gloo_run_elastic
# horovod/runner/gloo_run.py
def gloo_run_elastic(settings, env, command_or_func, executable) -> Optional[Any]:
rendezvous = RendezvousServer(settings.verbose)
return launch_gloo_elastic(command_or_func, exec_command, settings, env, get_common_interfaces, rendezvous, executable)
def launch_gloo_elastic(command_or_func, exec_command, settings, env, get_common_interfaces, rendezvous, executable):
driver = ElasticDriver(rendezvous, settings.discovery,
settings.min_num_proc, settings.max_num_proc,
timeout=settings.elastic_timeout,
reset_limit=settings.reset_limit,
cooldown_range=settings.cooldown_range,
verbose=settings.verbose)
handler = create_rendezvous_handler(driver)
global_rendezv_port = rendezvous.start(handler)
driver.wait_for_available_slots(settings.num_proc)
run_command = get_run_command(command, server_ip, nics, global_rendezv_port, elastic=True)
create_worker = _create_elastic_worker_fn(exec_command, run_command, env, event)
driver.start(settings.num_proc, create_worker)
res = driver.get_results()
driver.stop()
Driver
ElasticDriver 中有两个线程
- 一个线程 while 循环执行
_discover_hosts
监控、报告节点更新 - 一个线程执行
run_worker
启动 worker
注意 driver 不会阻塞,启动 worker 后返回 launch 函数,等待获取 results, 具体地,driver.get_results()
是一个 while 循环会造成阻塞。
# horovod/runner/elastic/driver.py
class ElasticDriver(object):
def __init__(self, rendezvous, ...):
self._worker_registry = WorkerStateRegistry(self, self._host_manager, reset_limit=reset_limit)
self._results = ResultsRecorder()
self._discovery_thread = threading.Thread(target=self._discover_hosts)
def start(self, num_proc, create_worker_fn):
self._activate_workers(num_proc)
def get_results(self):
return self._results.get_results()
def register_worker_server(self, host, slot, addresses, secret_key):
self._worker_clients[(host, slot)] = WorkerNotificationClient(
addresses, secret_key, self._verbose)
def _activate_workers(self, min_num_proc):
current_hosts = self.wait_for_available_slots(min_num_proc)
pending_slots = self._update_host_assignments(current_hosts)
self._worker_registry.reset(self.world_size())
self._start_worker_processes(pending_slots)
def _discover_hosts(self):
while not self._shutdown.is_set():
self._wait_hosts_cond.acquire()
try:
update_res = self._host_manager.update_available_hosts()
if update_res != HostUpdateResult.no_update:
self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
self._wait_hosts_cond.notify_all()
finally:
self._wait_hosts_cond.release()
self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)
def _notify_workers_host_changes(self, current_hosts, update_res):
coordinator_client.notify_hosts_updated(timestamp, update_res)
def _start_worker_processes(self, pending_slots):
for slot_info in pending_slots:
self._start_worker_process(slot_info)
def _start_worker_process(self, slot_info):
def run_worker():
res = create_worker_fn(slot_info, [shutdown_event, host_event])
exit_code, timestamp = res
self._handle_worker_exit(slot_info, exit_code, timestamp)
thread = threading.Thread(target=run_worker)
thread.daemon = True
thread.start()
self._results.expect(thread)
Usage
hvd.init()
torch.cuda.set_device(hvd.local_rank())
dataset = ...
model = ...
optimizer = optim.SGD(model.parameters(), lr * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)
@hvd.elastic.run
def train(state):
batch_offset = state.batch
for state.epoch in range(state.epoch, epochs):
for state.batch in range(state.batch, batches_per_epoch):
data, target = get_random_batch()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if state.batch % batches_per_commit == 0:
state.commit()
state.batch = 0
def on_state_reset():
# adjust learning rate on reset
for param_group in optimizer.param_groups:
param_group['lr'] = lr * hvd.size()
state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset])
train(state)
Implementation
在训练的循环函数上会有装饰器,用于识别错误和自身的恢复,另外通过 notification_manager
别的节点也会感知到错误的发生。
# horovod/common/elastic.py
notification_manager = WorkerNotificationManager()
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
try:
while True:
try:
if not skip_sync:
state.sync()
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore()
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
return wrapper
# horovod/common/elastic.py
class State(object):
def register_reset_callbacks(self, callbacks):
pass
def on_reset(self):
self._host_messages = queue.Queue()
self.reset()
for callback in self._reset_callbacks:
callback()
def on_hosts_updated(self, timestamp, update_res):
self._host_messages.put((timestamp, update_res))
def commit(self):
self.save()
self.check_host_updates()
def check_host_updates(self):
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
def save(self):
pass
def restore(self):
pass
def sync(self):
pass
def reset(self):
pass
class ObjectState(State):
def restore(self):
self._set_attrs()
def sync(self):
if self._saved_state:
self._saved_state = self._bcast_object(self._saved_state)
self._set_attrs()
# horovod/torch/elastic/state.py
class TorchState(ObjectState):
def restore(self):
for handler in self._handlers.values():
handler.restore()
super(TorchState, self).restore()
因为 hvd.init()
之后,一直会有后台进程在负责真正的通信过程,在有节点变化时,通信组需要重建,具体实现如下。
// horovod/common/operations.cc
bool RunLoopOnce(HorovodGlobalState& state) {
if (state.dynamic_process_sets) {
// Initialize any newly added process set that has been registered by all
// Horovod processes and remove a process set that has been marked for
// removal by all Horovod processes.
if (state.control_operation == LibType::GLOO) {
state.process_set_table.InitializeRegisteredAndRemoveMarkedIfReady(
global_gloo_context);
}
}
}
Reinitialize Horovod context performing a new round of rendezvous.
负责重建通信域
// horovod/common/process_set.cc
template <class Context>
int32_t ProcessSetTable::InitializeRegisteredAndRemoveMarkedIfReady_(
const Context& global_context, const Status& removal_status) {
if (registered_count_agreement) {
for (auto id : Ids()) {
bool newly_registered = Get(id).Initialize(global_context);
}
}
if (id_to_be_removed_agreement) {
id_to_process_set_[id_to_be_removed_].Finalize(removal_status);
DeregisterProcessSet(id_to_be_removed_);
}
}
ProcessSetTable::ProcessSetTable() {
auto process_set_id = RegisterProcessSet();
}
int32_t ProcessSetTable::RegisterProcessSet(std::vector<int> global_ranks) {
// ProcessSet
ids_.push_back(id);
}
bool ProcessSet::Initialize(const GlooContext& global_gloo_context) {
gloo_context.InitializeForProcessSet(global_gloo_context,
registered_global_ranks);
}
// horovod/common/gloo/gloo_context.cc
void GlooContext::InitializeForProcessSet(const GlooContext& global_context,
const std::vector<int>& registered_ranks) {
ctx = Rendezvous(HOROVOD_GLOO_GLOBAL_PREFIX + process_set_suffix,
rendezvous_addr_env, rendezvous_port, rank, size, dev,
timeout_);
}
接口的调用顺序
# horovod/common/process_sets.py
def add_process_set(process_set: Union[ProcessSet, Sequence[int]]) -> ProcessSet:
process_set_id = _basics._add_process_set_impl(process_set.ranks)
# horovod/common/basics.py
class HorovodBasics(object):
def _add_process_set_impl(self, ranks: Sequence[int]) -> Optional[int]:
result = int(self.MPI_LIB_CTYPES.horovod_add_process_set(
(ctypes.c_int * nrank)(*ranks), ctypes.c_int(nrank)))
// horovod/common/operations.cc
int horovod_add_process_set(const int* ranks, int nrank) {
ProcessSet* process_set = nullptr;
{
process_set = &GetProcessSetOrAddUnitialized(
ranks && nrank > 0 ? std::vector<int>(ranks, ranks + nrank) : std::vector<int>(),
id);
}
}
ProcessSet& GetProcessSetOrAddUnitialized(std::vector<int> ranks, int& id) {
id = horovod_global.process_set_table.FindId(ranks);
id = horovod_global.process_set_table.RegisterProcessSet(std::move(ranks));
auto& process_set = horovod_global.process_set_table.Get(id);
EnrichProcessSetWithGlooController(process_set);
}
void EnrichProcessSetWithGlooController(ProcessSet& process_set) {
process_set.controller.reset(new GlooController(
process_set.response_cache, process_set.tensor_queue,
horovod_global.timeline, horovod_global.parameter_manager,
process_set.group_table, horovod_global.timeline_controller,
process_set.gloo_context));
}
Workflow Summary
hvd.elastic.run
decorator- process
HorovodInternalError
- reinitialize context, new rendezvous
- restore state, broadcasting from new worker-0
Reference
- https://horovod.readthedocs.io/en/stable/elastic_include.html