Crystalcareai commited on
Commit
91359bc
·
verified ·
1 Parent(s): c0dd54c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +4 -66
modeling_quiet.py CHANGED
@@ -48,7 +48,7 @@ from transformers.utils import (
48
  replace_return_docstrings,
49
  )
50
  from .configuration_quiet import QuietConfig
51
-
52
  import time
53
  from typing import Optional, List
54
 
@@ -1423,71 +1423,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1423
  logits = self.lm_head(mixed_hidden_states)
1424
  return logits
1425
 
1426
-
1427
- def custom_generate(model, input_ids, attention_mask, max_length, streamer=None, **kwargs):
1428
- # Set up some variables
1429
- batch_size, seq_len = input_ids.shape
1430
- max_length = max_length if max_length is not None else model.config.max_length
1431
- max_new_tokens = max_length - seq_len
1432
- temperature = kwargs.get("temperature", 1.0)
1433
-
1434
- with torch.no_grad():
1435
- for cur_token_idx in range(max_new_tokens):
1436
- # Run a forward pass to get the logits for the next token
1437
- outputs = model(
1438
- input_ids=input_ids,
1439
- attention_mask=attention_mask,
1440
- use_cache=True,
1441
- )
1442
-
1443
- logits = outputs.logits[:, -1, :]
1444
-
1445
- # Sample the next token from the logits
1446
- next_token_logits = logits / temperature
1447
- next_token_id = torch.multinomial(torch.nn.functional.softmax(next_token_logits, dim=-1), num_samples=1)
1448
-
1449
- # Append the new token to the input sequence
1450
- input_ids = torch.cat([input_ids, next_token_id], dim=-1)
1451
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)], dim=-1)
1452
-
1453
- # Stream the new token if a streamer is provided
1454
- if streamer is not None:
1455
- streamer.put(next_token_id)
1456
-
1457
- # Check if the end token is generated for all sequences in the batch
1458
- if next_token_id.eq(model.config.eos_token_id).all():
1459
- break
1460
-
1461
- return input_ids
1462
-
1463
-
1464
- # Add this to QuietForCausalLM forward method to support custom generate
1465
-
1466
- @torch.no_grad()
1467
- def generate(
1468
- self,
1469
- input_ids,
1470
- attention_mask=None,
1471
- max_length=None,
1472
- streamer=None,
1473
- **kwargs,
1474
- ):
1475
- # Prepare inputs
1476
- batch_size, seq_len = input_ids.shape
1477
- if attention_mask is None:
1478
- attention_mask = torch.ones_like(input_ids)
1479
-
1480
- # Call the custom generate function
1481
- output_ids = custom_generate(
1482
- self,
1483
- input_ids=input_ids,
1484
- attention_mask=attention_mask,
1485
- max_length=max_length,
1486
- streamer=streamer,
1487
- **kwargs,
1488
- )
1489
-
1490
- return output_ids
1491
 
1492
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1493
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
48
  replace_return_docstrings,
49
  )
50
  from .configuration_quiet import QuietConfig
51
+ from .generate import generate
52
  import time
53
  from typing import Optional, List
54
 
 
1423
  logits = self.lm_head(mixed_hidden_states)
1424
  return logits
1425
 
1426
+ def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1427
+ from .generate import generate
1428
+ return generate(self, input_ids, attention_mask, max_length, temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1429
 
1430
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1431
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)