transformer-lm-japanese-1.0b
This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).
Source Code
We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
Our Blog Post
Model Details
Model | Params | Layers | Dim | Heads | Dataset | Dataset size | Training time | PPL |
---|---|---|---|---|---|---|---|---|
transformer-lm-japanese-1.0b | 1.0B | 18 | 2048 | 16 | wiki40b/ja | 2.19GB | 4 days | 31.47 |
Usage: FlaxAutoModel
Requirements:
pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99
# For CPU
pip install -U "jax[cpu]==0.4.31"
# For GPU
pip install -U "jax[cuda12]==0.4.31"
Note: Set trust_remote_code=True to load our custom model.
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
output_ids = model.generate(
token_ids,
do_sample=True,
temperature=0.6,
top_k=20,
max_new_tokens=100
)
output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
We tested text generation in a Python 3.10 environment on GCP as follows
- GPU Type: NVIDIA L4 (x 1)
- Machine Type: g2-standard-16 (16 CPUs, 64GB Memory)
- Disk: 256GB
- OS: Ubuntu 22.04 LTS x86/64
Dataset
- wiki40b/ja (2.19GB)
Tokenization
Author
- Downloads last month
- 10
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.