File size: 4,659 Bytes
463e12e
 
 
 
 
 
 
 
 
50a00c8
463e12e
 
0229442
 
463e12e
50a00c8
463e12e
4b35527
463e12e
 
 
 
50a00c8
 
463e12e
 
 
 
 
 
 
 
4b35527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463e12e
 
 
 
50a00c8
463e12e
 
 
 
 
 
 
 
 
 
 
 
 
50a00c8
463e12e
 
 
 
 
 
 
 
 
 
 
 
50a00c8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
base_model: meta-llama/Meta-Llama-3-8B
inference: true
model_type: llama
pipeline_tag: text-generation
tags:
- sparse
---

# SparseLlama-3-8B-pruned_50.2of4

This repo contains model files for a 2:4 (N:M) sparse [Meta-Llama-3-8B](meta-llama/Meta-Llama-3-8B) model pruned in one-shot with [SparseGPT](https://arxiv.org/abs/2301.00774), and then additionally retrained with the [SquareHead](https://arxiv.org/abs/2310.06927) knowledge distillation while maintaining the 2:4 sparsity mask.

**Note:** This is still a work in progress and subject to change. We expect to release new weights with even better accuracy soon.

## Running the model

It can be run naively in transformers for testing purposes:
```python
# pip install transformers accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", device_map="auto")

input_text = "A poem about Machine Learning goes as follows:"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
```

To take advantage of the 2:4 sparsity present, install [nm-vllm](https://github.com/neuralmagic/nm-vllm) for fast inference and low memory-usage: 
```bash
pip install nm-vllm[sparse] --extra-index-url https://pypi.neuralmagic.com/simple
```

```python
from vllm import LLM, SamplingParams

model = LLM("nm-testing/SparseLlama-3-8B-pruned_50.2of4", sparsity="semi_structured_sparse_w16a16")

prompt = "A poem about Machine Learning goes as follows:"
sampling_params = SamplingParams(max_tokens=100, temperature=0)

outputs = model.generate(prompt, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
```

## Evaluation Benchmark Results

Model evaluation results obtained via [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) following the configuration of [Open LLM Leaderboard](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard).

| Benchmark                                      | Meta-Llama-3-8B  | SparseLlama-3-8B-pruned_50.2of4<br>(this model) |
|:----------------------------------------------:|:-----------:|:-----------------------------:|
| [ARC-c](https://arxiv.org/abs/1911.01547)<br> 25-shot      | 59.47%       | 57.76%                         |
| [MMLU](https://arxiv.org/abs/2009.03300)<br> 5-shot       | 65.29%       | 60.44%                         |
| [HellaSwag](https://arxiv.org/abs/1905.07830)<br> 10-shot  |82.14%       | 79.97%                         |
| [WinoGrande](https://arxiv.org/abs/1907.10641)<br> 5-shot |77.27%       | 77.19%                         |
| [GSM8K](https://arxiv.org/abs/2110.14168)<br> 5-shot      |  44.81%       | 47.92%                         |
| [TruthfulQA](https://arxiv.org/abs/2109.07958)<br> 0-shot |  43.96%       | 41.02%                         |
| **Average<br>Accuracy**  | **62.16%**                    |              **60.72%**                                     |
| **Recovery**             | **100%**                     |              **97.68%**                                     |


Model evaluation results obtained via [Mosaic Eval Gauntlet](https://github.com/mosaicml/llm-foundry/blob/main/scripts/eval/local_data/EVAL_GAUNTLET.md) following the configuration of [Eval Gauntlet v0.3](https://github.com/mosaicml/llm-foundry/blob/main/scripts/eval/yamls/eval_gauntlet_v0.3.yaml).

| Benchmark                | Meta-Llama-3-8B  | SparseLlama-3-8B-pruned_50.2of4<br>(this model) |
|:------------------------:|:----------------:|:----------------------------------------------:|
| World Knowledge          | 58.08%       | 54.61%                         |
| Commonsense Reasoning    | 47.66%       | 47.62%                         |
| Language Understanding   | 71.13%       | 67.58%                         |
| Symbolic Problem Solving | 38.44%       | 32.15%                         |
| Reading Comprehension    | 57.48%       | 55.76%                         |
| **Average Accuracy**     | **54.70%**    |  **51.54%**                   |
| **Recovery**             | **100%** |  **94.22%**                   |


## Help

For further support, and discussions on these models and AI in general, join [Neural Magic's Slack Community](https://join.slack.com/t/discuss-neuralmagic/shared_invite/zt-q1a1cnvo-YBoICSIw3L1dmQpjBeDurQ)

## Acknowledgment

This model is built with Meta Llama 3. For more details on its licence please check the model card of [Meta-Llama-3-8B](meta-llama/Meta-Llama-3-8B).