# python/paddle/io/dataloader/dataloader_iter.py
class _DataLoaderIterBase:
"""
Iterator implement of DataLoader, will load and feed mini-batch
data by setting in given dataloader.
Args:
loader(instance of DataLoader): instance of `paddle.io.DataLoader`
"""
def __init__(self, loader):
self._dataset = loader.dataset
self._feed_list = loader.feed_list or []
self._places = loader.places
self._return_list = loader.return_list
self._batch_sampler = loader.batch_sampler
self._drop_last = loader.drop_last
self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._prefetch_factor = loader.prefetch_factor
self._use_shared_memory = loader.use_shared_memory
self._timeout = (
loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
)
self._worker_init_fn = loader.worker_init_fn
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# will be get from:
# 1. multi-process mode: get data from workers' result queue
# 2. single-process mode: read mini-batch data in main process
self._blocking_queue = None
self._thread = None
self._thread_done_event = threading.Event()
# python/paddle/io/dataloader/dataloader_iter.py
class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
"""
Single process implement of DataLoaderIter, loading data from
loader.data in main process
"""
def __init__(self, loader):
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset, ...
)
# NOTE: _structrue_infos used to record the data structure of
# batch to restore batch structure after reading Tensor
# from blocking_queue in single-process mode. Note that
# only single process is used in single-process mode, we
# can record the data structure sequencely in a list without
# recording the send and recv index
self._structure_infos = []
self._blocking_queue_capacity = self._prefetch_factor * len(
self._places
)
self._init_thread()
def _init_thread(self):
self._blocking_queue = core.init_lod_tensor_blocking_queue(...)
self._reader = core.create_py_reader(...)
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(),)
)
self._thread.start()
def _thread_loop(self, legacy_expected_place):
while not self._thread_done_event.is_set():
try:
indices = next(self._sampler_iter)
batch = self._dataset_fetcher.fetch(
indices, self._thread_done_event
)
batch, structure = _flatten_batch(batch)
try:
# pack as LoDTensorArray
array = core.LoDTensorArray()
for slot in batch:
if isinstance(slot, (paddle.Tensor, core.eager.Tensor)):
slot = slot.value().get_tensor()
elif not isinstance(slot, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(slot, core.CPUPlace())
slot = tmp
array.append(slot)
try:
self._blocking_queue.push(array)
def __next__(self):
try:
if in_dygraph_mode():
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
data = _restore_batch(data, self._structure_infos.pop(0))
else:
if self._return_list:
data = self._reader.read_next_list()
else:
data = self._reader.read_next()
return data
class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader):
super().__init__(loader)
self._data_queue = None
self._outstanding_capacity = self._prefetch_factor * max(
self._num_workers, len(self._places)
)
# see _try_put_indices
self._thread_lock = threading.Lock()
self._init_workers()
for _ in range(self._outstanding_capacity):
self._try_put_indices()
self._init_thread()
def _init_workers(self):
self._data_queue = multiprocessing.Queue()
for i in range(self._num_workers):
indices_queue = multiprocessing.Queue()
self._indices_queues.append(indices_queue)
worker = multiprocessing.Process( target=_worker_loop, args=(self._dataset, ...),)
worker.start()
self._workers.append(worker)
def _init_thread(self):
self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._outstanding_capacity, len(self._places) > 1
)
self._reader = core.create_py_reader(self._blocking_queue, ...)
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(),)
)
self._thread.start()
def _reset(self):
# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with self._thread_lock:
self._resume_worker_cnt = self._num_workers
for worker_id in range(self._num_workers):
self._indices_queues[worker_id].put(_ResumeIteration())
self._batches_outstanding += 1
# all flag will be check in _thread_loop, simply wait here
while self._resume_worker_cnt > 0:
time.sleep(0.5)
# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while self._blocking_queue.size() >= len(self._places):
if in_dygraph_mode():
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
else:
if self._return_list:
self._reader.read_next_list()
else:
data = self._reader.read_next()
# 3. reset all states
self._send_idx = 0
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []
# set all worker status available
self._worker_status = [True] * self._num_workers
# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self._sampler_iter = iter(self._index_sampler)
for _ in range(self._outstanding_capacity):
self._try_put_indices()
def _thread_loop(self, legacy_expected_place):
while not self._thread_done_event.is_set():
batch = self._get_data()
if not self._thread_done_event.is_set():
try:
# pack as LoDTensorArray
array = core.LoDTensorArray()
if self._use_shared_memory:
for tensor in batch:
array.append(tensor)
else:
# LoDTensor not in shared memory is not
# serializable, cannot be create in workers
for slot in batch:
if isinstance(
slot, (paddle.Tensor, core.eager.Tensor)
):
slot = slot.value().get_tensor()
elif not isinstance(slot, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(slot, core.CPUPlace())
slot = tmp
array.append(slot)
if not self._blocking_queue.push(array):
self._blocking_queue.close()
def _get_data(self):
while not self._thread_done_event.is_set():
# For IterableDataset, batch indices is generated infinitely
# for each worker to raise StopIteration, but a StopIteration
# raising process will discard a batch indices which is count
# in _send_idx but will not increase _rcvd_idx, so we check
# whether the worker is still alive here to skip the discarded
# batch indices and increase _rcvd_idx
if self._dataset_kind == _DatasetKind.ITER:
while self._rcvd_idx < self._send_idx:
info = self._task_infos[self._rcvd_idx]
if len(info) == 3 or self._worker_status[info[0]]:
break
del self._task_infos[self._rcvd_idx]
self._rcvd_idx += 1
self._batches_outstanding -= 1
else:
# NOTE: when _rcvd_idx catch up _send_idx, which means
# one of following:
# 1. all 2 * num_workers batches have been loaded
# and stored in _blocking_queue
# 2. all data drained
# we need to let _thread blocking at _data_queue
# get_data to inoccupy CPU, otherwise may occupy
# CPU time for model running
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if not self._persistent_workers:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
if (
self._rcvd_idx in self._task_infos
and len(self._task_infos[self._rcvd_idx]) == 3
):
info = self._task_infos.pop(self._rcvd_idx)
self._structure_infos.append(info[2])
return info[1]
try:
# [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff:
# 1. get data with timeout, MP_STATUS_CHECK_INTERVAL(5s) as timeout
# default, if KeyboardInterrupt blocking, failed workers will be
# checked and raise RuntimeError to quit DataLoader in timeout
# exception handling.
# 2. if get data timeout and check workers all alive, continue to
# get data again
data = self._data_queue.get(timeout=self._timeout)
except Exception as e:
# check if thread done event set when waiting data
if self._thread_done_event.is_set():
continue
# check failed workers
failed_workers = []
for i, w in enumerate(self._workers):
if self._worker_status[i] and not w.is_alive():
failed_workers.append(w)
self._shutdown_worker(i)
if len(failed_workers) > 0:
self._exit_thread_unexpectedly()
pids = ', '.join(str(w.pid) for w in failed_workers)
raise RuntimeError(
"DataLoader {} workers exit unexpectedly, "
"pids: {}".format(len(failed_workers), pids)
)
# get(timeout) will call _poll(timeout) and may raise IOError
if isinstance(e, queue.Empty) or isinstance(e, IOError):
# continue on timeout to keep getting data from queue
continue
self._exit_thread_unexpectedly()
logging.error(
"DataLoader reader thread failed({}) to read data from "
"workers' result queue.".format(e)
)
raise e
else:
if self._dataset_kind == _DatasetKind.ITER and isinstance(
data, _IterableDatasetStopIteration
):
# if a worker get StopIteraion, we shutdown this worker,
# note that this batch indices to trigger StopIteration
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
if self._persistent_workers:
self._worker_status[data.worker_id] = False
else:
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
self._try_put_indices()
continue
idx, batch, structure = data
if (
isinstance(idx, _ResumeIteration)
and batch is None
and structure is None
):
return idx
if isinstance(batch, _WorkerException):
self._exit_thread_unexpectedly()
batch.reraise()
if idx == self._rcvd_idx:
del self._task_infos[idx]
self._structure_infos.append(structure)
return batch
else:
self._task_infos[idx] += (batch, structure)
continue
def _try_put_indices(self):
assert (
self._batches_outstanding <= self._outstanding_capacity
), "too many indices have been put to queue"
# In multi-process mode for IterableDataset, _try_put_indices will
# be called both in main process(for our implement has blocking queue,
# and blocking queue read is in main process) and thread, which may
# cause error following error
# 1. "ValueError: generator already executing" in next(self._sampler_iter)
# 2. re-enter in increase _send_idx
# add a lock for threading save, for _try_put_indices is only a slight
# function which is not in data reading pipeline, this lock almost no
# influence on performance
with self._thread_lock:
try:
indices = next(self._sampler_iter)
except StopIteration:
return
for i in range(self._num_workers):
worker_idx = next(self._workers_idx_cycle)
if self._worker_status[worker_idx]:
break
else:
return
self._indices_queues[worker_idx].put((self._send_idx, indices))
self._task_infos[self._send_idx] = (worker_idx,)
self._batches_outstanding += 1
self._send_idx += 1
def __del__(self):
self._try_shutdown_all()
def _shutdown_on_exit(self):
self._try_shutdown_all(1)
def __next__(self):
if in_profiler_mode():
trace_event = profiler.RecordEvent(
name="_DataLoaderIterMultiProcess",
event_type=profiler.TracerEventType.Dataloader,
)
trace_event.begin()
try:
benchmark().check_if_need_record(self)
benchmark().before_reader()
# _batches_outstanding here record the total batch data number
# in 'from after _try_put_indices to beforeoutput data', this
# value should be _outstanding_capacity if data is not drained,
# if _batches_outstanding is less than _places number, there are
# no enough data to generate next output, close blocking_queue and
# set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling
if self._batches_outstanding < len(self._places):
if self._persistent_workers:
raise StopIteration
else:
self._thread_done_event.set()
self._blocking_queue.close()
if in_dygraph_mode():
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
data = _restore_batch(data, self._structure_infos.pop(0))
else:
if self._return_list:
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
structs = [
self._structure_infos.pop(0)
for _ in range(len(self._places))
]
data = [_restore_batch(d, s) for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
if len(self._places) == 1:
data = data[0]
else:
data = self._reader.read_next()
self._on_output_batch()
benchmark().after_reader()
return data
except StopIteration:
if not self._persistent_workers:
self._reader.shutdown()
self._try_shutdown_all()
raise
finally:
if in_profiler_mode():
trace_event.end()
def _on_output_batch(self):
for _ in range(len(self._places)):
self._batches_outstanding -= 1
self._try_put_indices()