subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Sequence
import click
import numpy as np
from omegaconf import DictConfig, ListConfig, OmegaConf
from nemo.collections.common.data.lhotse.cutset import get_parser_fn
@click.command()
@click.argument("input_cfgs", type=click.Path(exists=True, dir_okay=False), nargs=-1)
@click.argument("output_cfg", type=click.Path())
@click.option(
"-t",
"--temperature",
type=float,
default=None,
multiple=True,
help="Temperature for re-weighting datasets. 1 is a neutral value. "
"Lower temperature over-samples smaller datasets, and vice versa. "
"Can be specified multiple times to apply a different temperature to each group level in the YAML config.",
)
@click.option(
"-s",
"--strategy",
type=click.Choice(["num_hours", "num_examples"]),
default="num_hours",
help="Strategy for choosing weights for each dataset.",
)
def estimate_data_weights(input_cfgs: str, output_cfg: str, temperature: list[float], strategy: str):
"""
Read a YAML specification of datasets from INPUT_CFGS, compute their weights, and save the result in OUTPUT_CFG.
The weight for each entry is determined by the number of hours in a given dataset.
If more than one config is provided as input, we will concatenate them and output a single merged config.
Optionally, apply temperature re-weighting to balance the datasets (specify TEMPERATURE lesser than 1).
"""
data = ListConfig([])
for icfg in input_cfgs:
data.extend(OmegaConf.load(icfg))
temperature = parse_temperature(temperature)
validate(data)
count(data, weight_key=strategy)
aggregate_group_weights(data)
reweight(data, temperature=temperature)
OmegaConf.save(data, output_cfg)
def validate(entry: DictConfig | ListConfig, _level: int = 0):
if isinstance(entry, ListConfig):
for subentry in entry:
validate(subentry, _level + 1)
return
assert "type" in entry, f"Invalid YAML data config at nesting level {_level}: missing key 'type' in entry={entry}"
if entry.type == "group":
for subentry in entry["input_cfg"]:
validate(subentry, _level + 1)
def count(entry: DictConfig | ListConfig, weight_key: str) -> None:
if isinstance(entry, ListConfig):
for subentry in entry:
count(subentry, weight_key=weight_key)
return
if entry.type == "group":
for subentry in entry["input_cfg"]:
count(subentry, weight_key=weight_key)
return
with quick_iter_options(entry):
iterable, is_tarred = get_parser_fn(entry.type)(entry)
stats = {"num_hours": 0.0, "num_examples": 0}
for example in iterable:
if hasattr(example, "duration"):
stats["num_hours"] += example.duration
stats["num_examples"] += 1
stats["num_hours"] /= 3600.0
if weight_key == "num_hours" and stats[weight_key] == 0.0:
raise RuntimeError(
f"Cannot set weights based on 'num_hours': at least one dataset has examples without 'duration' property. "
f"Details: {entry=}"
)
entry["weight"] = stats[weight_key]
def aggregate_group_weights(entry: DictConfig | ListConfig) -> None:
if isinstance(entry, ListConfig):
for subentry in entry:
aggregate_group_weights(subentry)
return
if entry.type != "group":
return
for subentry in entry["input_cfg"]:
if "weight" not in subentry:
aggregate_group_weights(subentry)
entry.weight = sum(subentry["weight"] for subentry in entry["input_cfg"])
def reweight(entry: DictConfig | ListConfig, temperature: None | float | list[float]) -> None:
if not temperature or (isinstance(entry, DictConfig) and entry.type != "group"):
return
if isinstance(temperature, Sequence):
temperature, *next_temperatures = temperature
else:
next_temperatures = temperature
if isinstance(entry, ListConfig):
for subentry in entry:
reweight(subentry, temperature=next_temperatures)
new_weights = temperature_reweighting([se.weight for se in entry], temperature=temperature)
for se, nw in zip(entry, new_weights):
se.weight = nw
return
for subentry in entry["input_cfg"]:
reweight(subentry, temperature=next_temperatures)
new_weights = temperature_reweighting([se.weight for se in entry["input_cfg"]], temperature=temperature)
for se, nw in zip(entry["input_cfg"], new_weights):
se.weight = nw
def temperature_reweighting(weights: list[float], temperature: float = 1.0):
"""(w_i ^ alpha / sum(w_i ^ alpha))"""
weights = np.asarray(weights) ** temperature
return (weights / weights.sum()).tolist()
@contextmanager
def quick_iter_options(entry: DictConfig):
entry.metadata_only = True
entry.force_finite = True
yield entry
del entry["metadata_only"]
del entry["force_finite"]
def parse_temperature(value: list[float]) -> float | list[float] | None:
match value:
case 0:
return None
case 1:
return value[0]
case _:
return value
if __name__ == '__main__':
estimate_data_weights()