add batch_size attribute to VariableCache

#15
by itlevy - opened
Files changed (1) hide show
  1. variable_cache.py +3 -2
variable_cache.py CHANGED
@@ -34,18 +34,19 @@ class VariableCache(Cache_4_44_2, Cache):
34
 
35
  def __init__(
36
  self,
 
37
  config: DeciLMConfig,
38
  batch_size: int = None,
39
  max_cache_len: int = None,
40
- device: torch.device = None,
41
  dtype: torch.dtype = torch.float32,
42
  max_batch_size: Optional[int] = None,
43
  **kwargs: Any,
44
  ) -> None:
45
  Cache_4_44_2.__init__(self)
46
 
47
- self.config = config
48
  self.max_batch_size = batch_size or max_batch_size
 
49
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
50
  self.dtype = dtype
51
 
 
34
 
35
  def __init__(
36
  self,
37
+ *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
38
  config: DeciLMConfig,
39
  batch_size: int = None,
40
  max_cache_len: int = None,
 
41
  dtype: torch.dtype = torch.float32,
42
  max_batch_size: Optional[int] = None,
43
  **kwargs: Any,
44
  ) -> None:
45
  Cache_4_44_2.__init__(self)
46
 
47
+ self.config = deepcopy(config)
48
  self.max_batch_size = batch_size or max_batch_size
49
+ self.batch_size = self.max_batch_size
50
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
  self.dtype = dtype
52