Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +9 -7
modeling_esm_plusplus.py
CHANGED
@@ -537,10 +537,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
537 |
batch_size: Batch size for processing
|
538 |
max_len: Maximum sequence length
|
539 |
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
540 |
-
full_precision: Whether to cast to full precision (float32) before storage
|
541 |
pooling_type: Type of pooling ('mean' or 'cls')
|
542 |
num_workers: Number of workers for data loading, 0 for the main process
|
543 |
-
sql: Whether to store embeddings in SQLite database
|
544 |
sql_db_path: Path to SQLite database
|
545 |
|
546 |
Returns:
|
@@ -553,12 +553,12 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
553 |
device = self.device
|
554 |
|
555 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
556 |
-
if full_precision:
|
557 |
-
residue_embeddings = residue_embeddings.float()
|
558 |
if full_embeddings:
|
559 |
return residue_embeddings
|
560 |
-
|
561 |
-
|
|
|
|
|
562 |
|
563 |
if sql:
|
564 |
import sqlite3
|
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
x = self.embed(input_ids)
|
578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
|
581 |
for seq, emb in zip(seqs, embeddings):
|
@@ -596,6 +596,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
597 |
x = self.embed(input_ids)
|
598 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
|
|
|
|
599 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
600 |
for seq, emb in zip(seqs, embeddings):
|
601 |
embeddings_dict[seq] = emb
|
|
|
537 |
batch_size: Batch size for processing
|
538 |
max_len: Maximum sequence length
|
539 |
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
540 |
+
full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
|
541 |
pooling_type: Type of pooling ('mean' or 'cls')
|
542 |
num_workers: Number of workers for data loading, 0 for the main process
|
543 |
+
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
544 |
sql_db_path: Path to SQLite database
|
545 |
|
546 |
Returns:
|
|
|
553 |
device = self.device
|
554 |
|
555 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
556 |
if full_embeddings:
|
557 |
return residue_embeddings
|
558 |
+
elif pooling_type == 'mean':
|
559 |
+
return self.mean_pooling(residue_embeddings, attention_mask)
|
560 |
+
else:
|
561 |
+
return residue_embeddings[:, 0, :]
|
562 |
|
563 |
if sql:
|
564 |
import sqlite3
|
|
|
575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
x = self.embed(input_ids)
|
578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
|
579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
|
581 |
for seq, emb in zip(seqs, embeddings):
|
|
|
596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
597 |
x = self.embed(input_ids)
|
598 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
599 |
+
if full_precision:
|
600 |
+
residue_embeddings = residue_embeddings.float()
|
601 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
602 |
for seq, emb in zip(seqs, embeddings):
|
603 |
embeddings_dict[seq] = emb
|