Why is the vocab size 64?

#1
by ekiefl - opened

The logits output shape is described as (batch_size, seq_len, vocab_size), which in the below example is (2, 11, 64). Why is the vocab size 64 though, when the tokenizer has a vocab of 33?

Example in question:

from transformers import AutoModelForMaskedLM  # AutoModel also works

model = AutoModelForMaskedLM.from_pretrained("Synthyra/ESMplusplus_small", trust_remote_code=True)
tokenizer = model.tokenizer

print(len(tokenizer.vocab))

sequences = ["MPRTEIN", "MSEQWENCE"]
tokenized = tokenizer(sequences, padding=True, return_tensors="pt")

# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training

output = model(**tokenized)  # get all hidden states with output_hidden_states=True
print(
    output.logits.shape
)  # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(
    output.last_hidden_state.shape
)  # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
print(output.loss)  # language modeling loss if you passed labels
# print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)

Thanks in advance.

Synthyra org

I'm not sure why EvoScale chose 64 in the embedding layer / LM head. From pure speculation it may have to do with the other types of tokens in ESM3, or could be because cuda based operations that are divisible by 8 are faster. However, in regular inference, the additional tokens are not used and should not harm performance.

lhallee changed discussion status to closed
Synthyra org

Gonna leave this open in case other people have the same question.

lhallee changed discussion status to open

Thanks @lhallee . You're totally right. Here's a small script I used to convince myself.

from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("Synthyra/ESMplusplus_small", trust_remote_code=True)
tokenizer = model.tokenizer

mask_pos = 100

sequences = ["L" * 200]
tokenized = tokenizer(sequences, padding=True, return_tensors="pt")
tokenized["input_ids"][0, mask_pos] = tokenizer.mask_token_id

output = model(**tokenized)

# Logit corresponding to L should be the argmax
assert output.logits[0, mask_pos].argmax().item() == tokenizer.vocab["L"]

Sign up or log in to comment