Crystalcareai
commited on
Update generate.py
Browse files- generate.py +16 -2
generate.py
CHANGED
@@ -95,8 +95,22 @@ def custom_generate(
|
|
95 |
if streamer is not None:
|
96 |
streamer.put(new_ids_sampled)
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
def generate(
|
|
|
95 |
if streamer is not None:
|
96 |
streamer.put(new_ids_sampled)
|
97 |
|
98 |
+
# Create a named tuple to match the expected output format
|
99 |
+
from collections import namedtuple
|
100 |
+
GenerateOutput = namedtuple("GenerateOutput", ["sequences", "scores", "attentions", "hidden_states"])
|
101 |
+
|
102 |
+
# Convert the generated token IDs to a tensor
|
103 |
+
generated_token_ids_tensor = input_ids
|
104 |
+
|
105 |
+
# Create the GenerateOutput named tuple
|
106 |
+
output = GenerateOutput(
|
107 |
+
sequences=generated_token_ids_tensor,
|
108 |
+
scores=None,
|
109 |
+
attentions=None,
|
110 |
+
hidden_states=None
|
111 |
+
)
|
112 |
+
|
113 |
+
return output
|
114 |
|
115 |
|
116 |
def generate(
|