transformer-lm-japanese-1.0b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).

Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

Our Blog Post

Model Details

Model Params Layers Dim Heads Dataset Dataset size Training time PPL
transformer-lm-japanese-1.0b 1.0B 18 2048 16 wiki40b/ja 2.19GB 4 days 31.47

Usage: FlaxAutoModel

Requirements:

pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99

# For CPU
pip install -U "jax[cpu]==0.4.31"

# For GPU
pip install -U "jax[cuda12]==0.4.31"

Note: Set trust_remote_code=True to load our custom model.

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)

We tested text generation in a Python 3.10 environment on GCP as follows

  • GPU Type: NVIDIA L4 (x 1)
  • Machine Type: g2-standard-16 (16 CPUs, 64GB Memory)
  • Disk: 256GB
  • OS: Ubuntu 22.04 LTS x86/64

Dataset

Tokenization

Author

Ryoichi Fukugawa

Downloads last month
10
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train fukugawa/transformer-lm-japanese-1.0b