converting LM

#2
by jisx - opened

Hi, do you plan to convert LM to HF format?

Hey,

I was planning to, but I just realized that apparently there's no decoder only T5 in HF's transformers:

Is that true? We might need to implement it as well.

I was also trying to convert. But I didn't find decoder-only T5 (or decoder-only UL2) in HF. MADLAD paper said they use the same configuration as previous work [27, 52]. But I cannot find configs of these two works.

I had a look at the code and I think I know what needs to be done. I will try it later today.

convert_t5x_checkpoint_to_pytorch.py is crashing for me even on a machine with 128GB of RAM. It worked for the 10B param model, but somehow this smaller model requires more RAM.

I can change the code to avoid putting it all in memory when I have the time, but it will take a bit longer.

Ok, so looking at the gin file, they used the DecoderOnly model from flaxformer.

I've converted the weights with this colab, but we need the transformers code to use them.

The weights are in https://huggingface.co/jbochi/madlad400-8b-lm

Some interesting things:

  • There's just one layer norm per block. so parallel is true here (PARALLEL_LAYERS = True in the gin file)
  • There's no relative position embeddings (disabled in the gin file: t5_architecture.Decoder.shared_relative_position_bias_factory = None)
  • lm_head and token_embeddings have tied weights

Just out of curiosity, how do you plan to use the model?

Interesting! thanks for converting!
I'm interested in its downstream performance on low-resource langauges.

jisx changed discussion status to closed

Sign up or log in to comment