Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +26 -21
modeling_quiet.py
CHANGED
@@ -1836,7 +1836,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1836 |
elif ahead_idx >= self.n_ahead - 1:
|
1837 |
if labels is not None: # we're in the talk phase
|
1838 |
cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
|
1839 |
-
# print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
|
1840 |
shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
|
1841 |
padding = torch.full_like(
|
1842 |
labels[..., :cur_talk_n],
|
@@ -1848,44 +1847,50 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1848 |
[shift_labels, padding],
|
1849 |
dim=-1
|
1850 |
)
|
1851 |
-
|
1852 |
-
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1853 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
1854 |
-
|
1855 |
-
# Now safely convert rm tokens to one-hot
|
1856 |
probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
|
|
|
1857 |
else:
|
1858 |
continue
|
|
|
1859 |
temperature = self.gumbel_temperature if self.training else 0.001
|
1860 |
prev_sample_probs = sample_probs
|
1861 |
sample_probs = probabilities_2d
|
|
|
1862 |
if ahead_idx < self.n_ahead - 1 and not skip_sampling:
|
1863 |
probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
|
1864 |
if self.gumbel_detach:
|
1865 |
probabilities_2d = probabilities_2d.detach()
|
1866 |
-
|
|
|
1867 |
# convert rm logits directly to embeddings
|
1868 |
contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
|
1869 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
1870 |
contains_thought = contains_start or contains_end
|
1871 |
|
1872 |
-
|
1873 |
-
|
1874 |
-
|
1875 |
-
|
1876 |
-
|
1877 |
-
|
1878 |
-
|
1879 |
-
|
1880 |
-
|
1881 |
-
|
1882 |
-
|
|
|
|
|
|
|
|
|
1883 |
else:
|
1884 |
-
|
|
|
1885 |
else:
|
1886 |
-
|
1887 |
-
|
1888 |
-
|
|
|
1889 |
|
1890 |
if len(attention_mask.shape) == 2:
|
1891 |
breakpoint()
|
|
|
1836 |
elif ahead_idx >= self.n_ahead - 1:
|
1837 |
if labels is not None: # we're in the talk phase
|
1838 |
cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
|
|
|
1839 |
shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
|
1840 |
padding = torch.full_like(
|
1841 |
labels[..., :cur_talk_n],
|
|
|
1847 |
[shift_labels, padding],
|
1848 |
dim=-1
|
1849 |
)
|
|
|
|
|
1850 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
|
|
|
|
1851 |
probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
|
1852 |
+
skip_sampling = True
|
1853 |
else:
|
1854 |
continue
|
1855 |
+
|
1856 |
temperature = self.gumbel_temperature if self.training else 0.001
|
1857 |
prev_sample_probs = sample_probs
|
1858 |
sample_probs = probabilities_2d
|
1859 |
+
|
1860 |
if ahead_idx < self.n_ahead - 1 and not skip_sampling:
|
1861 |
probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
|
1862 |
if self.gumbel_detach:
|
1863 |
probabilities_2d = probabilities_2d.detach()
|
1864 |
+
sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
|
1865 |
+
|
1866 |
# convert rm logits directly to embeddings
|
1867 |
contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
|
1868 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
1869 |
contains_thought = contains_start or contains_end
|
1870 |
|
1871 |
+
# Flash Attention modification
|
1872 |
+
if self._attn_implementation == "flash_attention_2":
|
1873 |
+
probabilities_2d = probabilities_2d.view(batch_size, seq_len, -1)
|
1874 |
+
|
1875 |
+
if contains_thought:
|
1876 |
+
thought_id = self.start_token_id if contains_start else self.end_token_id
|
1877 |
+
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
1878 |
+
|
1879 |
+
if self.use_reparam_for_thought_embeddings:
|
1880 |
+
inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
|
1881 |
+
inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
|
1882 |
+
if contains_start:
|
1883 |
+
sampled_start = inputs_embeds.clone().detach()
|
1884 |
+
else:
|
1885 |
+
sampled_end = inputs_embeds.clone().detach()
|
1886 |
else:
|
1887 |
+
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
1888 |
+
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1889 |
else:
|
1890 |
+
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1891 |
+
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
|
1892 |
+
|
1893 |
+
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1894 |
|
1895 |
if len(attention_mask.shape) == 2:
|
1896 |
breakpoint()
|