Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from contextlib import contextmanager, nullcontext | |
| from typing import Any | |
| import torch | |
| def avoid_bfloat16_autocast_context(): | |
| """ | |
| If the current autocast context is bfloat16, | |
| cast it to float32 | |
| """ | |
| if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: | |
| return torch.amp.autocast('cuda', dtype=torch.float32) | |
| else: | |
| return nullcontext() | |
| def avoid_float16_autocast_context(): | |
| """ | |
| If the current autocast context is float16, cast it to bfloat16 | |
| if available (unless we're in jit) or float32 | |
| """ | |
| if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| return torch.amp.autocast('cuda', dtype=torch.float32) | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.amp.autocast('cuda', dtype=torch.bfloat16) | |
| else: | |
| return torch.amp.autocast('cuda', dtype=torch.float32) | |
| else: | |
| return nullcontext() | |
| def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): | |
| return x.to(dtype=to_dtype) if x.dtype == from_dtype else x | |
| def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): | |
| if isinstance(x, torch.Tensor): | |
| return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) | |
| else: | |
| if isinstance(x, dict): | |
| new_dict = {} | |
| for k in x.keys(): | |
| new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) | |
| return new_dict | |
| elif isinstance(x, tuple): | |
| return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) | |
| class CastToFloat(torch.nn.Module): | |
| def __init__(self, mod): | |
| super(CastToFloat, self).__init__() | |
| self.mod = mod | |
| def forward(self, x): | |
| if torch.is_autocast_enabled() and x.dtype != torch.float32: | |
| with torch.amp.autocast(x.device.type, enabled=False): | |
| ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) | |
| else: | |
| ret = self.mod.forward(x) | |
| return ret | |
| class CastToFloatAll(torch.nn.Module): | |
| def __init__(self, mod): | |
| super(CastToFloatAll, self).__init__() | |
| self.mod = mod | |
| def forward(self, *args): | |
| if torch.is_autocast_enabled(): | |
| from_dtype = args[0].dtype | |
| with torch.amp.autocast(self.device.type, enabled=False): | |
| ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) | |
| return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) | |
| else: | |
| return self.mod.forward(*args) | |
| def monkeypatched(object, name, patch): | |
| """Temporarily monkeypatches an object.""" | |
| pre_patched_value = getattr(object, name) | |
| setattr(object, name, patch) | |
| yield object | |
| setattr(object, name, pre_patched_value) | |
| def maybe_cast_to_type(x: Any, type_: type) -> Any: | |
| """Try to cast a value to int, if it fails, return the original value. | |
| Args: | |
| x (Any): The value to be casted. | |
| type_ (type): The type to cast to, must be a callable. | |
| Returns: | |
| Any: The casted value or the original value if casting fails. | |
| """ | |
| try: | |
| return type_(x) | |
| except Exception: | |
| return x | |