mgoin's picture
Update README.md
4b35527 verified
|
raw
history blame
4.66 kB
metadata
base_model: meta-llama/Meta-Llama-3-8B
inference: true
model_type: llama
pipeline_tag: text-generation
tags:
  - sparse

SparseLlama-3-8B-pruned_50.2of4

This repo contains model files for a 2:4 (N:M) sparse Meta-Llama-3-8B model pruned in one-shot with SparseGPT, and then additionally retrained with the SquareHead knowledge distillation while maintaining the 2:4 sparsity mask.

Note: This is still a work in progress and subject to change. We expect to release new weights with even better accuracy soon.

Running the model

It can be run naively in transformers for testing purposes:

# pip install transformers accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", device_map="auto")

input_text = "A poem about Machine Learning goes as follows:"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

To take advantage of the 2:4 sparsity present, install nm-vllm for fast inference and low memory-usage:

pip install nm-vllm[sparse] --extra-index-url https://pypi.neuralmagic.com/simple
from vllm import LLM, SamplingParams

model = LLM("nm-testing/SparseLlama-3-8B-pruned_50.2of4", sparsity="semi_structured_sparse_w16a16")

prompt = "A poem about Machine Learning goes as follows:"
sampling_params = SamplingParams(max_tokens=100, temperature=0)

outputs = model.generate(prompt, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)

Evaluation Benchmark Results

Model evaluation results obtained via lm-evaluation-harness following the configuration of Open LLM Leaderboard.

Benchmark Meta-Llama-3-8B SparseLlama-3-8B-pruned_50.2of4
(this model)
ARC-c
25-shot
59.47% 57.76%
MMLU
5-shot
65.29% 60.44%
HellaSwag
10-shot
82.14% 79.97%
WinoGrande
5-shot
77.27% 77.19%
GSM8K
5-shot
44.81% 47.92%
TruthfulQA
0-shot
43.96% 41.02%
Average
Accuracy
62.16% 60.72%
Recovery 100% 97.68%

Model evaluation results obtained via Mosaic Eval Gauntlet following the configuration of Eval Gauntlet v0.3.

Benchmark Meta-Llama-3-8B SparseLlama-3-8B-pruned_50.2of4
(this model)
World Knowledge 58.08% 54.61%
Commonsense Reasoning 47.66% 47.62%
Language Understanding 71.13% 67.58%
Symbolic Problem Solving 38.44% 32.15%
Reading Comprehension 57.48% 55.76%
Average Accuracy 54.70% 51.54%
Recovery 100% 94.22%

Help

For further support, and discussions on these models and AI in general, join Neural Magic's Slack Community

Acknowledgment

This model is built with Meta Llama 3. For more details on its licence please check the model card of Meta-Llama-3-8B.