Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import contextlib | |
| import os | |
| import tempfile | |
| import torch | |
| import torch.distributed as dist | |
| from nemo.utils import logging | |
| from nemo.utils.get_rank import is_global_rank_zero | |
| try: | |
| from megatron.core import parallel_state | |
| HAVE_MEGATRON_CORE = True | |
| except (ImportError, ModuleNotFoundError): | |
| HAVE_MEGATRON_CORE = False | |
| def initialize_distributed(args, backend='nccl'): | |
| """Initialize torch.distributed.""" | |
| # Get local rank in case it is provided. | |
| local_rank = args.local_rank | |
| # Get rank and world size. | |
| rank = int(os.getenv('RANK', '0')) | |
| world_size = int(os.getenv("WORLD_SIZE", '1')) | |
| logging.info( | |
| f'Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}' | |
| ) | |
| # Set the device id. | |
| device = rank % torch.cuda.device_count() | |
| if local_rank is not None: | |
| device = local_rank | |
| torch.cuda.set_device(device) | |
| # Call the init process. | |
| init_method = 'tcp://' | |
| master_ip = os.getenv('MASTER_ADDR', 'localhost') | |
| master_port = os.getenv('MASTER_PORT', '6000') | |
| init_method += master_ip + ':' + master_port | |
| torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method) | |
| return local_rank, rank, world_size | |
| def gather_objects(partial_results_list, main_rank=None): | |
| """ | |
| Collect objects (e.g., results) from all GPUs. | |
| Useful for inference over multiple GPUs with DDP. | |
| Use main_rank to specify which rank will be used to gather results. | |
| This allows to continue execution on the main_rank only after the gather. | |
| Args: | |
| partial_results_list: list of partial results from each GPU | |
| main_rank: rank of the main process to collect results from all GPUs (useful for collecting results in a target rank) | |
| Example: | |
| predictions = gather_objects(predictions,main_rank=0) | |
| # all but rank 0 will return None | |
| if predictions is None: | |
| return | |
| # from here only rank 0 should contiue | |
| pickle.dump(predictions, open(output_fname, "wb")) | |
| """ | |
| # do not fail when DDP is not initialized | |
| if not parallel_state.is_initialized(): | |
| return partial_results_list | |
| rank = parallel_state.get_data_parallel_rank() | |
| world_size = parallel_state.get_data_parallel_world_size() | |
| # return input when no DDP is used | |
| if world_size == 1: | |
| return partial_results_list | |
| gathered_results = [None for _ in range(world_size)] | |
| torch.distributed.all_gather_object(gathered_results, partial_results_list) | |
| # return None to non-main ranks | |
| if main_rank is not None: | |
| if rank != main_rank: | |
| return None | |
| # return collected results | |
| results_list = [] | |
| for r in gathered_results: | |
| results_list.extend(r) | |
| return results_list | |
| def temporary_directory(): | |
| """Create a shared temporary directory across ranks in distributed setup. | |
| This function assumes that the distributed setup has been already | |
| correctly initialized. It is intended to be used only in single-node | |
| setup so that all ranks can access the directory created.""" | |
| if is_global_rank_zero(): | |
| tmp_dir = [tempfile.TemporaryDirectory()] | |
| else: | |
| tmp_dir = [None] | |
| dist.broadcast_object_list(tmp_dir) | |
| yield tmp_dir[0].name | |
| # We use barrier below to make sure that rank zero won't exit | |
| # and delete tmp_dir while other ranks may still use it | |
| dist.barrier() | |
| if is_global_rank_zero(): | |
| tmp_dir[0].cleanup() | |
| def webdataset_split_by_workers(src): | |
| """ | |
| This is for latest webdataset>=0.2.6 | |
| This function will make sure that each worker gets a different subset of the dataset. | |
| """ | |
| # group = torch.distributed.group.WORLD | |
| # rank = torch.distributed.get_rank(group=group) | |
| # world_size = torch.distributed.get_world_size(group=group) | |
| worker_info = torch.utils.data.get_worker_info() | |
| num_workers = 1 | |
| if worker_info is not None: | |
| worker = worker_info.id | |
| num_workers = worker_info.num_workers | |
| if num_workers > 1: | |
| yield from list(src)[worker::num_workers] | |
| else: | |
| yield from src | |