lhallee commited on
Commit
3f722c7
·
verified ·
1 Parent(s): 1693b1b

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py CHANGED
@@ -619,9 +619,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
619
  Dictionary mapping sequences to embeddings, or None if sql=True
620
  """
621
  sequences = list(set([seq[:max_len] for seq in sequences]))
622
- sequences = sorted(sequences, key=len, reverse=True)
623
- dataset = ProteinDataset(sequences)
624
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
625
  device = self.device
626
 
627
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -636,6 +633,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
636
  else:
637
  raise ValueError(f"Invalid pooling type: {pooling_type}")
638
 
 
639
  if sql:
640
  import sqlite3
641
  conn = sqlite3.connect(sql_db_path)
@@ -646,6 +644,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
646
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
647
  print(f"Embedding {len(to_embed)} new sequences")
648
  if len(to_embed) > 0:
 
 
 
649
  with torch.no_grad():
650
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
651
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
@@ -668,6 +669,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
668
  return None
669
 
670
  embeddings_dict = {}
 
 
 
671
  with torch.no_grad():
672
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
673
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
 
619
  Dictionary mapping sequences to embeddings, or None if sql=True
620
  """
621
  sequences = list(set([seq[:max_len] for seq in sequences]))
 
 
 
622
  device = self.device
623
 
624
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
633
  else:
634
  raise ValueError(f"Invalid pooling type: {pooling_type}")
635
 
636
+ sequences = list(set([seq[:max_len] for seq in sequences]))
637
  if sql:
638
  import sqlite3
639
  conn = sqlite3.connect(sql_db_path)
 
644
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
645
  print(f"Embedding {len(to_embed)} new sequences")
646
  if len(to_embed) > 0:
647
+ to_embed = sorted(to_embed, key=len, reverse=True)
648
+ dataset = ProteinDataset(to_embed)
649
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
650
  with torch.no_grad():
651
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
652
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
 
669
  return None
670
 
671
  embeddings_dict = {}
672
+ sequences = sorted(sequences, key=len, reverse=True)
673
+ dataset = ProteinDataset(sequences)
674
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
675
  with torch.no_grad():
676
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
677
  seqs = sequences[i * batch_size:(i + 1) * batch_size]