Venue
International Conference on Machine Learning
Domain
Natural Language Processing, Machine Learning
Large Language Models (LLMs) employ autoregressive decoding that requires sequential computation, with each step reliant on the previous one's output.This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache.While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model.In this paper, we present MEDUSA, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel.Using a tree-based attention mechanism, MEDUSA constructs multiple candidate continuations and verifies them simultaneously in each decoding step.By leveraging parallel processing, MEDUSA substantially reduces the number of decoding steps required.We present two levels of fine-tuning procedures for MEDUSA to meet the needs of different use cases: MEDUSA-1: MEDUSA is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration.MEDUSA-2: MEDUSA is fine-tuned together with the backbone LLM, enabling better prediction accuracy of MEDUSA heads and higher speedup but needing a special training recipe that preserves the model's capabilities.Moreover, we propose several extensions that improve or expand the utility of MEDUSA, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality.We evaluate * Equal contribution
This paper introduces MEDUSA, a framework designed to accelerate inference in Large Language Models (LLMs) by utilizing multiple decoding heads that allow for concurrent token prediction. Traditional autoregressive decoding is limited by sequential computation, which MEDUSA seeks to overcome through a tree-based attention mechanism that enables parallel evaluations of multiple candidate continuations. The method incorporates two distinct fine-tuning strategies (MEDUSA-1 using a frozen backbone model, and MEDUSA-2 involving joint training of the backbone and the MEDUSA heads). The results indicate significant inference speedups of 2.3 to 2.8 times without sacrificing generation quality. Furthermore, the authors propose additional techniques like a typical acceptance scheme to enhance efficiency and self-distillation for training dataset generation in the absence of labeled data. The experiments validate the effectiveness of MEDUSA across different models and categories, illustrating notable performance improvements, particularly for coding and extraction tasks.
This paper employs the following methods:
- Multiple Decoding Heads
- Tree-Based Attention
- Self-Distillation
- Typical Acceptance Scheme
- Vicuna-7B
- Vicuna-13B
- Vicuna-33B
- Zephyr-7B
The following datasets were used in this research:
- 2.3 to 2.8 times speedup in inference
- Improved inference quality without loss
- Enhanced performance in coding and extraction categories
The authors identified the following limitations:
- Number of GPUs: 1
- GPU Type: NVIDIA A100
LLMs
inference acceleration
parallel decoding
multi-head decoding
tree-based attention
fine-tuning