← ML Research Wiki / 2401.10774

MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

Tianle Cai <[email protected]> Princeton University Together AI, Yuhong Li University of Illinois Urbana Champaign, Zhengyang Geng Carnegie Mellon Uni-versity, Hongwu Peng University of Connecticut, Jason D Lee Princeton University, Deming Chen University of Illinois Urbana Champaign, Tri Dao Princeton University Together AI (2024)

Paper Information
arXiv ID
Venue
International Conference on Machine Learning
Domain
Natural Language Processing, Machine Learning

Abstract

Large Language Models (LLMs) employ autoregressive decoding that requires sequential computation, with each step reliant on the previous one's output.This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache.While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model.In this paper, we present MEDUSA, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel.Using a tree-based attention mechanism, MEDUSA constructs multiple candidate continuations and verifies them simultaneously in each decoding step.By leveraging parallel processing, MEDUSA substantially reduces the number of decoding steps required.We present two levels of fine-tuning procedures for MEDUSA to meet the needs of different use cases: MEDUSA-1: MEDUSA is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration.MEDUSA-2: MEDUSA is fine-tuned together with the backbone LLM, enabling better prediction accuracy of MEDUSA heads and higher speedup but needing a special training recipe that preserves the model's capabilities.Moreover, we propose several extensions that improve or expand the utility of MEDUSA, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality.We evaluate * Equal contribution

Summary

This paper introduces MEDUSA, a framework designed to accelerate inference in Large Language Models (LLMs) by utilizing multiple decoding heads that allow for concurrent token prediction. Traditional autoregressive decoding is limited by sequential computation, which MEDUSA seeks to overcome through a tree-based attention mechanism that enables parallel evaluations of multiple candidate continuations. The method incorporates two distinct fine-tuning strategies (MEDUSA-1 using a frozen backbone model, and MEDUSA-2 involving joint training of the backbone and the MEDUSA heads). The results indicate significant inference speedups of 2.3 to 2.8 times without sacrificing generation quality. Furthermore, the authors propose additional techniques like a typical acceptance scheme to enhance efficiency and self-distillation for training dataset generation in the absence of labeled data. The experiments validate the effectiveness of MEDUSA across different models and categories, illustrating notable performance improvements, particularly for coding and extraction tasks.

Methods

This paper employs the following methods:

  • Multiple Decoding Heads
  • Tree-Based Attention
  • Self-Distillation
  • Typical Acceptance Scheme

Models Used

  • Vicuna-7B
  • Vicuna-13B
  • Vicuna-33B
  • Zephyr-7B

Datasets

The following datasets were used in this research:

  • ShareGPT
  • UltraChat

Evaluation Metrics

  • Speedup
  • Quality

Results

  • 2.3 to 2.8 times speedup in inference
  • Improved inference quality without loss
  • Enhanced performance in coding and extraction categories

Limitations

The authors identified the following limitations:

  • Not specified

Technical Requirements

  • Number of GPUs: 1
  • GPU Type: NVIDIA A100

Keywords

LLMs inference acceleration parallel decoding multi-head decoding tree-based attention fine-tuning

Papers Using Similar Methods

External Resources