Crystalcareai commited on
Commit
c24856d
·
verified ·
1 Parent(s): fba2fba

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +0 -9
generate.py CHANGED
@@ -1,10 +1,4 @@
1
  import torch
2
- from transformers.generation.utils import (
3
- GenerationMixin,
4
- validate_stopping_criteria,
5
- StoppingCriteriaList,
6
- )
7
- from transformers import TextStreamer
8
 
9
 
10
  def custom_generate(
@@ -50,7 +44,6 @@ def custom_generate(
50
  if max_new_tokens is None:
51
  max_new_tokens = 50 # Default value if not specified
52
  for cur_token_idx in range(max_new_tokens):
53
- # Sample the next token
54
  new_ids = self(
55
  input_ids[~finished_generating],
56
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
@@ -95,14 +88,12 @@ def custom_generate(
95
  if streamer is not None:
96
  streamer.put(new_ids_sampled)
97
 
98
- # Create a named tuple to match the expected output format
99
  from collections import namedtuple
100
  GenerateOutput = namedtuple("GenerateOutput", ["sequences", "scores", "attentions", "hidden_states"])
101
 
102
  # Convert the generated token IDs to a tensor
103
  generated_token_ids_tensor = input_ids
104
 
105
- # Create the GenerateOutput named tuple
106
  output = GenerateOutput(
107
  sequences=generated_token_ids_tensor,
108
  scores=None,
 
1
  import torch
 
 
 
 
 
 
2
 
3
 
4
  def custom_generate(
 
44
  if max_new_tokens is None:
45
  max_new_tokens = 50 # Default value if not specified
46
  for cur_token_idx in range(max_new_tokens):
 
47
  new_ids = self(
48
  input_ids[~finished_generating],
49
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
 
88
  if streamer is not None:
89
  streamer.put(new_ids_sampled)
90
 
 
91
  from collections import namedtuple
92
  GenerateOutput = namedtuple("GenerateOutput", ["sequences", "scores", "attentions", "hidden_states"])
93
 
94
  # Convert the generated token IDs to a tensor
95
  generated_token_ids_tensor = input_ids
96
 
 
97
  output = GenerateOutput(
98
  sequences=generated_token_ids_tensor,
99
  scores=None,