Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from argparse import Namespace | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, List, Literal, Mapping, Optional, Union | |
| import pandas as pd | |
| from lightning.pytorch.callbacks import Checkpoint | |
| from lightning.pytorch.loggers import Logger | |
| from lightning.pytorch.utilities.parsing import AttributeDict | |
| from lightning_utilities.core.apply_func import apply_to_collection | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from torch import Tensor | |
| from nemo.utils import logging | |
| try: | |
| from clearml import OutputModel, Task | |
| HAVE_CLEARML_LOGGER = True | |
| except (ImportError, ModuleNotFoundError): | |
| HAVE_CLEARML_LOGGER = False | |
| class ClearMLParams: | |
| project: Optional[str] = None | |
| task: Optional[str] = None | |
| connect_pytorch: Optional[bool] = False | |
| model_name: Optional[str] = None | |
| tags: Optional[List[str]] = None | |
| log_model: Optional[bool] = False | |
| log_cfg: Optional[bool] = False | |
| log_metrics: Optional[bool] = False | |
| class ClearMLLogger(Logger): | |
| def name(self) -> str: | |
| return self.clearml_task.name | |
| def version(self) -> str: | |
| return self.clearml_task.id | |
| def __init__( | |
| self, clearml_cfg: DictConfig, log_dir: str, prefix: str, save_best_model: bool, postfix: str = ".nemo" | |
| ) -> None: | |
| if not HAVE_CLEARML_LOGGER: | |
| raise ImportError( | |
| "Found create_clearml_logger is True." | |
| "But ClearML not found. Please see the README for installation instructions:" | |
| "https://github.com/allegroai/clearml" | |
| ) | |
| self.clearml_task = None | |
| self.clearml_model = None | |
| self.clearml_cfg = clearml_cfg | |
| self.path_nemo_model = os.path.abspath( | |
| os.path.expanduser(os.path.join(log_dir, "checkpoints", prefix + postfix)) | |
| ) | |
| self.save_best_model = save_best_model | |
| self.prefix = prefix | |
| self.previos_best_model_path = None | |
| self.last_metrics = None | |
| self.save_blocked = True | |
| self.project_name = os.getenv("CLEARML_PROJECT", clearml_cfg.project if clearml_cfg.project else "NeMo") | |
| self.task_name = os.getenv("CLEARML_TASK", clearml_cfg.task if clearml_cfg.task else f"Trainer {self.prefix}") | |
| tags = ["NeMo"] | |
| if clearml_cfg.tags: | |
| tags.extend(clearml_cfg.tags) | |
| self.clearml_task: Task = Task.init( | |
| project_name=self.project_name, | |
| task_name=self.task_name, | |
| auto_connect_frameworks={"pytorch": clearml_cfg.connect_pytorch}, | |
| output_uri=True, | |
| tags=tags, | |
| ) | |
| if clearml_cfg.model_name: | |
| model_name = clearml_cfg.model_name | |
| elif self.prefix: | |
| model_name = self.prefix | |
| else: | |
| model_name = self.task_name | |
| if clearml_cfg.log_model: | |
| self.clearml_model: OutputModel = OutputModel( | |
| name=model_name, task=self.clearml_task, tags=tags, framework="NeMo" | |
| ) | |
| def log_hyperparams(self, params, *args, **kwargs) -> None: | |
| if self.clearml_model and self.clearml_cfg.log_cfg: | |
| if isinstance(params, Namespace): | |
| params = vars(params) | |
| elif isinstance(params, AttributeDict): | |
| params = dict(params) | |
| params = apply_to_collection(params, (DictConfig, ListConfig), OmegaConf.to_container, resolve=True) | |
| params = apply_to_collection(params, Path, str) | |
| params = OmegaConf.to_yaml(params) | |
| self.clearml_model.update_design(config_text=params) | |
| def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: | |
| if self.clearml_model and self.clearml_cfg.log_metrics: | |
| metrics = { | |
| k: { | |
| "value": str(v.item() if type(v) == Tensor else v), | |
| "type": str(type(v.item() if type(v) == Tensor else v)), | |
| } | |
| for k, v in metrics.items() | |
| } | |
| self.last_metrics = metrics | |
| def log_table( | |
| self, | |
| key: str, | |
| columns: List[str] = None, | |
| data: List[List[Any]] = None, | |
| dataframe: Any = None, | |
| step: Optional[int] = None, | |
| ) -> None: | |
| table: Optional[Union[pd.DataFrame, List[List[Any]]]] = None | |
| if dataframe is not None: | |
| table = dataframe | |
| if columns is not None: | |
| table.columns = columns | |
| if data is not None: | |
| table = data | |
| assert len(columns) == len(table[0]), "number of column names should match the total number of columns" | |
| table.insert(0, columns) | |
| if table is not None: | |
| self.clearml_task.logger.report_table(title=key, series=key, iteration=step, table_plot=table) | |
| def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: | |
| if self.clearml_model: | |
| if self.save_best_model: | |
| if self.save_blocked: | |
| self.save_blocked = False | |
| return None | |
| if not os.path.exists(checkpoint_callback.best_model_path): | |
| return None | |
| if self.previos_best_model_path == checkpoint_callback.best_model_path: | |
| return None | |
| self.previos_best_model_path = checkpoint_callback.best_model_path | |
| self._log_model(self.path_nemo_model) | |
| def finalize(self, status: Literal["success", "failed", "aborted"] = "success") -> None: | |
| if status == "success": | |
| self.clearml_task.mark_completed() | |
| elif status == "failed": | |
| self.clearml_task.mark_failed() | |
| elif status == "aborted": | |
| self.clearml_task.mark_stopped() | |
| def _log_model(self, save_path: str) -> None: | |
| if self.clearml_model: | |
| if os.path.exists(save_path): | |
| self.clearml_model.update_weights( | |
| weights_filename=save_path, | |
| upload_uri=self.clearml_task.storage_uri or self.clearml_task._get_default_report_storage_uri(), | |
| auto_delete_file=False, | |
| is_package=True, | |
| ) | |
| if self.clearml_cfg.log_metrics and self.last_metrics: | |
| self.clearml_model.set_all_metadata(self.last_metrics) | |
| self.save_blocked = True | |
| else: | |
| logging.warning((f"Logging model enabled, but cant find .nemo file!" f" Path: {save_path}")) | |