| r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter | |
| To support these two classes, in `./_utils` we define many utility methods and | |
| functions to be run in multiprocessing. E.g., the data loading worker loop is | |
| in `./_utils/worker.py`. | |
| """ | |
| import threading | |
| import itertools | |
| import warnings | |
| import multiprocessing as python_multiprocessing | |
| import torch | |
| import torch.multiprocessing as multiprocessing | |
| from torch._utils import ExceptionWrapper | |
| from torch.multiprocessing import Queue as queue | |
| from torch._six import string_classes | |
| from torch.utils.data.dataset import IterableDataset | |
| from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler | |
| from torch.utils.data import _utils | |
| from .my_data_worker import worker_loop | |
| __all__ = ["MyDataLoader"] | |
| get_worker_info = _utils.worker.get_worker_info | |
| # This function used to be defined in this file. However, it was moved to | |
| # _utils/collate.py. Although it is rather hard to access this from user land | |
| # (one has to explicitly directly `import torch.utils.data.dataloader`), there | |
| # probably is user code out there using it. This aliasing maintains BC in this | |
| # aspect. | |
| default_collate = _utils.collate.default_collate | |
| class _DatasetKind(object): | |
| Map = 0 | |
| Iterable = 1 | |
| def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): | |
| if kind == _DatasetKind.Map: | |
| return _utils.fetch._MapDatasetFetcher( | |
| dataset, auto_collation, collate_fn, drop_last | |
| ) | |
| else: | |
| return _utils.fetch._IterableDatasetFetcher( | |
| dataset, auto_collation, collate_fn, drop_last | |
| ) | |
| class _InfiniteConstantSampler(Sampler): | |
| r"""Analogous to ``itertools.repeat(None, None)``. | |
| Used as sampler for :class:`~torch.utils.data.IterableDataset`. | |
| Arguments: | |
| data_source (Dataset): dataset to sample from | |
| """ | |
| def __init__(self): | |
| super(_InfiniteConstantSampler, self).__init__(None) | |
| def __iter__(self): | |
| while True: | |
| yield None | |
| class MyDataLoader(object): | |
| r""" | |
| Data loader. Combines a dataset and a sampler, and provides an iterable over | |
| the given dataset. | |
| The :class:`~torch.utils.data.DataLoader` supports both map-style and | |
| iterable-style datasets with single- or multi-process loading, customizing | |
| loading order and optional automatic batching (collation) and memory pinning. | |
| See :py:mod:`torch.utils.data` documentation page for more details. | |
| Arguments: | |
| dataset (Dataset): dataset from which to load the data. | |
| batch_size (int, optional): how many samples per batch to load | |
| (default: ``1``). | |
| shuffle (bool, optional): set to ``True`` to have the data reshuffled | |
| at every epoch (default: ``False``). | |
| sampler (Sampler, optional): defines the strategy to draw samples from | |
| the dataset. If specified, :attr:`shuffle` must be ``False``. | |
| batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of | |
| indices at a time. Mutually exclusive with :attr:`batch_size`, | |
| :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. | |
| num_workers (int, optional): how many subprocesses to use for data | |
| loading. ``0`` means that the data will be loaded in the main process. | |
| (default: ``0``) | |
| collate_fn (callable, optional): merges a list of samples to form a | |
| mini-batch of Tensor(s). Used when using batched loading from a | |
| map-style dataset. | |
| pin_memory (bool, optional): If ``True``, the data loader will copy Tensors | |
| into CUDA pinned memory before returning them. If your data elements | |
| are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, | |
| see the example below. | |
| drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
| if the dataset size is not divisible by the batch size. If ``False`` and | |
| the size of dataset is not divisible by the batch size, then the last batch | |
| will be smaller. (default: ``False``) | |
| timeout (numeric, optional): if positive, the timeout value for collecting a batch | |
| from workers. Should always be non-negative. (default: ``0``) | |
| worker_init_fn (callable, optional): If not ``None``, this will be called on each | |
| worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as | |
| input, after seeding and before data loading. (default: ``None``) | |
| .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` | |
| cannot be an unpicklable object, e.g., a lambda function. See | |
| :ref:`multiprocessing-best-practices` on more details related | |
| to multiprocessing in PyTorch. | |
| .. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used. | |
| When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, | |
| ``len(dataset)`` (if implemented) is returned instead, regardless | |
| of multi-process loading configurations, because PyTorch trust | |
| user :attr:`dataset` code in correctly handling multi-process | |
| loading to avoid duplicate data. See `Dataset Types`_ for more | |
| details on these two types of datasets and how | |
| :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_. | |
| """ | |
| __initialized = False | |
| def __init__( | |
| self, | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| sampler=None, | |
| batch_sampler=None, | |
| num_workers=0, | |
| collate_fn=None, | |
| pin_memory=False, | |
| drop_last=False, | |
| timeout=0, | |
| worker_init_fn=None, | |
| multiprocessing_context=None, | |
| ): | |
| torch._C._log_api_usage_once("python.data_loader") | |
| if num_workers < 0: | |
| raise ValueError( | |
| "num_workers option should be non-negative; " | |
| "use num_workers=0 to disable multiprocessing." | |
| ) | |
| if timeout < 0: | |
| raise ValueError("timeout option should be non-negative") | |
| self.dataset = dataset | |
| self.num_workers = num_workers | |
| self.pin_memory = pin_memory | |
| self.timeout = timeout | |
| self.worker_init_fn = worker_init_fn | |
| self.multiprocessing_context = multiprocessing_context | |
| # Arg-check dataset related before checking samplers because we want to | |
| # tell users that iterable-style datasets are incompatible with custom | |
| # samplers first, so that they don't learn that this combo doesn't work | |
| # after spending time fixing the custom sampler errors. | |
| if isinstance(dataset, IterableDataset): | |
| self._dataset_kind = _DatasetKind.Iterable | |
| # NOTE [ Custom Samplers and `IterableDataset` ] | |
| # | |
| # `IterableDataset` does not support custom `batch_sampler` or | |
| # `sampler` since the key is irrelevant (unless we support | |
| # generator-style dataset one day...). | |
| # | |
| # For `sampler`, we always create a dummy sampler. This is an | |
| # infinite sampler even when the dataset may have an implemented | |
| # finite `__len__` because in multi-process data loading, naive | |
| # settings will return duplicated data (which may be desired), and | |
| # thus using a sampler with length matching that of dataset will | |
| # cause data lost (you may have duplicates of the first couple | |
| # batches, but never see anything afterwards). Therefore, | |
| # `Iterabledataset` always uses an infinite sampler, an instance of | |
| # `_InfiniteConstantSampler` defined above. | |
| # | |
| # A custom `batch_sampler` essentially only controls the batch size. | |
| # However, it is unclear how useful it would be since an iterable-style | |
| # dataset can handle that within itself. Moreover, it is pointless | |
| # in multi-process data loading as the assignment order of batches | |
| # to workers is an implementation detail so users can not control | |
| # how to batchify each worker's iterable. Thus, we disable this | |
| # option. If this turns out to be useful in future, we can re-enable | |
| # this, and support custom samplers that specify the assignments to | |
| # specific workers. | |
| if shuffle is not False: | |
| raise ValueError( | |
| "DataLoader with IterableDataset: expected unspecified " | |
| "shuffle option, but got shuffle={}".format(shuffle) | |
| ) | |
| elif sampler is not None: | |
| # See NOTE [ Custom Samplers and IterableDataset ] | |
| raise ValueError( | |
| "DataLoader with IterableDataset: expected unspecified " | |
| "sampler option, but got sampler={}".format(sampler) | |
| ) | |
| elif batch_sampler is not None: | |
| # See NOTE [ Custom Samplers and IterableDataset ] | |
| raise ValueError( | |
| "DataLoader with IterableDataset: expected unspecified " | |
| "batch_sampler option, but got batch_sampler={}".format( | |
| batch_sampler | |
| ) | |
| ) | |
| else: | |
| self._dataset_kind = _DatasetKind.Map | |
| if sampler is not None and shuffle: | |
| raise ValueError("sampler option is mutually exclusive with " "shuffle") | |
| if batch_sampler is not None: | |
| # auto_collation with custom batch_sampler | |
| if batch_size != 1 or shuffle or sampler is not None or drop_last: | |
| raise ValueError( | |
| "batch_sampler option is mutually exclusive " | |
| "with batch_size, shuffle, sampler, and " | |
| "drop_last" | |
| ) | |
| batch_size = None | |
| drop_last = False | |
| elif batch_size is None: | |
| # no auto_collation | |
| if shuffle or drop_last: | |
| raise ValueError( | |
| "batch_size=None option disables auto-batching " | |
| "and is mutually exclusive with " | |
| "shuffle, and drop_last" | |
| ) | |
| if sampler is None: # give default samplers | |
| if self._dataset_kind == _DatasetKind.Iterable: | |
| # See NOTE [ Custom Samplers and IterableDataset ] | |
| sampler = _InfiniteConstantSampler() | |
| else: # map-style | |
| if shuffle: | |
| sampler = RandomSampler(dataset) | |
| else: | |
| sampler = SequentialSampler(dataset) | |
| if batch_size is not None and batch_sampler is None: | |
| # auto_collation without custom batch_sampler | |
| batch_sampler = BatchSampler(sampler, batch_size, drop_last) | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| self.sampler = sampler | |
| self.batch_sampler = batch_sampler | |
| if collate_fn is None: | |
| if self._auto_collation: | |
| collate_fn = _utils.collate.default_collate | |
| else: | |
| collate_fn = _utils.collate.default_convert | |
| self.collate_fn = collate_fn | |
| self.__initialized = True | |
| self._IterableDataset_len_called = ( | |
| None # See NOTE [ IterableDataset and __len__ ] | |
| ) | |
| def multiprocessing_context(self): | |
| return self.__multiprocessing_context | |
| def multiprocessing_context(self, multiprocessing_context): | |
| if multiprocessing_context is not None: | |
| if self.num_workers > 0: | |
| if not multiprocessing._supports_context: | |
| raise ValueError( | |
| "multiprocessing_context relies on Python >= 3.4, with " | |
| "support for different start methods" | |
| ) | |
| if isinstance(multiprocessing_context, string_classes): | |
| valid_start_methods = multiprocessing.get_all_start_methods() | |
| if multiprocessing_context not in valid_start_methods: | |
| raise ValueError( | |
| ( | |
| "multiprocessing_context option " | |
| "should specify a valid start method in {}, but got " | |
| "multiprocessing_context={}" | |
| ).format(valid_start_methods, multiprocessing_context) | |
| ) | |
| multiprocessing_context = multiprocessing.get_context( | |
| multiprocessing_context | |
| ) | |
| if not isinstance( | |
| multiprocessing_context, python_multiprocessing.context.BaseContext | |
| ): | |
| raise ValueError( | |
| ( | |
| "multiprocessing_context option should be a valid context " | |
| "object or a string specifying the start method, but got " | |
| "multiprocessing_context={}" | |
| ).format(multiprocessing_context) | |
| ) | |
| else: | |
| raise ValueError( | |
| ( | |
| "multiprocessing_context can only be used with " | |
| "multi-process loading (num_workers > 0), but got " | |
| "num_workers={}" | |
| ).format(self.num_workers) | |
| ) | |
| self.__multiprocessing_context = multiprocessing_context | |
| def __setattr__(self, attr, val): | |
| if self.__initialized and attr in ( | |
| "batch_size", | |
| "batch_sampler", | |
| "sampler", | |
| "drop_last", | |
| "dataset", | |
| ): | |
| raise ValueError( | |
| "{} attribute should not be set after {} is " | |
| "initialized".format(attr, self.__class__.__name__) | |
| ) | |
| super(MyDataLoader, self).__setattr__(attr, val) | |
| def __iter__(self): | |
| if self.num_workers == 0: | |
| return _SingleProcessDataLoaderIter(self) | |
| else: | |
| return _MultiProcessingDataLoaderIter(self) | |
| def _auto_collation(self): | |
| return self.batch_sampler is not None | |
| def _index_sampler(self): | |
| # The actual sampler used for generating indices for `_DatasetFetcher` | |
| # (see _utils/fetch.py) to read data at each time. This would be | |
| # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. | |
| # We can't change `.sampler` and `.batch_sampler` attributes for BC | |
| # reasons. | |
| if self._auto_collation: | |
| return self.batch_sampler | |
| else: | |
| return self.sampler | |
| def __len__(self): | |
| if self._dataset_kind == _DatasetKind.Iterable: | |
| # NOTE [ IterableDataset and __len__ ] | |
| # | |
| # For `IterableDataset`, `__len__` could be inaccurate when one naively | |
| # does multi-processing data loading, since the samples will be duplicated. | |
| # However, no real use case should be actually using that behavior, so | |
| # it should count as a user error. We should generally trust user | |
| # code to do the proper thing (e.g., configure each replica differently | |
| # in `__iter__`), and give us the correct `__len__` if they choose to | |
| # implement it (this will still throw if the dataset does not implement | |
| # a `__len__`). | |
| # | |
| # To provide a further warning, we track if `__len__` was called on the | |
| # `DataLoader`, save the returned value in `self._len_called`, and warn | |
| # if the iterator ends up yielding more than this number of samples. | |
| length = self._IterableDataset_len_called = len(self.dataset) | |
| return length | |
| else: | |
| return len(self._index_sampler) | |
| class _BaseDataLoaderIter(object): | |
| def __init__(self, loader): | |
| self._dataset = loader.dataset | |
| self._dataset_kind = loader._dataset_kind | |
| self._IterableDataset_len_called = loader._IterableDataset_len_called | |
| self._auto_collation = loader._auto_collation | |
| self._drop_last = loader.drop_last | |
| self._index_sampler = loader._index_sampler | |
| self._num_workers = loader.num_workers | |
| self._pin_memory = loader.pin_memory and torch.cuda.is_available() | |
| self._timeout = loader.timeout | |
| self._collate_fn = loader.collate_fn | |
| self._sampler_iter = iter(self._index_sampler) | |
| self._base_seed = torch.empty((), dtype=torch.int64).random_().item() | |
| self._num_yielded = 0 | |
| def __iter__(self): | |
| return self | |
| def _next_index(self): | |
| return next(self._sampler_iter) # may raise StopIteration | |
| def _next_data(self): | |
| raise NotImplementedError | |
| def __next__(self): | |
| data = self._next_data() | |
| self._num_yielded += 1 | |
| if ( | |
| self._dataset_kind == _DatasetKind.Iterable | |
| and self._IterableDataset_len_called is not None | |
| and self._num_yielded > self._IterableDataset_len_called | |
| ): | |
| warn_msg = ( | |
| "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " | |
| "samples have been fetched. " | |
| ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) | |
| if self._num_workers > 0: | |
| warn_msg += ( | |
| "For multiprocessing data-loading, this could be caused by not properly configuring the " | |
| "IterableDataset replica at each worker. Please see " | |
| "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." | |
| ) | |
| warnings.warn(warn_msg) | |
| return data | |
| next = __next__ # Python 2 compatibility | |
| def __len__(self): | |
| return len(self._index_sampler) | |
| def __getstate__(self): | |
| # across multiple threads for HOGWILD. | |
| # Probably the best way to do this is by moving the sample pushing | |
| # to a separate thread and then just sharing the data queue | |
| # but signalling the end is tricky without a non-blocking API | |
| raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) | |
| class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): | |
| def __init__(self, loader): | |
| super(_SingleProcessDataLoaderIter, self).__init__(loader) | |
| assert self._timeout == 0 | |
| assert self._num_workers == 0 | |
| self._dataset_fetcher = _DatasetKind.create_fetcher( | |
| self._dataset_kind, | |
| self._dataset, | |
| self._auto_collation, | |
| self._collate_fn, | |
| self._drop_last, | |
| ) | |
| def _next_data(self): | |
| index = self._next_index() # may raise StopIteration | |
| data = self._dataset_fetcher.fetch(index) # may raise StopIteration | |
| if self._pin_memory: | |
| data = _utils.pin_memory.pin_memory(data) | |
| return data | |
| class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): | |
| r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" | |
| # NOTE [ Data Loader Multiprocessing Shutdown Logic ] | |
| # | |
| # Preliminary: | |
| # | |
| # Our data model looks like this (queues are indicated with curly brackets): | |
| # | |
| # main process || | |
| # | || | |
| # {index_queue} || | |
| # | || | |
| # worker processes || DATA | |
| # | || | |
| # {worker_result_queue} || FLOW | |
| # | || | |
| # pin_memory_thread of main process || DIRECTION | |
| # | || | |
| # {data_queue} || | |
| # | || | |
| # data output \/ | |
| # | |
| # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if | |
| # `pin_memory=False`. | |
| # | |
| # | |
| # Terminating multiprocessing logic requires very careful design. In | |
| # particular, we need to make sure that | |
| # | |
| # 1. The iterator gracefully exits the workers when its last reference is | |
| # gone or it is depleted. | |
| # | |
| # In this case, the workers should be gracefully exited because the | |
| # main process may still need to continue to run, and we want cleaning | |
| # up code in the workers to be executed (e.g., releasing GPU memory). | |
| # Naturally, we implement the shutdown logic in `__del__` of | |
| # DataLoaderIterator. | |
| # | |
| # We delay the discussion on the logic in this case until later. | |
| # | |
| # 2. The iterator exits the workers when the loader process and/or worker | |
| # processes exits normally or with error. | |
| # | |
| # We set all workers and `pin_memory_thread` to have `daemon=True`. | |
| # | |
| # You may ask, why can't we make the workers non-daemonic, and | |
| # gracefully exit using the same logic as we have in `__del__` when the | |
| # iterator gets deleted (see 1 above)? | |
| # | |
| # First of all, `__del__` is **not** guaranteed to be called when | |
| # interpreter exits. Even if it is called, by the time it executes, | |
| # many Python core library resources may alreay be freed, and even | |
| # simple things like acquiring an internal lock of a queue may hang. | |
| # Therefore, in this case, we actually need to prevent `__del__` from | |
| # being executed, and rely on the automatic termination of daemonic | |
| # children. Thus, we register an `atexit` hook that sets a global flag | |
| # `_utils.python_exit_status`. Since `atexit` hooks are executed in the | |
| # reverse order of registration, we are guaranteed that this flag is | |
| # set before library resources we use are freed. (Hooks freeing those | |
| # resources are registered at importing the Python core libraries at | |
| # the top of this file.) So in `__del__`, we check if | |
| # `_utils.python_exit_status` is set or `None` (freed), and perform | |
| # no-op if so. | |
| # | |
| # Another problem with `__del__` is also related to the library cleanup | |
| # calls. When a process ends, it shuts the all its daemonic children | |
| # down with a SIGTERM (instead of joining them without a timeout). | |
| # Simiarly for threads, but by a different mechanism. This fact, | |
| # together with a few implementation details of multiprocessing, forces | |
| # us to make workers daemonic. All of our problems arise when a | |
| # DataLoader is used in a subprocess, and are caused by multiprocessing | |
| # code which looks more or less like this: | |
| # | |
| # try: | |
| # your_function_using_a_dataloader() | |
| # finally: | |
| # multiprocessing.util._exit_function() | |
| # | |
| # The joining/termination mentioned above happens inside | |
| # `_exit_function()`. Now, if `your_function_using_a_dataloader()` | |
| # throws, the stack trace stored in the exception will prevent the | |
| # frame which uses `DataLoaderIter` to be freed. If the frame has any | |
| # reference to the `DataLoaderIter` (e.g., in a method of the iter), | |
| # its `__del__`, which starts the shutdown procedure, will not be | |
| # called. That, in turn, means that workers aren't notified. Attempting | |
| # to join in `_exit_function` will then result in a hang. | |
| # | |
| # For context, `_exit_function` is also registered as an `atexit` call. | |
| # So it is unclear to me (@ssnl) why this is needed in a finally block. | |
| # The code dates back to 2008 and there is no comment on the original | |
| # PEP 371 or patch https://bugs.python.org/issue3050 (containing both | |
| # the finally block and the `atexit` registration) that explains this. | |
| # | |
| # Another choice is to just shutdown workers with logic in 1 above | |
| # whenever we see an error in `next`. This isn't ideal because | |
| # a. It prevents users from using try-catch to resume data loading. | |
| # b. It doesn't prevent hanging if users have references to the | |
| # iterator. | |
| # | |
| # 3. All processes exit if any of them die unexpectedly by fatal signals. | |
| # | |
| # As shown above, the workers are set as daemonic children of the main | |
| # process. However, automatic cleaning-up of such child processes only | |
| # happens if the parent process exits gracefully (e.g., not via fatal | |
| # signals like SIGKILL). So we must ensure that each process will exit | |
| # even the process that should send/receive data to/from it were | |
| # killed, i.e., | |
| # | |
| # a. A process won't hang when getting from a queue. | |
| # | |
| # Even with carefully designed data dependencies (i.e., a `put()` | |
| # always corresponding to a `get()`), hanging on `get()` can still | |
| # happen when data in queue is corrupted (e.g., due to | |
| # `cancel_join_thread` or unexpected exit). | |
| # | |
| # For child exit, we set a timeout whenever we try to get data | |
| # from `data_queue`, and check the workers' status on each timeout | |
| # and error. | |
| # See `_DataLoaderiter._get_batch()` and | |
| # `_DataLoaderiter._try_get_data()` for details. | |
| # | |
| # Additionally, for child exit on non-Windows platforms, we also | |
| # register a SIGCHLD handler (which is supported on Windows) on | |
| # the main process, which checks if any of the workers fail in the | |
| # (Python) handler. This is more efficient and faster in detecting | |
| # worker failures, compared to only using the above mechanism. | |
| # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. | |
| # | |
| # For `.get()` calls where the sender(s) is not the workers, we | |
| # guard them with timeouts, and check the status of the sender | |
| # when timeout happens: | |
| # + in the workers, the `_utils.worker.ManagerWatchdog` class | |
| # checks the status of the main process. | |
| # + if `pin_memory=True`, when getting from `pin_memory_thread`, | |
| # check `pin_memory_thread` status periodically until `.get()` | |
| # returns or see that `pin_memory_thread` died. | |
| # | |
| # b. A process won't hang when putting into a queue; | |
| # | |
| # We use `mp.Queue` which has a separate background thread to put | |
| # objects from an unbounded buffer array. The background thread is | |
| # daemonic and usually automatically joined when the process | |
| # exits. | |
| # | |
| # However, in case that the receiver has ended abruptly while | |
| # reading from the pipe, the join will hang forever. Therefore, | |
| # for both `worker_result_queue` (worker -> main process/pin_memory_thread) | |
| # and each `index_queue` (main process -> worker), we use | |
| # `q.cancel_join_thread()` in sender process before any `q.put` to | |
| # prevent this automatic join. | |
| # | |
| # Moreover, having all queues called `cancel_join_thread` makes | |
| # implementing graceful shutdown logic in `__del__` much easier. | |
| # It won't need to get from any queue, which would also need to be | |
| # guarded by periodic status checks. | |
| # | |
| # Nonetheless, `cancel_join_thread` must only be called when the | |
| # queue is **not** going to be read from or write into by another | |
| # process, because it may hold onto a lock or leave corrupted data | |
| # in the queue, leading other readers/writers to hang. | |
| # | |
| # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does | |
| # a blocking `put` if the queue is full. So there is no above | |
| # problem, but we do need to wrap the `put` in a loop that breaks | |
| # not only upon success, but also when the main process stops | |
| # reading, i.e., is shutting down. | |
| # | |
| # | |
| # Now let's get back to 1: | |
| # how we gracefully exit the workers when the last reference to the | |
| # iterator is gone. | |
| # | |
| # To achieve this, we implement the following logic along with the design | |
| # choices mentioned above: | |
| # | |
| # `workers_done_event`: | |
| # A `multiprocessing.Event` shared among the main process and all worker | |
| # processes. This is used to signal the workers that the iterator is | |
| # shutting down. After it is set, they will not send processed data to | |
| # queues anymore, and only wait for the final `None` before exiting. | |
| # `done_event` isn't strictly needed. I.e., we can just check for `None` | |
| # from the input queue, but it allows us to skip wasting resources | |
| # processing data if we are already shutting down. | |
| # | |
| # `pin_memory_thread_done_event`: | |
| # A `threading.Event` for a similar purpose to that of | |
| # `workers_done_event`, but is for the `pin_memory_thread`. The reason | |
| # that separate events are needed is that `pin_memory_thread` reads from | |
| # the output queue of the workers. But the workers, upon seeing that | |
| # `workers_done_event` is set, only wants to see the final `None`, and is | |
| # not required to flush all data in the output queue (e.g., it may call | |
| # `cancel_join_thread` on that queue if its `IterableDataset` iterator | |
| # happens to exhaust coincidentally, which is out of the control of the | |
| # main process). Thus, since we will exit `pin_memory_thread` before the | |
| # workers (see below), two separete events are used. | |
| # | |
| # NOTE: In short, the protocol is that the main process will set these | |
| # `done_event`s and then the corresponding processes/threads a `None`, | |
| # and that they may exit at any time after receiving the `None`. | |
| # | |
| # NOTE: Using `None` as the final signal is valid, since normal data will | |
| # always be a 2-tuple with the 1st element being the index of the data | |
| # transferred (different from dataset index/key), and the 2nd being | |
| # either the dataset key or the data sample (depending on which part | |
| # of the data model the queue is at). | |
| # | |
| # [ worker processes ] | |
| # While loader process is alive: | |
| # Get from `index_queue`. | |
| # If get anything else, | |
| # Check `workers_done_event`. | |
| # If set, continue to next iteration | |
| # i.e., keep getting until see the `None`, then exit. | |
| # Otherwise, process data: | |
| # If is fetching from an `IterableDataset` and the iterator | |
| # is exhausted, send an `_IterableDatasetStopIteration` | |
| # object to signal iteration end. The main process, upon | |
| # receiving such an object, will send `None` to this | |
| # worker and not use the corresponding `index_queue` | |
| # anymore. | |
| # If timed out, | |
| # No matter `workers_done_event` is set (still need to see `None`) | |
| # or not, must continue to next iteration. | |
| # (outside loop) | |
| # If `workers_done_event` is set, (this can be False with `IterableDataset`) | |
| # `data_queue.cancel_join_thread()`. (Everything is ending here: | |
| # main process won't read from it; | |
| # other workers will also call | |
| # `cancel_join_thread`.) | |
| # | |
| # [ pin_memory_thread ] | |
| # # No need to check main thread. If this thread is alive, the main loader | |
| # # thread must be alive, because this thread is set as daemonic. | |
| # While `pin_memory_thread_done_event` is not set: | |
| # Get from `index_queue`. | |
| # If timed out, continue to get in the next iteration. | |
| # Otherwise, process data. | |
| # While `pin_memory_thread_done_event` is not set: | |
| # Put processed data to `data_queue` (a `queue.Queue` with blocking put) | |
| # If timed out, continue to put in the next iteration. | |
| # Otherwise, break, i.e., continuing to the out loop. | |
| # | |
| # NOTE: we don't check the status of the main thread because | |
| # 1. if the process is killed by fatal signal, `pin_memory_thread` | |
| # ends. | |
| # 2. in other cases, either the cleaning-up in __del__ or the | |
| # automatic exit of daemonic thread will take care of it. | |
| # This won't busy-wait either because `.get(timeout)` does not | |
| # busy-wait. | |
| # | |
| # [ main process ] | |
| # In the DataLoader Iter's `__del__` | |
| # b. Exit `pin_memory_thread` | |
| # i. Set `pin_memory_thread_done_event`. | |
| # ii Put `None` in `worker_result_queue`. | |
| # iii. Join the `pin_memory_thread`. | |
| # iv. `worker_result_queue.cancel_join_thread()`. | |
| # | |
| # c. Exit the workers. | |
| # i. Set `workers_done_event`. | |
| # ii. Put `None` in each worker's `index_queue`. | |
| # iii. Join the workers. | |
| # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. | |
| # | |
| # NOTE: (c) is better placed after (b) because it may leave corrupted | |
| # data in `worker_result_queue`, which `pin_memory_thread` | |
| # reads from, in which case the `pin_memory_thread` can only | |
| # happen at timeing out, which is slow. Nonetheless, same thing | |
| # happens if a worker is killed by signal at unfortunate times, | |
| # but in other cases, we are better off having a non-corrupted | |
| # `worker_result_queue` for `pin_memory_thread`. | |
| # | |
| # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) | |
| # can be omitted | |
| # | |
| # NB: `done_event`s isn't strictly needed. E.g., we can just check for | |
| # `None` from `index_queue`, but it allows us to skip wasting resources | |
| # processing indices already in `index_queue` if we are already shutting | |
| # down. | |
| def __init__(self, loader): | |
| super(_MultiProcessingDataLoaderIter, self).__init__(loader) | |
| assert self._num_workers > 0 | |
| if loader.multiprocessing_context is None: | |
| multiprocessing_context = multiprocessing | |
| else: | |
| multiprocessing_context = loader.multiprocessing_context | |
| self._worker_init_fn = loader.worker_init_fn | |
| self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) | |
| self._worker_result_queue = multiprocessing_context.Queue() | |
| self._worker_pids_set = False | |
| self._shutdown = False | |
| self._send_idx = 0 # idx of the next task to be sent to workers | |
| self._rcvd_idx = 0 # idx of the next task to be returned in __next__ | |
| # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). | |
| # map: task idx => - (worker_id,) if data isn't fetched (outstanding) | |
| # \ (worker_id, data) if data is already fetched (out-of-order) | |
| self._task_info = {} | |
| self._tasks_outstanding = ( | |
| 0 # always equal to count(v for v in task_info.values() if len(v) == 1) | |
| ) | |
| self._workers_done_event = multiprocessing_context.Event() | |
| self._index_queues = [] | |
| self._workers = [] | |
| # A list of booleans representing whether each worker still has work to | |
| # do, i.e., not having exhausted its iterable dataset object. It always | |
| # contains all `True`s if not using an iterable-style dataset | |
| # (i.e., if kind != Iterable). | |
| self._workers_status = [] | |
| for i in range(self._num_workers): | |
| index_queue = multiprocessing_context.Queue() | |
| # index_queue.cancel_join_thread() | |
| w = multiprocessing_context.Process( | |
| target=worker_loop, | |
| args=( | |
| self._dataset_kind, | |
| self._dataset, | |
| index_queue, | |
| self._worker_result_queue, | |
| self._workers_done_event, | |
| self._auto_collation, | |
| self._collate_fn, | |
| self._drop_last, | |
| self._base_seed + i, | |
| self._worker_init_fn, | |
| i, | |
| self._num_workers, | |
| ), | |
| ) | |
| w.daemon = True | |
| # NB: Process.start() actually take some time as it needs to | |
| # start a process and pass the arguments over via a pipe. | |
| # Therefore, we only add a worker to self._workers list after | |
| # it started, so that we do not call .join() if program dies | |
| # before it starts, and __del__ tries to join but will get: | |
| # AssertionError: can only join a started process. | |
| w.start() | |
| self._index_queues.append(index_queue) | |
| self._workers.append(w) | |
| self._workers_status.append(True) | |
| if self._pin_memory: | |
| self._pin_memory_thread_done_event = threading.Event() | |
| self._data_queue = queue() | |
| pin_memory_thread = threading.Thread( | |
| target=_utils.pin_memory._pin_memory_loop, | |
| args=( | |
| self._worker_result_queue, | |
| self._data_queue, | |
| torch.cuda.current_device(), | |
| self._pin_memory_thread_done_event, | |
| ), | |
| ) | |
| pin_memory_thread.daemon = True | |
| pin_memory_thread.start() | |
| # Similar to workers (see comment above), we only register | |
| # pin_memory_thread once it is started. | |
| self._pin_memory_thread = pin_memory_thread | |
| else: | |
| self._data_queue = self._worker_result_queue | |
| _utils.signal_handling._set_worker_pids( | |
| id(self), tuple(w.pid for w in self._workers) | |
| ) | |
| _utils.signal_handling._set_SIGCHLD_handler() | |
| self._worker_pids_set = True | |
| # prime the prefetch loop | |
| for _ in range(2 * self._num_workers): | |
| self._try_put_index() | |
| def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): | |
| # Tries to fetch data from `self._data_queue` once for a given timeout. | |
| # This can also be used as inner loop of fetching without timeout, with | |
| # the sender status as the loop condition. | |
| # | |
| # This raises a `RuntimeError` if any worker died expectedly. This error | |
| # can come from either the SIGCHLD handler in `_utils/signal_handling.py` | |
| # (only for non-Windows platforms), or the manual check below on errors | |
| # and timeouts. | |
| # | |
| # Returns a 2-tuple: | |
| # (bool: whether successfully get data, any: data if successful else None) | |
| try: | |
| data = self._data_queue.get(timeout=timeout) | |
| return (True, data) | |
| except Exception as e: | |
| # At timeout and error, we manually check whether any worker has | |
| # failed. Note that this is the only mechanism for Windows to detect | |
| # worker failures. | |
| failed_workers = [] | |
| for worker_id, w in enumerate(self._workers): | |
| if self._workers_status[worker_id] and not w.is_alive(): | |
| failed_workers.append(w) | |
| self._shutdown_worker(worker_id) | |
| if len(failed_workers) > 0: | |
| pids_str = ", ".join(str(w.pid) for w in failed_workers) | |
| raise RuntimeError( | |
| "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str) | |
| ) | |
| if isinstance(e, queue.Empty): | |
| return (False, None) | |
| raise | |
| def _get_data(self): | |
| # Fetches data from `self._data_queue`. | |
| # | |
| # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, | |
| # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` | |
| # in a loop. This is the only mechanism to detect worker failures for | |
| # Windows. For other platforms, a SIGCHLD handler is also used for | |
| # worker failure detection. | |
| # | |
| # If `pin_memory=True`, we also need check if `pin_memory_thread` had | |
| # died at timeouts. | |
| if self._timeout > 0: | |
| success, data = self._try_get_data(self._timeout) | |
| if success: | |
| return data | |
| else: | |
| raise RuntimeError( | |
| "DataLoader timed out after {} seconds".format(self._timeout) | |
| ) | |
| elif self._pin_memory: | |
| while self._pin_memory_thread.is_alive(): | |
| success, data = self._try_get_data() | |
| if success: | |
| return data | |
| else: | |
| # while condition is false, i.e., pin_memory_thread died. | |
| raise RuntimeError("Pin memory thread exited unexpectedly") | |
| # In this case, `self._data_queue` is a `queue.Queue`,. But we don't | |
| # need to call `.task_done()` because we don't use `.join()`. | |
| else: | |
| while True: | |
| success, data = self._try_get_data() | |
| if success: | |
| return data | |
| def _next_data(self): | |
| while True: | |
| # If the worker responsible for `self._rcvd_idx` has already ended | |
| # and was unable to fulfill this task (due to exhausting an `IterableDataset`), | |
| # we try to advance `self._rcvd_idx` to find the next valid index. | |
| # | |
| # This part needs to run in the loop because both the `self._get_data()` | |
| # call and `_IterableDatasetStopIteration` check below can mark | |
| # extra worker(s) as dead. | |
| while self._rcvd_idx < self._send_idx: | |
| info = self._task_info[self._rcvd_idx] | |
| worker_id = info[0] | |
| if ( | |
| len(info) == 2 or self._workers_status[worker_id] | |
| ): # has data or is still active | |
| break | |
| del self._task_info[self._rcvd_idx] | |
| self._rcvd_idx += 1 | |
| else: | |
| # no valid `self._rcvd_idx` is found (i.e., didn't break) | |
| self._shutdown_workers() | |
| raise StopIteration | |
| # Now `self._rcvd_idx` is the batch index we want to fetch | |
| # Check if the next sample has already been generated | |
| if len(self._task_info[self._rcvd_idx]) == 2: | |
| data = self._task_info.pop(self._rcvd_idx)[1] | |
| return self._process_data(data) | |
| assert not self._shutdown and self._tasks_outstanding > 0 | |
| idx, data = self._get_data() | |
| self._tasks_outstanding -= 1 | |
| if self._dataset_kind == _DatasetKind.Iterable: | |
| # Check for _IterableDatasetStopIteration | |
| if isinstance(data, _utils.worker._IterableDatasetStopIteration): | |
| self._shutdown_worker(data.worker_id) | |
| self._try_put_index() | |
| continue | |
| if idx != self._rcvd_idx: | |
| # store out-of-order samples | |
| self._task_info[idx] += (data,) | |
| else: | |
| del self._task_info[idx] | |
| return self._process_data(data) | |
| def _try_put_index(self): | |
| assert self._tasks_outstanding < 2 * self._num_workers | |
| try: | |
| index = self._next_index() | |
| except StopIteration: | |
| return | |
| for _ in range(self._num_workers): # find the next active worker, if any | |
| worker_queue_idx = next(self._worker_queue_idx_cycle) | |
| if self._workers_status[worker_queue_idx]: | |
| break | |
| else: | |
| # not found (i.e., didn't break) | |
| return | |
| self._index_queues[worker_queue_idx].put((self._send_idx, index)) | |
| self._task_info[self._send_idx] = (worker_queue_idx,) | |
| self._tasks_outstanding += 1 | |
| self._send_idx += 1 | |
| def _process_data(self, data): | |
| self._rcvd_idx += 1 | |
| self._try_put_index() | |
| if isinstance(data, ExceptionWrapper): | |
| data.reraise() | |
| return data | |
| def _shutdown_worker(self, worker_id): | |
| # Mark a worker as having finished its work and dead, e.g., due to | |
| # exhausting an `IterableDataset`. This should be used only when this | |
| # `_MultiProcessingDataLoaderIter` is going to continue running. | |
| assert self._workers_status[worker_id] | |
| # Signal termination to that specific worker. | |
| q = self._index_queues[worker_id] | |
| # Indicate that no more data will be put on this queue by the current | |
| # process. | |
| q.put(None) | |
| # Note that we don't actually join the worker here, nor do we remove the | |
| # worker's pid from C side struct because (1) joining may be slow, and | |
| # (2) since we don't join, the worker may still raise error, and we | |
| # prefer capturing those, rather than ignoring them, even though they | |
| # are raised after the worker has finished its job. | |
| # Joinning is deferred to `_shutdown_workers`, which it is called when | |
| # all workers finish their jobs (e.g., `IterableDataset` replicas) or | |
| # when this iterator is garbage collected. | |
| self._workers_status[worker_id] = False | |
| def _shutdown_workers(self): | |
| # Called when shutting down this `_MultiProcessingDataLoaderIter`. | |
| # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on | |
| # the logic of this function. | |
| python_exit_status = _utils.python_exit_status | |
| if python_exit_status is True or python_exit_status is None: | |
| # See (2) of the note. If Python is shutting down, do no-op. | |
| return | |
| # Normal exit when last reference is gone / iterator is depleted. | |
| # See (1) and the second half of the note. | |
| if not self._shutdown: | |
| self._shutdown = True | |
| try: | |
| # Exit `pin_memory_thread` first because exiting workers may leave | |
| # corrupted data in `worker_result_queue` which `pin_memory_thread` | |
| # reads from. | |
| if hasattr(self, "_pin_memory_thread"): | |
| # Use hasattr in case error happens before we set the attribute. | |
| self._pin_memory_thread_done_event.set() | |
| # Send something to pin_memory_thread in case it is waiting | |
| # so that it can wake up and check `pin_memory_thread_done_event` | |
| self._worker_result_queue.put((None, None)) | |
| self._pin_memory_thread.join() | |
| self._worker_result_queue.close() | |
| # Exit workers now. | |
| self._workers_done_event.set() | |
| for worker_id in range(len(self._workers)): | |
| # Get number of workers from `len(self._workers)` instead of | |
| # `self._num_workers` in case we error before starting all | |
| # workers. | |
| if self._workers_status[worker_id]: | |
| self._shutdown_worker(worker_id) | |
| for w in self._workers: | |
| w.join() | |
| for q in self._index_queues: | |
| q.cancel_join_thread() | |
| q.close() | |
| finally: | |
| # Even though all this function does is putting into queues that | |
| # we have called `cancel_join_thread` on, weird things can | |
| # happen when a worker is killed by a signal, e.g., hanging in | |
| # `Event.set()`. So we need to guard this with SIGCHLD handler, | |
| # and remove pids from the C side data structure only at the | |
| # end. | |
| # | |
| # FIXME: Unfortunately, for Windows, we are missing a worker | |
| # error detection mechanism here in this function, as it | |
| # doesn't provide a SIGCHLD handler. | |
| if self._worker_pids_set: | |
| _utils.signal_handling._remove_worker_pids(id(self)) | |
| self._worker_pids_set = False | |
| def __del__(self): | |
| self._shutdown_workers() | |