blewis-hir commited on
Commit
f4a9fe8
·
verified ·
1 Parent(s): 420ba7d

Since `transformers` v4.56.0` the dictionary `ALL_STATIC_CACHE_IMPLEMENTATIONS` replaced the existing dictionary

Browse files
Files changed (1) hide show
  1. modeling_decilm.py +3 -2
modeling_decilm.py CHANGED
@@ -27,7 +27,8 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
@@ -810,7 +811,7 @@ class DeciLMPreTrainedModel(PreTrainedModel):
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
  generation_config.cache_implementation = "variable"
813
- NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
814
  return generation_config, model_kwargs
815
 
816
 
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
+ from transformers.generation.utils import GenerationMixin, GenerateOutput
31
+ from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
34
  from transformers.utils import (
 
811
  # DeciLM-specific code
812
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
813
  generation_config.cache_implementation = "variable"
814
+ ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
815
  return generation_config, model_kwargs
816
 
817