JMalott commited on
Commit
0adf088
·
1 Parent(s): 8190b2f

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -162,14 +162,13 @@ class DalleBartDecoder(nn.Module):
162
  print(tracemalloc.get_traced_memory())
163
 
164
  for i in range(self.layer_count):
165
- decoder_state, attention_state[i] = self.layers[i].forward(
166
  decoder_state,
167
  encoder_state,
168
  attention_state[i],
169
  attention_mask,
170
  token_index
171
  )
172
- del decoder_state
173
  print(tracemalloc.get_traced_memory())
174
  decoder_state = self.final_ln(decoder_state)
175
  logits = self.lm_head(decoder_state)
 
162
  print(tracemalloc.get_traced_memory())
163
 
164
  for i in range(self.layer_count):
165
+ del decoder_state, attention_state[i] = self.layers[i].forward(
166
  decoder_state,
167
  encoder_state,
168
  attention_state[i],
169
  attention_mask,
170
  token_index
171
  )
 
172
  print(tracemalloc.get_traced_memory())
173
  decoder_state = self.final_ln(decoder_state)
174
  logits = self.lm_head(decoder_state)