Since `transformers` v4.56.0` the dictionary `ALL_STATIC_CACHE_IMPLEMENTATIONS` replaced the existing dictionary
Browse files- modeling_decilm.py +3 -2
modeling_decilm.py
CHANGED
|
@@ -27,7 +27,8 @@ import torch.utils.checkpoint
|
|
| 27 |
from torch import nn
|
| 28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
from transformers import GenerationConfig
|
| 30 |
-
from transformers.generation.utils import
|
|
|
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 33 |
from transformers.utils import (
|
|
@@ -810,7 +811,7 @@ class DeciLMPreTrainedModel(PreTrainedModel):
|
|
| 810 |
# DeciLM-specific code
|
| 811 |
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
|
| 812 |
generation_config.cache_implementation = "variable"
|
| 813 |
-
|
| 814 |
return generation_config, model_kwargs
|
| 815 |
|
| 816 |
|
|
|
|
| 27 |
from torch import nn
|
| 28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
from transformers import GenerationConfig
|
| 30 |
+
from transformers.generation.utils import GenerationMixin, GenerateOutput
|
| 31 |
+
from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
|
| 32 |
from transformers.modeling_utils import PreTrainedModel
|
| 33 |
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 34 |
from transformers.utils import (
|
|
|
|
| 811 |
# DeciLM-specific code
|
| 812 |
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
|
| 813 |
generation_config.cache_implementation = "variable"
|
| 814 |
+
ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
|
| 815 |
return generation_config, model_kwargs
|
| 816 |
|
| 817 |
|