# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import lightning.pytorch as pl import torch def extract_dtypes(ckpt): """ Extracts dtype from the input iterator ckpt can be module.named_parameters or module.state_dict().items() """ dtypes = {} for key, val in ckpt: if hasattr(val, 'dtype'): dtypes[key] = val.dtype elif hasattr(val, 'data') and hasattr(val.data, 'dtype'): # if it's ShardedTensor populated with data. dtypes[key] = val.data.dtype return dtypes def dtype_from_str(dtype): """ Convert a str precision to equivalent torch dtype. """ assert isinstance(dtype, str) if dtype in ["float16", "fp16", "16", "16-mixed"]: return torch.float16 elif dtype in ["bfloat16", "bf16-mixed"]: return torch.bfloat16 else: return torch.float32 def dtype_from_hf(config): """ Extracts torch dtype from a HF config """ assert hasattr(config, 'torch_dtype'), "Expected config to have attr `torch_dtype`" torch_dtype = config.torch_dtype if isinstance(torch_dtype, torch.dtype): return torch_dtype elif isinstance(torch_dtype, str): return dtype_from_str(torch_dtype) else: raise ValueError("torch_dtype is not of type str/torch.dtype") def is_trainer_attached(model: pl.LightningModule): """ Returns true if trainer is attached to a model """ return hasattr(model, 'trainer') def get_automodel_from_trainer(trainer: pl.Trainer): """ Extracts the automodel from a PyTorch Lightning trainer instance. This function checks whether the `trainer.model` is an automodel (e.g. `HFAutoModelForCausalLM`). It handles different distributed training strategies: - If no DistributedDataParallel (DDP) or Fully Sharded Data Parallel (FSDP) is used, `trainer.model` directly holds the automodel. - If DDP is used, `trainer.model.module` contains the actual automodel. - If FSDP is used, `trainer.model` still holds the automodel wrapped inside an FSDP structure. Args: trainer (lightning.pytorch.Trainer): The PyTorch Lightning trainer instance. Returns: nn.Module or None: The automodel if found, otherwise `None`. """ # No DDP -> trainer.model holds: # HFAutoModelForCausalLM( # (model): # FSDP -> trainer.model holds: # HFAutoModelForCausalLM( # (model): FSDP if getattr(trainer.model, "is_hf_model", False) == True: return trainer.model # DDP -> trainer.model holds: # DistributedDataParallel( # (module): HFAutoModelForCausalLM( if hasattr(trainer.model, 'module') and getattr(trainer.model.module, "is_hf_model", False) == True: return trainer.model.module return None