Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from contextlib import contextmanager | |
| from typing import Sequence | |
| import click | |
| import numpy as np | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from nemo.collections.common.data.lhotse.cutset import get_parser_fn | |
| def estimate_data_weights(input_cfgs: str, output_cfg: str, temperature: list[float], strategy: str): | |
| """ | |
| Read a YAML specification of datasets from INPUT_CFGS, compute their weights, and save the result in OUTPUT_CFG. | |
| The weight for each entry is determined by the number of hours in a given dataset. | |
| If more than one config is provided as input, we will concatenate them and output a single merged config. | |
| Optionally, apply temperature re-weighting to balance the datasets (specify TEMPERATURE lesser than 1). | |
| """ | |
| data = ListConfig([]) | |
| for icfg in input_cfgs: | |
| data.extend(OmegaConf.load(icfg)) | |
| temperature = parse_temperature(temperature) | |
| validate(data) | |
| count(data, weight_key=strategy) | |
| aggregate_group_weights(data) | |
| reweight(data, temperature=temperature) | |
| OmegaConf.save(data, output_cfg) | |
| def validate(entry: DictConfig | ListConfig, _level: int = 0): | |
| if isinstance(entry, ListConfig): | |
| for subentry in entry: | |
| validate(subentry, _level + 1) | |
| return | |
| assert "type" in entry, f"Invalid YAML data config at nesting level {_level}: missing key 'type' in entry={entry}" | |
| if entry.type == "group": | |
| for subentry in entry["input_cfg"]: | |
| validate(subentry, _level + 1) | |
| def count(entry: DictConfig | ListConfig, weight_key: str) -> None: | |
| if isinstance(entry, ListConfig): | |
| for subentry in entry: | |
| count(subentry, weight_key=weight_key) | |
| return | |
| if entry.type == "group": | |
| for subentry in entry["input_cfg"]: | |
| count(subentry, weight_key=weight_key) | |
| return | |
| with quick_iter_options(entry): | |
| iterable, is_tarred = get_parser_fn(entry.type)(entry) | |
| stats = {"num_hours": 0.0, "num_examples": 0} | |
| for example in iterable: | |
| if hasattr(example, "duration"): | |
| stats["num_hours"] += example.duration | |
| stats["num_examples"] += 1 | |
| stats["num_hours"] /= 3600.0 | |
| if weight_key == "num_hours" and stats[weight_key] == 0.0: | |
| raise RuntimeError( | |
| f"Cannot set weights based on 'num_hours': at least one dataset has examples without 'duration' property. " | |
| f"Details: {entry=}" | |
| ) | |
| entry["weight"] = stats[weight_key] | |
| def aggregate_group_weights(entry: DictConfig | ListConfig) -> None: | |
| if isinstance(entry, ListConfig): | |
| for subentry in entry: | |
| aggregate_group_weights(subentry) | |
| return | |
| if entry.type != "group": | |
| return | |
| for subentry in entry["input_cfg"]: | |
| if "weight" not in subentry: | |
| aggregate_group_weights(subentry) | |
| entry.weight = sum(subentry["weight"] for subentry in entry["input_cfg"]) | |
| def reweight(entry: DictConfig | ListConfig, temperature: None | float | list[float]) -> None: | |
| if not temperature or (isinstance(entry, DictConfig) and entry.type != "group"): | |
| return | |
| if isinstance(temperature, Sequence): | |
| temperature, *next_temperatures = temperature | |
| else: | |
| next_temperatures = temperature | |
| if isinstance(entry, ListConfig): | |
| for subentry in entry: | |
| reweight(subentry, temperature=next_temperatures) | |
| new_weights = temperature_reweighting([se.weight for se in entry], temperature=temperature) | |
| for se, nw in zip(entry, new_weights): | |
| se.weight = nw | |
| return | |
| for subentry in entry["input_cfg"]: | |
| reweight(subentry, temperature=next_temperatures) | |
| new_weights = temperature_reweighting([se.weight for se in entry["input_cfg"]], temperature=temperature) | |
| for se, nw in zip(entry["input_cfg"], new_weights): | |
| se.weight = nw | |
| def temperature_reweighting(weights: list[float], temperature: float = 1.0): | |
| """(w_i ^ alpha / sum(w_i ^ alpha))""" | |
| weights = np.asarray(weights) ** temperature | |
| return (weights / weights.sum()).tolist() | |
| def quick_iter_options(entry: DictConfig): | |
| entry.metadata_only = True | |
| entry.force_finite = True | |
| yield entry | |
| del entry["metadata_only"] | |
| del entry["force_finite"] | |
| def parse_temperature(value: list[float]) -> float | list[float] | None: | |
| match value: | |
| case 0: | |
| return None | |
| case 1: | |
| return value[0] | |
| case _: | |
| return value | |
| if __name__ == '__main__': | |
| estimate_data_weights() | |