Crystalcareai commited on
Commit
3da56b6
·
verified ·
1 Parent(s): dfec8fb

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -3
generate.py CHANGED
@@ -45,7 +45,9 @@ def custom_generate(
45
  ):
46
  device = input_ids.device
47
  with torch.no_grad():
48
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
 
 
49
 
50
  for cur_token_idx in range(max_new_tokens):
51
  # Sample the next token
@@ -70,7 +72,7 @@ def custom_generate(
70
  # Assign the new id to the last token
71
  if last_token_idx + 1 >= len(base_answer_ids):
72
  # Add padding everywhere
73
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
74
  device=device)
75
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
76
  if attention_mask is not None:
@@ -79,6 +81,7 @@ def custom_generate(
79
  if attention_mask is not None:
80
  attention_mask[answer_idx, last_token_idx + 1] = 1
81
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
 
82
 
83
  if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
84
  finished_generating[answer_idx] = 1
@@ -93,7 +96,7 @@ def custom_generate(
93
  if streamer is not None:
94
  streamer.put(new_ids_sampled)
95
 
96
- return input_ids
97
 
98
 
99
  def generate(
 
45
  ):
46
  device = input_ids.device
47
  with torch.no_grad():
48
+ batch_size = input_ids.shape[0]
49
+ finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
50
+ generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
51
 
52
  for cur_token_idx in range(max_new_tokens):
53
  # Sample the next token
 
72
  # Assign the new id to the last token
73
  if last_token_idx + 1 >= len(base_answer_ids):
74
  # Add padding everywhere
75
+ new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
76
  device=device)
77
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
78
  if attention_mask is not None:
 
81
  if attention_mask is not None:
82
  attention_mask[answer_idx, last_token_idx + 1] = 1
83
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
84
+ generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
85
 
86
  if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
87
  finished_generating[answer_idx] = 1
 
96
  if streamer is not None:
97
  streamer.put(new_ids_sampled)
98
 
99
+ return generated_token_ids
100
 
101
 
102
  def generate(