Update modeling_atomformer.py
Browse files- modeling_atomformer.py +2 -2
modeling_atomformer.py
CHANGED
@@ -2516,7 +2516,7 @@ class AtomformerEncoder(nn.Module):
|
|
2516 |
for blk in self.blocks:
|
2517 |
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
|
2518 |
|
2519 |
-
return input_embeds
|
2520 |
|
2521 |
|
2522 |
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
|
@@ -2550,7 +2550,7 @@ class AtomformerModel(AtomformerPreTrainedModel):
|
|
2550 |
) -> torch.Tensor:
|
2551 |
"""Forward function call for the transformer model."""
|
2552 |
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
|
2553 |
-
return output
|
2554 |
|
2555 |
|
2556 |
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
|
|
|
2516 |
for blk in self.blocks:
|
2517 |
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
|
2518 |
|
2519 |
+
return input_embeds, pos_embeds
|
2520 |
|
2521 |
|
2522 |
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
|
|
|
2550 |
) -> torch.Tensor:
|
2551 |
"""Forward function call for the transformer model."""
|
2552 |
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
|
2553 |
+
return output[0][:, :-1]
|
2554 |
|
2555 |
|
2556 |
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
|