Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
@@ -162,14 +162,13 @@ class DalleBartDecoder(nn.Module):
|
|
162 |
print(tracemalloc.get_traced_memory())
|
163 |
|
164 |
for i in range(self.layer_count):
|
165 |
-
decoder_state, attention_state[i] = self.layers[i].forward(
|
166 |
decoder_state,
|
167 |
encoder_state,
|
168 |
attention_state[i],
|
169 |
attention_mask,
|
170 |
token_index
|
171 |
)
|
172 |
-
del decoder_state
|
173 |
print(tracemalloc.get_traced_memory())
|
174 |
decoder_state = self.final_ln(decoder_state)
|
175 |
logits = self.lm_head(decoder_state)
|
|
|
162 |
print(tracemalloc.get_traced_memory())
|
163 |
|
164 |
for i in range(self.layer_count):
|
165 |
+
del decoder_state, attention_state[i] = self.layers[i].forward(
|
166 |
decoder_state,
|
167 |
encoder_state,
|
168 |
attention_state[i],
|
169 |
attention_mask,
|
170 |
token_index
|
171 |
)
|
|
|
172 |
print(tracemalloc.get_traced_memory())
|
173 |
decoder_state = self.final_ln(decoder_state)
|
174 |
logits = self.lm_head(decoder_state)
|