Kaihui's picture
Update README.md
0252b79 verified
metadata
license: apache-2.0
datasets:
  - lambada
language:
  - en
library_name: transformers
pipeline_tag: text-generation
tags:
  - text-generation-inference
  - causal-lm
  - int8
  - PyTorch
  - PostTrainingStatic
  - Intel® Neural Compressor
  - neural-compressor

INT8 GPT-J 6B

Model Description

GPT-J 6B is a transformer model trained using Ben Wang's Mesh Transformer JAX. "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.

This int8 PyTorch model is generated by intel-extension-for-transformers.

Package Version
intel-extension-for-transformers a4aba8ddb07c9b744b6ac106502ec059e0c47960
neural-compressor 2.4.1
torch 2.1.0+cpu
intel-extension-for-pytorch 2.1.0
transformers 4.32.0

Usage

Currently, we only support the method of downloading the model and then loading it. In this approach, the model files are downloaded from the server and stored locally on the user's machine.

  • Clone this model repository
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static
  • Load int8 model
from intel_extension_for_transformers.llm.evaluation.models import TSModelCausalLMForITREX

user_model = TSModelCausalLMForITREX.from_pretrained(
    args.output_dir, # Your saved path
    file_name="best_model.pt",
    trust_remote_code=args.trust_remote_code, # Default is False
)

Evaluation results

Evaluating the accuracy of the optimized model of gpt-j-6b using the lambada_openai dataset in lm_eval.

Dtype Dataset Precision
FP32 Lambada_openai 0.6831
INT8 Lambada_openai 0.6835