File size: 3,951 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Mapping, Sequence

import torch
from lightning.pytorch.plugins import HalfPrecision
from typing_extensions import override

_HAS_HYDRA = True

try:
    import hydra
    from omegaconf import DictConfig, OmegaConf
except ModuleNotFoundError:
    DictConfig = Mapping
    OmegaConf = None
    _HAS_HYDRA = False


def resolve_trainer_cfg(trainer_cfg: DictConfig) -> DictConfig:
    """
    Resolves and processes a trainer configuration.

    This function handles specific trainer configuration details:
    - For half precision setups, replaces precision settings with custom plugins
    - Instantiates strategy objects from mapping configurations
    - Instantiates custom callbacks from sequences

    Args:
        trainer_cfg: A DictConfig containing trainer configuration parameters

    Returns:
        A processed DictConfig with resolved configuration values
    """
    trainer_cfg = OmegaConf.to_container(trainer_cfg, resolve=True)
    if not _HAS_HYDRA:
        return trainer_cfg

    # Avoids downcasting 'audio' tensors in 'true' half precision setups.
    precision = trainer_cfg.get("precision")
    if precision in ("fp16-true", "bf16-true"):
        trainer_cfg.pop("precision", None)
        trainer_cfg["plugins"] = [HalfPrecisionForAudio(precision)]

    # Allows customizable strategies (eg ModelParallelStrategy) in YAML configs.
    if (strategy := trainer_cfg.get("strategy", None)) is not None and isinstance(strategy, Mapping):
        trainer_cfg["strategy"] = hydra.utils.instantiate(strategy)

    # Allows to add custom callbacks (e.g. NsysCallback) from YAML config.
    if (cbs := trainer_cfg.get("callbacks", None)) is not None and isinstance(cbs, Sequence):
        resolved = []
        for cb in cbs:
            resolved.append(hydra.utils.instantiate(cb))
        trainer_cfg["callbacks"] = resolved

    return trainer_cfg


class HalfPrecisionForAudio(HalfPrecision):
    """
    Adjusted Pytorch Lightning plugin for training with half precision.
    It avoids downcasting audio to bfloat16 when the mini-batch is a dict
    with 'audio' string in the keys corresponding to audio tensors.
    """

    @override
    def convert_input(self, data: Any) -> Any:
        """
        Converts input data to the appropriate precision format, preserving audio tensor precision.

        This method overrides the parent class implementation to avoid downcasting tensors
        with 'audio' in their dictionary keys. It processes input data recursively when
        encountering nested dictionaries.

        Args:
            data: The input data to convert (can be tensor, dict, or other types)

        Returns:
            The converted data with appropriate precision for each element
        """
        if not isinstance(data, dict):
            return super().convert_input(data)

        def _convert(v):
            if isinstance(v, dict):
                ans = {}
                for k, v in v.items():
                    if "audio" not in k or not torch.is_tensor(v):
                        v = _convert(v)
                    ans[k] = v
                return ans
            if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
                return v.to(self._desired_input_dtype)
            return v  # any other type

        return _convert(data)