TY - JOUR
T1 - MEDUSA
T2 - 41st International Conference on Machine Learning, ICML 2024
AU - Cai, Tianle
AU - Li, Yuhong
AU - Geng, Zhengyang
AU - Peng, Hongwu
AU - Lee, Jason D.
AU - Chen, Deming
AU - Dao, Tri
N1 - We extend our heartfelt gratitude to several individuals whose contributions were invaluable to this project: \u2022 Zhuohan Li, for his invaluable insights on LLM serving. If you haven't already, do check out Zhuohan's vLLM project-it's nothing short of impressive. \u2022 Shaojie Bai, for engaging in crucial discussions that helped shape the early phases of this work. \u2022 Denny Zhou, for introducing the truncation sampling scheme to Tianle and encouraging Tianle to explore the area of LLM serving. \u2022 Yanping Huang, for pointing out the memory-bandwidth-bound challenges associated with LLM serving to Tianle. \u2022 Lianmin Zheng, for clarifying the different training recipes used in different sizes of Vicuna models. Jason D. Lee acknowledges the support of the NSF CCF 2002272, NSF IIS 2107304, and NSF CAREER Award 2144994. Deming Chen acknowledges the support from the AMD Center of Excellence at UIUC.
PY - 2024
Y1 - 2024
N2 - Large Language Models (LLMs) employ autoregressive decoding that requires sequential computation, with each step reliant on the previous one's output. This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present MEDUSA, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, MEDUSA constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, MEDUSA substantially reduces the number of decoding steps required. We present two levels of fine-tuning procedures for MEDUSA to meet the needs of different use cases: MEDUSA-1: MEDUSA is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. MEDUSA-2: MEDUSA is fine-tuned together with the backbone LLM, enabling better prediction accuracy of MEDUSA heads and higher speedup but needing a special training recipe that preserves the model's capabilities. Moreover, we propose several extensions that improve or expand the utility of MEDUSA, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate MEDUSA on models of various sizes and training procedures. Our experiments demonstrate that MEDUSA-1 can achieve over 2.2× speedup without compromising generation quality, while MEDUSA-2 further improves the speedup to 2.3-2.8×.
AB - Large Language Models (LLMs) employ autoregressive decoding that requires sequential computation, with each step reliant on the previous one's output. This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present MEDUSA, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, MEDUSA constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, MEDUSA substantially reduces the number of decoding steps required. We present two levels of fine-tuning procedures for MEDUSA to meet the needs of different use cases: MEDUSA-1: MEDUSA is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. MEDUSA-2: MEDUSA is fine-tuned together with the backbone LLM, enabling better prediction accuracy of MEDUSA heads and higher speedup but needing a special training recipe that preserves the model's capabilities. Moreover, we propose several extensions that improve or expand the utility of MEDUSA, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate MEDUSA on models of various sizes and training procedures. Our experiments demonstrate that MEDUSA-1 can achieve over 2.2× speedup without compromising generation quality, while MEDUSA-2 further improves the speedup to 2.3-2.8×.
UR - https://www.scopus.com/pages/publications/85203829925
UR - https://www.scopus.com/pages/publications/85203829925#tab=citedBy
M3 - Conference article
AN - SCOPUS:85203829925
SN - 2640-3498
VL - 235
SP - 5209
EP - 5235
JO - Proceedings of Machine Learning Research
JF - Proceedings of Machine Learning Research
Y2 - 21 July 2024 through 27 July 2024
ER -