Faisal AlKhateeb
commited on
Commit
·
b566562
1
Parent(s):
e94a265
update ALiBi with kv caching
Browse files- modeling_btlm.py +12 -3
modeling_btlm.py
CHANGED
@@ -74,9 +74,14 @@ class AlibiPositionEmbeddingLayer(nn.Module):
|
|
74 |
self,
|
75 |
seq_length,
|
76 |
key_length,
|
|
|
77 |
):
|
78 |
-
context_position = torch.arange(
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
relative_position = memory_position - context_position
|
81 |
relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
|
82 |
alibi = (self.slopes * -1.0).unsqueeze(1) * relative_position
|
@@ -946,7 +951,11 @@ class BTLMModel(BTLMPreTrainedModel):
|
|
946 |
|
947 |
if self.relative_pe is not None:
|
948 |
length = input_ids.shape[1]
|
949 |
-
|
|
|
|
|
|
|
|
|
950 |
else:
|
951 |
position_bias = None
|
952 |
|
|
|
74 |
self,
|
75 |
seq_length,
|
76 |
key_length,
|
77 |
+
cached_qk_len,
|
78 |
):
|
79 |
+
context_position = torch.arange(
|
80 |
+
cached_qk_len, cached_qk_len + seq_length, device=self.slopes.device
|
81 |
+
)[:, None]
|
82 |
+
memory_position = torch.arange(
|
83 |
+
key_length + cached_qk_len, device=self.slopes.device
|
84 |
+
)[None, :]
|
85 |
relative_position = memory_position - context_position
|
86 |
relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
|
87 |
alibi = (self.slopes * -1.0).unsqueeze(1) * relative_position
|
|
|
951 |
|
952 |
if self.relative_pe is not None:
|
953 |
length = input_ids.shape[1]
|
954 |
+
cached_kv_length = 0
|
955 |
+
cached_kv = past_key_values[0]
|
956 |
+
if cached_kv is not None:
|
957 |
+
cached_kv_length = cached_kv[0].shape[-2]
|
958 |
+
position_bias = self.relative_pe(length, length, cached_kv_length)
|
959 |
else:
|
960 |
position_bias = None
|
961 |
|