Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -66
modeling_quiet.py
CHANGED
@@ -48,7 +48,7 @@ from transformers.utils import (
|
|
48 |
replace_return_docstrings,
|
49 |
)
|
50 |
from .configuration_quiet import QuietConfig
|
51 |
-
|
52 |
import time
|
53 |
from typing import Optional, List
|
54 |
|
@@ -1423,71 +1423,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1423 |
logits = self.lm_head(mixed_hidden_states)
|
1424 |
return logits
|
1425 |
|
1426 |
-
|
1427 |
-
|
1428 |
-
|
1429 |
-
batch_size, seq_len = input_ids.shape
|
1430 |
-
max_length = max_length if max_length is not None else model.config.max_length
|
1431 |
-
max_new_tokens = max_length - seq_len
|
1432 |
-
temperature = kwargs.get("temperature", 1.0)
|
1433 |
-
|
1434 |
-
with torch.no_grad():
|
1435 |
-
for cur_token_idx in range(max_new_tokens):
|
1436 |
-
# Run a forward pass to get the logits for the next token
|
1437 |
-
outputs = model(
|
1438 |
-
input_ids=input_ids,
|
1439 |
-
attention_mask=attention_mask,
|
1440 |
-
use_cache=True,
|
1441 |
-
)
|
1442 |
-
|
1443 |
-
logits = outputs.logits[:, -1, :]
|
1444 |
-
|
1445 |
-
# Sample the next token from the logits
|
1446 |
-
next_token_logits = logits / temperature
|
1447 |
-
next_token_id = torch.multinomial(torch.nn.functional.softmax(next_token_logits, dim=-1), num_samples=1)
|
1448 |
-
|
1449 |
-
# Append the new token to the input sequence
|
1450 |
-
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
|
1451 |
-
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)], dim=-1)
|
1452 |
-
|
1453 |
-
# Stream the new token if a streamer is provided
|
1454 |
-
if streamer is not None:
|
1455 |
-
streamer.put(next_token_id)
|
1456 |
-
|
1457 |
-
# Check if the end token is generated for all sequences in the batch
|
1458 |
-
if next_token_id.eq(model.config.eos_token_id).all():
|
1459 |
-
break
|
1460 |
-
|
1461 |
-
return input_ids
|
1462 |
-
|
1463 |
-
|
1464 |
-
# Add this to QuietForCausalLM forward method to support custom generate
|
1465 |
-
|
1466 |
-
@torch.no_grad()
|
1467 |
-
def generate(
|
1468 |
-
self,
|
1469 |
-
input_ids,
|
1470 |
-
attention_mask=None,
|
1471 |
-
max_length=None,
|
1472 |
-
streamer=None,
|
1473 |
-
**kwargs,
|
1474 |
-
):
|
1475 |
-
# Prepare inputs
|
1476 |
-
batch_size, seq_len = input_ids.shape
|
1477 |
-
if attention_mask is None:
|
1478 |
-
attention_mask = torch.ones_like(input_ids)
|
1479 |
-
|
1480 |
-
# Call the custom generate function
|
1481 |
-
output_ids = custom_generate(
|
1482 |
-
self,
|
1483 |
-
input_ids=input_ids,
|
1484 |
-
attention_mask=attention_mask,
|
1485 |
-
max_length=max_length,
|
1486 |
-
streamer=streamer,
|
1487 |
-
**kwargs,
|
1488 |
-
)
|
1489 |
-
|
1490 |
-
return output_ids
|
1491 |
|
1492 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1493 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
48 |
replace_return_docstrings,
|
49 |
)
|
50 |
from .configuration_quiet import QuietConfig
|
51 |
+
from .generate import generate
|
52 |
import time
|
53 |
from typing import Optional, List
|
54 |
|
|
|
1423 |
logits = self.lm_head(mixed_hidden_states)
|
1424 |
return logits
|
1425 |
|
1426 |
+
def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
|
1427 |
+
from .generate import generate
|
1428 |
+
return generate(self, input_ids, attention_mask, max_length, temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1429 |
|
1430 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1431 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|