metadata
license: apache-2.0
datasets:
- lambada
language:
- en
library_name: transformers
pipeline_tag: text-generation
tags:
- text-generation-inference
- causal-lm
- int8
- PyTorch
- PostTrainingStatic
- Intel® Neural Compressor
- neural-compressor
INT8 GPT-J 6B
Model Description
GPT-J 6B is a transformer model trained using Ben Wang's Mesh Transformer JAX. "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.
This int8 PyTorch model is generated by intel-extension-for-transformers.
Package | Version |
---|---|
intel-extension-for-transformers | a4aba8ddb07c9b744b6ac106502ec059e0c47960 |
neural-compressor | 2.4.1 |
torch | 2.1.0+cpu |
intel-extension-for-pytorch | 2.1.0 |
transformers | 4.32.0 |
Usage
Currently, we only support the method of downloading the model and then loading it. In this approach, the model files are downloaded from the server and stored locally on the user's machine.
- Clone this model repository
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static
- Load int8 model
from intel_extension_for_transformers.llm.evaluation.models import TSModelCausalLMForITREX
user_model = TSModelCausalLMForITREX.from_pretrained(
args.output_dir, # Your saved path
file_name="best_model.pt",
trust_remote_code=args.trust_remote_code, # Default is False
)
Evaluation results
Evaluating the accuracy of the optimized model of gpt-j-6b using the lambada_openai dataset in lm_eval.
Dtype | Dataset | Precision |
---|---|---|
FP32 | Lambada_openai | 0.6831 |
INT8 | Lambada_openai | 0.6835 |