Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from pathlib import Path | |
| from time import sleep | |
| import wget | |
| from lightning.pytorch.plugins.environments import LightningEnvironment | |
| from lightning.pytorch.strategies import DDPStrategy, StrategyRegistry | |
| from nemo.utils import logging | |
| def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str: | |
| """ | |
| Helper function to download pre-trained weights from the cloud | |
| Args: | |
| url: (str) URL of storage | |
| filename: (str) what to download. The request will be issued to url/filename | |
| subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can | |
| be empty | |
| cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it. | |
| If None (default), then it will be $HOME/.cache/torch/NeMo | |
| refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch | |
| Returns: | |
| If successful - absolute local path to the downloaded file | |
| else - empty string | |
| """ | |
| # try: | |
| if cache_dir is None: | |
| cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo") | |
| else: | |
| cache_location = cache_dir | |
| if subfolder is not None: | |
| destination = Path.joinpath(cache_location, subfolder) | |
| else: | |
| destination = cache_location | |
| if not os.path.exists(destination): | |
| os.makedirs(destination, exist_ok=True) | |
| destination_file = Path.joinpath(destination, filename) | |
| if os.path.exists(destination_file): | |
| logging.info(f"Found existing object {destination_file}.") | |
| if refresh_cache: | |
| logging.info("Asked to refresh the cache.") | |
| logging.info(f"Deleting file: {destination_file}") | |
| os.remove(destination_file) | |
| else: | |
| logging.info(f"Re-using file from: {destination_file}") | |
| return str(destination_file) | |
| # download file | |
| wget_uri = url + filename | |
| logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}") | |
| # NGC links do not work everytime so we try and wait | |
| i = 0 | |
| max_attempts = 3 | |
| while i < max_attempts: | |
| i += 1 | |
| try: | |
| wget.download(wget_uri, str(destination_file)) | |
| if os.path.exists(destination_file): | |
| return destination_file | |
| else: | |
| return "" | |
| except: | |
| logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}") | |
| sleep(0.05) | |
| continue | |
| raise ValueError("Not able to download url right now, please try again.") | |
| class SageMakerDDPStrategy(DDPStrategy): | |
| def cluster_environment(self): | |
| env = LightningEnvironment() | |
| env.world_size = lambda: int(os.environ["WORLD_SIZE"]) | |
| env.global_rank = lambda: int(os.environ["RANK"]) | |
| return env | |
| def cluster_environment(self, env): | |
| # prevents Lightning from overriding the Environment required for SageMaker | |
| pass | |
| def initialize_sagemaker() -> None: | |
| """ | |
| Helper function to initiate sagemaker with NeMo. | |
| This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp. | |
| """ | |
| StrategyRegistry.register( | |
| name='smddp', | |
| strategy=SageMakerDDPStrategy, | |
| process_group_backend="smddp", | |
| find_unused_parameters=False, | |
| ) | |
| def _install_system_libraries() -> None: | |
| os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg') | |
| def _patch_torch_metrics() -> None: | |
| """ | |
| Patches torchmetrics to not rely on internal state. | |
| This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning. | |
| """ | |
| from torchmetrics import Metric | |
| def __new_hash__(self): | |
| hash_vals = [self.__class__.__name__, id(self)] | |
| return hash(tuple(hash_vals)) | |
| Metric.__hash__ = __new_hash__ | |
| _patch_torch_metrics() | |
| if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"): | |
| import smdistributed.dataparallel.torch.distributed as dist | |
| # has to be imported, as it overrides torch modules and such when DDP is enabled. | |
| import smdistributed.dataparallel.torch.torch_smddp | |
| dist.init_process_group() | |
| if dist.get_local_rank(): | |
| _install_system_libraries() | |
| return dist.barrier() # wait for main process | |
| _install_system_libraries() | |
| return | |