Faisal AlKhateeb commited on
Commit
b566562
·
1 Parent(s): e94a265

update ALiBi with kv caching

Browse files
Files changed (1) hide show
  1. modeling_btlm.py +12 -3
modeling_btlm.py CHANGED
@@ -74,9 +74,14 @@ class AlibiPositionEmbeddingLayer(nn.Module):
74
  self,
75
  seq_length,
76
  key_length,
 
77
  ):
78
- context_position = torch.arange(seq_length, device=self.slopes.device)[:, None]
79
- memory_position = torch.arange(key_length, device=self.slopes.device)[None, :]
 
 
 
 
80
  relative_position = memory_position - context_position
81
  relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
82
  alibi = (self.slopes * -1.0).unsqueeze(1) * relative_position
@@ -946,7 +951,11 @@ class BTLMModel(BTLMPreTrainedModel):
946
 
947
  if self.relative_pe is not None:
948
  length = input_ids.shape[1]
949
- position_bias = self.relative_pe(length, length)
 
 
 
 
950
  else:
951
  position_bias = None
952
 
 
74
  self,
75
  seq_length,
76
  key_length,
77
+ cached_qk_len,
78
  ):
79
+ context_position = torch.arange(
80
+ cached_qk_len, cached_qk_len + seq_length, device=self.slopes.device
81
+ )[:, None]
82
+ memory_position = torch.arange(
83
+ key_length + cached_qk_len, device=self.slopes.device
84
+ )[None, :]
85
  relative_position = memory_position - context_position
86
  relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
87
  alibi = (self.slopes * -1.0).unsqueeze(1) * relative_position
 
951
 
952
  if self.relative_pe is not None:
953
  length = input_ids.shape[1]
954
+ cached_kv_length = 0
955
+ cached_kv = past_key_values[0]
956
+ if cached_kv is not None:
957
+ cached_kv_length = cached_kv[0].shape[-2]
958
+ position_bias = self.relative_pe(length, length, cached_kv_length)
959
  else:
960
  position_bias = None
961