Accelerating LLM Inference with Staged Speculative Decoding
Abstract
Recent advances with large language models (LLM) illustrate their diverse capabilities. We propose a novel algorithm, staged speculative decoding, to accelerate LLM inference in small-batch, on-device scenarios. We address the low arithmetic intensity of small-batch inference by improving upon previous work in speculative decoding. First, we restructure the speculative batch as a tree, which reduces generation costs and increases the expected tokens per batch. Second, we add a second stage of speculative decoding. Taken together, we reduce single-batch decoding latency by 3.16x with a 762M parameter GPT-2-L model while perfectly preserving output quality.
Community
Interesting
Interesting work! I have a couple of questions about 3.1 -
- what does it mean by moving the "compute from end of very long sequences to the beginning"?
- In the same paragraph, does it mean like a beam search of likely second and third token? For ex: with a beam of size 3, with tree-structured there are 3 possible sequences.
Thanks
By moving compute from end of very long sequences to the beginning, I mean that for a fixed batch size, you'd rather use that compute on more probable completions than less probable ones. Let's say I have a batch size of 16. Standard speculative decoding would structure the tree as a single path:
#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
Whereas I prefer to structure the tree as having many branches, since the shorter branches have higher cumulative probability than getting to the end of a single long branch
#+#+#-#-#-#-#
| -#-#
+#-#-#-#
|
-#-#
|
-#
In the above diagram, we've taken 4 branches at the first choice, and then 2 branches from the most likely node for the third token. (First token is always already known as the last output of the last batch.)
It is indeed very much like a beam search -- you actually get a free beam search when you do tree-structured speculative decoding. For the paper I ignore that since I focus more on the latency and memory bandwidth improvements and just match distribution, but it is also nice that you get a free improvement in quality :)
Hope this helps!
I am also very interested about this work! But I still have some confusion about the construction of the tree structure. Could you please give me some more detailed illustration or some reference? Thanks!
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper