syzymon commited on
Commit
caf6d75
·
1 Parent(s): 5d90917

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +267 -1
README.md CHANGED
@@ -26,4 +26,270 @@ model-index:
26
  type: pass@1
27
  value: 0.12
28
  verified: false
29
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  type: pass@1
27
  value: 0.12
28
  verified: false
29
+ ---
30
+
31
+ # LongLLaMA: Focused Transformer Training for Context Scaling
32
+
33
+ <div align="center">
34
+
35
+
36
+ <a href="https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_instruct_colab.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg"></a>
37
+
38
+
39
+ </div>
40
+
41
+ <div align="center">
42
+
43
+ [TLDR](#TLDR) | [Overview](#Overview) | [Usage](#Usage) | [LongLLaMA performance](#LongLLaMA-performance) | [Authors](#Authors) | [Citation](#Citation) | [License](#License) | [Acknowledgments](#Acknowledgments)
44
+
45
+ </div>
46
+
47
+ ## TLDR
48
+ This repo contains [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct) that is for **research purposes only**.
49
+
50
+ LongLLaMA is built upon the foundation of [OpenLLaMA](https://github.com/openlm-research/open_llama) and fine-tuned using the [Focused Transformer (FoT)](https://arxiv.org/abs/2307.03170) method. We release a smaller 3B base variant (not instruction tuned) of the LongLLaMA model on a permissive license (Apache 2.0) and inference code supporting longer contexts on [Hugging Face](https://huggingface.co/syzymon/long_llama_3b). Our model weights can serve as the drop-in replacement of LLaMA in existing implementations (for short context up to 2048 tokens). Additionally, we provide evaluation results and comparisons against the original OpenLLaMA models. Stay tuned for further updates.
51
+
52
+
53
+
54
+
55
+ ## Overview
56
+
57
+ ### Base models
58
+ [Focused Transformer: Contrastive Training for Context Scaling](https://arxiv.org/abs/2307.03170) (FoT) presents a simple method for endowing language models with the ability to handle context consisting possibly of millions of tokens while training on significantly shorter input. FoT permits a subset of attention layers to access a memory cache of (key, value) pairs to extend the context length. The distinctive aspect of FoT is its training procedure, drawing from contrastive learning. Specifically, we deliberately expose the memory attention layers to both relevant and irrelevant keys (like negative samples from unrelated documents). This strategy incentivizes the model to differentiate keys connected with semantically diverse values, thereby enhancing their structure. This, in turn, makes it possible to extrapolate the effective context length much beyond what is seen in training.
59
+
60
+
61
+ **LongLLaMA** is an [OpenLLaMA](https://github.com/openlm-research/open_llama) model finetuned with the FoT method,
62
+ with three layers used for context extension. **Crucially, LongLLaMA is able to extrapolate much beyond the context length seen in training: $8k$. E.g., in the passkey retrieval task, it can handle inputs of length $256k$**.
63
+
64
+ <div align="center">
65
+
66
+ | | [LongLLaMA-3B](https://huggingface.co/syzymon/long_llama_3b) | [LongLLaMA-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_v1_1) | LongLLaMA-7B<br />*(coming soon)*| LongLLaMA-13B<br />*(coming soon)*|
67
+ |----------------|----------|----------|-----------|-----------|
68
+ | Source model | [OpenLLaMA-3B](https://huggingface.co/openlm-research/open_llama_3b_easylm) | [OpenLLaMA-3Bv2](https://huggingface.co/openlm-research/open_llama_3b_v2_easylm) | - | - |
69
+ | Source model tokens | 1T | 1 T | - | - |
70
+ | Fine-tuning tokens | 10B | 5B | - | -|
71
+ | Memory layers | 6, 12, 18 | 6, 12, 18 | - | -|
72
+
73
+ </div>
74
+
75
+ ### Instruction/Chat tuning
76
+
77
+ In the [fine_tuning](fine_tuning) subfolder we provide the code that was used to create [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct), an instruction-tuned version of [LongLLaMA-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_v1_1). We used [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) (instructions) and [zetavg/ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) (chat) datasets for tuning.
78
+
79
+
80
+ ## Usage
81
+
82
+ See also:
83
+ * [Colab with LongLLaMA-Instruct-3Bv1.1](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_instruct_colab.ipynb).
84
+ * [Colab with an example usage of base LongLLaMA](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_colab.ipynb).
85
+ ### Requirements
86
+ ```
87
+ pip install --upgrade pip
88
+ pip install transformers==4.30 sentencepiece accelerate
89
+ ```
90
+
91
+ ### Loading model
92
+ ```python
93
+ import torch
94
+ from transformers import LlamaTokenizer, AutoModelForCausalLM
95
+
96
+ tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b")
97
+ model = AutoModelForCausalLM.from_pretrained("syzymon/long_llama_3b",
98
+ torch_dtype=torch.float32,
99
+ trust_remote_code=True)
100
+ ```
101
+
102
+ ### Input handling and generation
103
+ LongLLaMA uses the Hugging Face interface, the long input given to the model will be
104
+ split into context windows and loaded into the memory cache.
105
+ ```python
106
+ prompt = "My name is Julien and I like to"
107
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
108
+ outputs = model(input_ids=input_ids)
109
+ ```
110
+ During the model call, one can provide the parameter `last_context_length` (default $1024$), which specifies the number of tokens left in the last context window. Tuning this parameter can improve generation as the first layers do not have access to memory. See details in [How LongLLaMA handles long inputs](#How-LongLLaMA-handles-long-inputs).
111
+
112
+ ```python
113
+ generation_output = model.generate(
114
+ input_ids=input_ids,
115
+ max_new_tokens=256,
116
+ num_beams=1,
117
+ last_context_length=1792,
118
+ do_sample=True,
119
+ temperature=1.0,
120
+ )
121
+ print(tokenizer.decode(generation_output[0]))
122
+ ```
123
+
124
+ ### Additional configuration
125
+ LongLLaMA has several other parameters:
126
+ * `mem_layers` specifies layers endowed with memory (should be either an empty list or a list of all memory layers specified in the description of the checkpoint).
127
+ * `mem_dtype` allows changing the type of memory cache
128
+ * `mem_attention_grouping` can trade off speed for reduced memory usage.
129
+ When equal to `(4, 2048)`, the memory layers will process at most $4*2048$ queries at once ($4$ heads and $2048$ queries for each head).
130
+
131
+ ```python
132
+ import torch
133
+ from transformers import LlamaTokenizer, AutoModelForCausalLM
134
+
135
+ tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b")
136
+ model = AutoModelForCausalLM.from_pretrained(
137
+ "syzymon/long_llama_3b", torch_dtype=torch.float32,
138
+ mem_layers=[],
139
+ mem_dtype='bfloat16',
140
+ trust_remote_code=True,
141
+ mem_attention_grouping=(4, 2048),
142
+ )
143
+ ```
144
+
145
+
146
+ ### Drop-in use with LLaMA code
147
+ LongLLaMA checkpoints can also be used as a drop-in replacement for LLaMA checkpoints in [Hugging Face implementation of LLaMA](https://huggingface.co/docs/transformers/main/model_doc/llama), but in this case, they will be limited to the original context length of $2048$.
148
+
149
+ ```python
150
+ from transformers import LlamaTokenizer, LlamaForCausalLM
151
+ import torch
152
+
153
+ tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b")
154
+ model = LlamaForCausalLM.from_pretrained("syzymon/long_llama_3b", torch_dtype=torch.float32)
155
+ ```
156
+
157
+
158
+ ### How LongLLaMA handles long inputs
159
+ Inputs over $2048$ tokens are automatically split into windows $w_1, \ldots, w_m$. The first $m-2$ windows contain $2048$ tokens each, $w_{m-1}$ has no more than $2048$ tokens, and $w_m$ contains the number of tokens specified by `last_context_length`. The model processes the windows one by one extending the memory cache after each. If `use_cache` is `True`, the last window will not be loaded to the memory cache but to the local (generation) cache.
160
+
161
+ The memory cache stores $(key, value)$ pairs for each head of the specified memory layers `mem_layers`. In addition to this, it stores attention masks.
162
+
163
+ If `use_cache=True` (which is the case in generation), LongLLaMA will use two caches: the memory cache for the specified layers and the local (generation) cache for all layers. When the local cache exceeds $2048$ elements, its content is moved to the memory cache for the memory layers.
164
+
165
+ For simplicity, context extension is realized with a memory cache and full attention in this repo. Replacing this simple mechanism with a KNN search over an external database is possible with systems like [Faiss](https://github.com/facebookresearch/faiss). This potentially would enable further context length scaling. We leave this as a future work.
166
+
167
+
168
+ ## LongLLaMA performance
169
+ We present some illustrative examples of LongLLaMA results. Refer to our paper [Focused Transformer: Contrastive Training for Context Scaling](https://arxiv.org/abs/2307.03170) for more details.
170
+
171
+ We manage to achieve good performance on the passkey retrieval task from [Landmark Attention: Random-Access Infinite Context Length for Transformers](https://arxiv.org/abs/2305.16300). The code for generating the prompt and running the model is located in `examples/passkey.py`.
172
+
173
+ <p align="center" width="100%">
174
+ <img src="assets/plot_passkey.png" alt="LongLLaMA" style="width: 70%; min-width: 300px; display: block; margin: auto;">
175
+ </p>
176
+
177
+ Our LongLLaMA 3B model also shows improvements when using long context on two downstream tasks, TREC question classification and WebQS question answering.
178
+ <div align="center">
179
+
180
+
181
+ | Context/Dataset | TREC | WebQS |
182
+ | --- | --- | --- |
183
+ | $2K$ | 67.0 | 21.2 |
184
+ | $4K$ | 71.6 | 21.4 |
185
+ | $6K$ | 72.9 | 22.2 |
186
+ | $8K$ | **73.3** | **22.4** |
187
+
188
+ </div>
189
+
190
+ LongLLaMA retains performance on tasks that do not require long context. We provide a comparison with OpenLLaMA
191
+ on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) in the zero-shot setting.
192
+ <div align="center">
193
+
194
+ | Task/Metric | OpenLLaMA-3B | LongLLaMA-3B |
195
+ |----------------|----------|-----------|
196
+ | anli_r1/acc | 0.33 | 0.32 |
197
+ | anli_r2/acc | 0.32 | 0.33 |
198
+ | anli_r3/acc | 0.35 | 0.35 |
199
+ | arc_challenge/acc | 0.34 | 0.34 |
200
+ | arc_challenge/acc_norm | 0.37 | 0.37 |
201
+ | arc_easy/acc | 0.69 | 0.68 |
202
+ | arc_easy/acc_norm | 0.65 | 0.63 |
203
+ | boolq/acc | 0.68 | 0.68 |
204
+ | hellaswag/acc | 0.49 | 0.48 |
205
+ | hellaswag/acc_norm | 0.67 | 0.65 |
206
+ | openbookqa/acc | 0.27 | 0.28 |
207
+ | openbookqa/acc_norm | 0.40 | 0.38 |
208
+ | piqa/acc | 0.75 | 0.73 |
209
+ | piqa/acc_norm | 0.76 | 0.75 |
210
+ | record/em | 0.88 | 0.87 |
211
+ | record/f1 | 0.89 | 0.87 |
212
+ | rte/acc | 0.58 | 0.60 |
213
+ | truthfulqa_mc/mc1 | 0.22 | 0.24 |
214
+ | truthfulqa_mc/mc2 | 0.35 | 0.38 |
215
+ | wic/acc | 0.48 | 0.50 |
216
+ | winogrande/acc | 0.62 | 0.60 |
217
+ | Avg score | 0.53 | 0.53 |
218
+
219
+ </div>
220
+
221
+ Starting with v1.1 models we have decided to use [EleutherAI](https://github.com/EleutherAI) implementation of [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) wit a slight modification, that adds `<bos>` token at beginning of input sequence. The results are provided in the table below.
222
+
223
+ <div align="center">
224
+
225
+ | description | LongLLaMA-3B | OpenLLaMA-3Bv2 | LongLLaMA-3Bv1.1 | LongLLaMA-Instruct-3Bv1.1 |
226
+ |:-----------------------|:--------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|:--------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------|
227
+ | anli_r1/acc | 0.32 | 0.33 | 0.31 | 0.33 |
228
+ | anli_r2/acc | 0.33 | 0.35 | 0.33 | 0.35 |
229
+ | anli_r3/acc | 0.35 | 0.38 | 0.35 | 0.38 |
230
+ | arc_challenge/acc | 0.34 | 0.33 | 0.32 | 0.36 |
231
+ | arc_challenge/acc_norm | 0.37 | 0.36 | 0.36 | 0.37 |
232
+ | arc_easy/acc | 0.67 | 0.68 | 0.68 | 0.7 |
233
+ | arc_easy/acc_norm | 0.63 | 0.63 | 0.63 | 0.63 |
234
+ | boolq/acc | 0.68 | 0.67 | 0.66 | 0.77 |
235
+ | hellaswag/acc | 0.48 | 0.53 | 0.52 | 0.52 |
236
+ | hellaswag/acc_norm | 0.65 | 0.7 | 0.69 | 0.68 |
237
+ | openbookqa/acc | 0.28 | 0.28 | 0.28 | 0.28 |
238
+ | openbookqa/acc_norm | 0.38 | 0.39 | 0.37 | 0.41 |
239
+ | piqa/acc | 0.73 | 0.77 | 0.77 | 0.78 |
240
+ | piqa/acc_norm | 0.75 | 0.78 | 0.77 | 0.77 |
241
+ | record/em | 0.87 | 0.87 | 0.86 | 0.85 |
242
+ | record/f1 | 0.88 | 0.88 | 0.87 | 0.86 |
243
+ | rte/acc | 0.6 | 0.53 | 0.62 | 0.7 |
244
+ | truthfulqa_mc/mc1 | 0.24 | 0.22 | 0.21 | 0.25 |
245
+ | truthfulqa_mc/mc2 | 0.38 | 0.35 | 0.35 | 0.4 |
246
+ | wic/acc | 0.5 | 0.5 | 0.5 | 0.54 |
247
+ | winogrande/acc | 0.6 | 0.66 | 0.63 | 0.65 |
248
+ | Avg score | 0.53 | 0.53 | 0.53 | 0.55 |
249
+
250
+ </div>
251
+
252
+
253
+ We also provide the results on human-eval. We cut the generated text after either
254
+ * `"\ndef "`
255
+ * `"\nclass "`
256
+ * `"\nif __name__"`
257
+
258
+ <div align="center">
259
+
260
+ | | OpenLLaMA-3Bv2 | LongLLaMA-3Bv1.1 | LongLLaMA-Instruct-3Bv1.1 |
261
+ | - | - | - | - |
262
+ | pass@1| 0.09| 0.12 | 0.12 |
263
+
264
+ </div>
265
+
266
+ ## Authors
267
+ - [Szymon Tworkowski](https://scholar.google.com/citations?user=1V8AeXYAAAAJ&hl=en)
268
+ - [Konrad Staniszewski](https://scholar.google.com/citations?user=CM6PCBYAAAAJ)
269
+ - [Mikołaj Pacek](https://scholar.google.com/citations?user=eh6iEbQAAAAJ&hl=en&oi=ao)
270
+ - [Henryk Michalewski](https://scholar.google.com/citations?user=YdHW1ycAAAAJ&hl=en)
271
+ - [Yuhuai Wu](https://scholar.google.com/citations?user=bOQGfFIAAAAJ&hl=en)
272
+ - [Piotr Miłoś](https://scholar.google.pl/citations?user=Se68XecAAAAJ&hl=pl&oi=ao)
273
+
274
+
275
+ ## Citation
276
+ To cite this work please use
277
+ ```bibtex
278
+ @misc{tworkowski2023focused,
279
+ title={Focused Transformer: Contrastive Training for Context Scaling},
280
+ author={Szymon Tworkowski and Konrad Staniszewski and Mikołaj Pacek and Yuhuai Wu and Henryk Michalewski and Piotr Miłoś},
281
+ year={2023},
282
+ eprint={2307.03170},
283
+ archivePrefix={arXiv},
284
+ primaryClass={cs.CL}
285
+ }
286
+ ```
287
+
288
+
289
+ ## License
290
+ The code and base models checkpoints are licensed under [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0).
291
+ The instruction/chat tuned models are for research purposes only.
292
+ Some of the examples use external code (see headers of files for copyright notices and licenses).
293
+
294
+ ## Acknowledgments
295
+ We gratefully acknowledge the TPU Research Cloud program, which was instrumental to our research by providing significant computational resources. We are also grateful to Xinyang Geng and Hao Liu for releasing [OpenLLaMA](https://github.com/openlm-research/open_llama) checkpoints and the [EasyLM](https://github.com/young-geng/EasyLM) library.