Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py
CHANGED
@@ -619,9 +619,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
619 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
620 |
"""
|
621 |
sequences = list(set([seq[:max_len] for seq in sequences]))
|
622 |
-
sequences = sorted(sequences, key=len, reverse=True)
|
623 |
-
dataset = ProteinDataset(sequences)
|
624 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
625 |
device = self.device
|
626 |
|
627 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
@@ -636,6 +633,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
636 |
else:
|
637 |
raise ValueError(f"Invalid pooling type: {pooling_type}")
|
638 |
|
|
|
639 |
if sql:
|
640 |
import sqlite3
|
641 |
conn = sqlite3.connect(sql_db_path)
|
@@ -646,6 +644,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
646 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
647 |
print(f"Embedding {len(to_embed)} new sequences")
|
648 |
if len(to_embed) > 0:
|
|
|
|
|
|
|
649 |
with torch.no_grad():
|
650 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
651 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
@@ -668,6 +669,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
668 |
return None
|
669 |
|
670 |
embeddings_dict = {}
|
|
|
|
|
|
|
671 |
with torch.no_grad():
|
672 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
673 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
|
|
619 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
620 |
"""
|
621 |
sequences = list(set([seq[:max_len] for seq in sequences]))
|
|
|
|
|
|
|
622 |
device = self.device
|
623 |
|
624 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
633 |
else:
|
634 |
raise ValueError(f"Invalid pooling type: {pooling_type}")
|
635 |
|
636 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
637 |
if sql:
|
638 |
import sqlite3
|
639 |
conn = sqlite3.connect(sql_db_path)
|
|
|
644 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
645 |
print(f"Embedding {len(to_embed)} new sequences")
|
646 |
if len(to_embed) > 0:
|
647 |
+
to_embed = sorted(to_embed, key=len, reverse=True)
|
648 |
+
dataset = ProteinDataset(to_embed)
|
649 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
650 |
with torch.no_grad():
|
651 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
652 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
|
|
669 |
return None
|
670 |
|
671 |
embeddings_dict = {}
|
672 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
673 |
+
dataset = ProteinDataset(sequences)
|
674 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
675 |
with torch.no_grad():
|
676 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
677 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|