lhallee commited on
Commit
c58309c
·
verified ·
1 Parent(s): 2af14cc

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +9 -7
modeling_esm_plusplus.py CHANGED
@@ -537,10 +537,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
537
  batch_size: Batch size for processing
538
  max_len: Maximum sequence length
539
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
540
- full_precision: Whether to cast to full precision (float32) before storage
541
  pooling_type: Type of pooling ('mean' or 'cls')
542
  num_workers: Number of workers for data loading, 0 for the main process
543
- sql: Whether to store embeddings in SQLite database
544
  sql_db_path: Path to SQLite database
545
 
546
  Returns:
@@ -553,12 +553,12 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
553
  device = self.device
554
 
555
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
556
- if full_precision:
557
- residue_embeddings = residue_embeddings.float()
558
  if full_embeddings:
559
  return residue_embeddings
560
- return (self.mean_pooling(residue_embeddings, attention_mask) if pooling_type == 'mean'
561
- else residue_embeddings[:, 0, :])
 
 
562
 
563
  if sql:
564
  import sqlite3
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
575
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
  x = self.embed(input_ids)
578
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
579
  embeddings = get_embeddings(residue_embeddings, attention_mask)
580
 
581
  for seq, emb in zip(seqs, embeddings):
@@ -596,6 +596,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
596
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
  x = self.embed(input_ids)
598
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
 
 
599
  embeddings = get_embeddings(residue_embeddings, attention_mask)
600
  for seq, emb in zip(seqs, embeddings):
601
  embeddings_dict[seq] = emb
 
537
  batch_size: Batch size for processing
538
  max_len: Maximum sequence length
539
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
540
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
541
  pooling_type: Type of pooling ('mean' or 'cls')
542
  num_workers: Number of workers for data loading, 0 for the main process
543
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
544
  sql_db_path: Path to SQLite database
545
 
546
  Returns:
 
553
  device = self.device
554
 
555
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
556
  if full_embeddings:
557
  return residue_embeddings
558
+ elif pooling_type == 'mean':
559
+ return self.mean_pooling(residue_embeddings, attention_mask)
560
+ else:
561
+ return residue_embeddings[:, 0, :]
562
 
563
  if sql:
564
  import sqlite3
 
575
  seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
  x = self.embed(input_ids)
578
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
579
  embeddings = get_embeddings(residue_embeddings, attention_mask)
580
 
581
  for seq, emb in zip(seqs, embeddings):
 
596
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
  x = self.embed(input_ids)
598
  residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
599
+ if full_precision:
600
+ residue_embeddings = residue_embeddings.float()
601
  embeddings = get_embeddings(residue_embeddings, attention_mask)
602
  for seq, emb in zip(seqs, embeddings):
603
  embeddings_dict[seq] = emb