Papers
arxiv:2401.10774

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

Published on Jan 19, 2024
Β· Submitted by akhaliq on Jan 21, 2024
#1 Paper of the day
Authors:
,
,

Abstract

The inference process in Large Language Models (LLMs) is often limited due to the absence of parallelism in the auto-regressive decoding process, resulting in most operations being restricted by the memory bandwidth of accelerators. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa introduces only minimal overhead in terms of single-step latency while substantially reducing the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.

Community

This is an automated message from the Librarian Bot. I found the following papers similar to this paper.

The following papers were recommended by the Semantic Scholar API

Please give a thumbs up to this comment if you found it helpful!

If you want recommendations for any Paper on Hugging Face checkout this Space

Unlocking Faster AI: Medusa's Multi-Head Decoding for LLMs

Links πŸ”—:

πŸ‘‰ Subscribe: https://www.youtube.com/@Arxflix
πŸ‘‰ Twitter: https://x.com/arxflix
πŸ‘‰ LMNT (Partner): https://lmnt.com/

By Arxflix
9t4iCUHx_400x400-1.jpg

Awesome paper!

When evaluating candidate sequences, is there an optimization that involves computing only the next token for each candidate in relation to the original model? I am confused, I initially thought that the total number of tokens to compute would simply be the sum of the top-K values used at each layer. For example, given:

  • h1 predictions: [h11, h12, h13]
  • h2 predictions: [h21, h22, h23]

Wouldn't we only 6 new tokens need to be computed, assuming that the tokens from h11 to h13 can be cached and reused when generating tokens at the h2 layer? i.e. can we computeh11, adding its value to your cached state, and then evaluating all candidate sequences stemming from h11 without recomputing the previous tokens (in this example, without recomputing h11 each time)?

I think I must be missing something basic!

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2401.10774 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2401.10774 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2401.10774 in a Space README.md to link it from this page.

Collections including this paper 18