Skip to content

Add LK Losses: Direct Acceptance Rate Optimization for Speculative Decoding#492

Open
MrShevan wants to merge 9 commits intosgl-project:mainfrom
MrShevan:feat/lk-losses
Open

Add LK Losses: Direct Acceptance Rate Optimization for Speculative Decoding#492
MrShevan wants to merge 9 commits intosgl-project:mainfrom
MrShevan:feat/lk-losses

Conversation

@MrShevan
Copy link

@MrShevan MrShevan commented Mar 6, 2026

Motivation

This PR integrates the LK objectives from the paper https://arxiv.org/pdf/2602.23881 into Eagle3 training so the optimization target better matches speculative decoding behavior in practice.
It introduces acceptance rate as an explicit optimization objective and a tracked metric during both training and evaluation, fixes the Eagle3 acc metric under truncated draft vocab mapping, and adds support for the Infinity-Instruct dataset used in our experiments/ablations.

Modifications

  • Added LK-loss CLI controls in scripts/train_eagle3.py:
    • --lk-loss-type with lambda (hybrid KL+LK) and alpha (log-acceptance) modes.
    • --kl-scale and --kl-decay for adaptive KL weighting in lambda mode.
    • Sanity checks for non-negative KL hyperparameters.
  • Added specforge/core/lk_loss.py with:
    • expected acceptance-rate computation,
    • masked/distributed acceptance-rate aggregation,
    • LK loss composition for lambda and alpha objectives.
  • Integrated LK loss and acceptance-rate computation into Eagle3 training paths in specforge/core/eagle3.py (standard and Qwen-VL), with adapter state updates in specforge/core/eagle3_adapters.py.
  • Added acceptance-rate tracking to training/eval logging in scripts/train_eagle3.py (acceptance_rate_{i} per TTT position), which is reported to the configured tracker (including W&B when enabled).
  • Fixed accuracy computation in Eagle3 to compare mapped draft predictions in target-vocab space (d2t) against target token IDs from full target logits, avoiding errors from truncated-vocab argmax comparison.
  • Added LK utility tests in tests/test_utils/test_lk_loss_utils.py.
  • Added dataset preparation support for nebius-llama31-8b-infinity-instruct in scripts/prepare_data.py, plus run_prepare_data.sh convenience entrypoint.

Related Issues

#485

Accuracy Test

Passed tests: python -m unittest discover -s ./tests -p "test_*.py" -v
Screenshot 2026-03-07 at 19 38 09

Benchmark & Profiling

Wandb training:
Screenshot 2026-03-09 at 14 15 02
Screenshot 2026-03-09 at 14 15 59

Benchmarking Performance with a vLLM Script - examples/offline_inference/spec_decode.py:
--temp {0;1} --dataset-path philschmid/mt-bench --enable-chunked-prefill --num-prompts 80

LK-lambda loss checkpoint (6 epochs, 120000 steps, 32 batch size, nebius-infinity-instruct dataset):

(Temperature: 0)
total_num_output_tokens: 16913
num_drafts: 4463
num_draft_tokens: 31241
num_accepted_tokens: 12524
mean acceptance length: 3.81
--------------------------------------------------
acceptance at token 0: 0.77
acceptance at token 1: 0.58
acceptance at token 2: 0.44
acceptance at token 3: 0.35
acceptance at token 4: 0.27
acceptance at token 5: 0.22
acceptance at token 6: 0.18

(Temperature: 1)
total_num_output_tokens: 17172
num_drafts: 4890
num_draft_tokens: 34230
num_accepted_tokens: 12358
mean acceptance length: 3.53
--------------------------------------------------
acceptance at token 0: 0.73
acceptance at token 1: 0.54
acceptance at token 2: 0.40
acceptance at token 3: 0.30
acceptance at token 4: 0.23
acceptance at token 5: 0.18
acceptance at token 6: 0.14

LK-alpha loss checkpoint (6 epochs, 120000 steps, 32 batch size, nebius-infinity-instruct dataset):

(Temperature: 0)
total_num_output_tokens: 17030
num_drafts: 4500
num_draft_tokens: 31500
num_accepted_tokens: 12592
mean acceptance length: 3.80
--------------------------------------------------
acceptance at token 0: 0.77
acceptance at token 1: 0.58
acceptance at token 2: 0.44
acceptance at token 3: 0.35
acceptance at token 4: 0.27
acceptance at token 5: 0.22
acceptance at token 6: 0.18

(Temperature: 1)
total_num_output_tokens: 17220
num_drafts: 4983
num_draft_tokens: 34881
num_accepted_tokens: 12290
mean acceptance length: 3.47
--------------------------------------------------
acceptance at token 0: 0.73
acceptance at token 1: 0.53
acceptance at token 2: 0.39
acceptance at token 3: 0.30
acceptance at token 4: 0.22
acceptance at token 5: 0.17
acceptance at token 6: 0.13

KL loss checkpoint (6 epochs, 120000, 32 batch size, nebius-infinity-instruct dataset)

(Temperature: 0)
total_num_output_tokens: 17030
num_drafts: 4589
num_draft_tokens: 32123
num_accepted_tokens: 12501
mean acceptance length: 3.72
--------------------------------------------------
acceptance at token 0: 0.77
acceptance at token 1: 0.58
acceptance at token 2: 0.44
acceptance at token 3: 0.33
acceptance at token 4: 0.25
acceptance at token 5: 0.20
acceptance at token 6: 0.15

(Temperature: 1)
total_num_output_tokens: 17183
num_drafts: 5212
num_draft_tokens: 36484
num_accepted_tokens: 12025
mean acceptance length: 3.31
--------------------------------------------------
acceptance at token 0: 0.73
acceptance at token 1: 0.52
acceptance at token 2: 0.37
acceptance at token 3: 0.26
acceptance at token 4: 0.19
acceptance at token 5: 0.14
acceptance at token 6: 0.10

Reference model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B

(Temperature: 0)
total_num_output_tokens: 16934
num_drafts: 4942
num_draft_tokens: 34594
num_accepted_tokens: 12024
mean acceptance length: 3.43
--------------------------------------------------
acceptance at token 0: 0.75
acceptance at token 1: 0.53
acceptance at token 2: 0.39
acceptance at token 3: 0.28
acceptance at token 4: 0.21
acceptance at token 5: 0.16
acceptance at token 6: 0.12

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Eagle3 training framework by incorporating LK (Likelihood-Kernel) loss objectives, which directly optimize for the acceptance rate in speculative decoding. This change aims to align the training target more closely with practical speculative decoding behavior, leading to improved performance. Additionally, it rectifies an existing accuracy metric calculation bug and expands dataset compatibility to include the Infinity-Instruct dataset.

Highlights

  • LK Loss Integration: Integrated LK (Likelihood-Kernel) loss objectives into Eagle3 training, allowing for direct optimization of speculative decoding acceptance rates. This includes new CLI controls (--lk-loss-type, --kl-scale, --kl-decay) for configuring the loss.
  • Acceptance Rate Tracking: Introduced acceptance rate as an explicit optimization objective and a tracked metric during both training and evaluation, with logging for each TTT position.
  • Accuracy Metric Fix: Corrected the Eagle3 acc metric calculation to accurately compare mapped draft predictions in target-vocab space against target token IDs, resolving issues with truncated draft vocab mapping.
  • New Dataset Support: Added support for the nebius-llama31-8b-infinity-instruct dataset, including data preparation and processing functions.
  • New LK Loss Module: Created a new module specforge/core/lk_loss.py containing utility functions for computing expected acceptance rates, overall acceptance rates, and composing LK loss objectives.
  • Unit Tests: Included new unit tests for the LK loss utility functions to ensure their correctness and reliability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • scripts/prepare_data.py
    • Added nebius-llama31-8b-infinity-instruct to the list of supported datasets.
    • Implemented process_nebius_infinity_instruct function to format the new dataset.
    • Integrated the new dataset processing into the main function.
  • scripts/train_dflash.py
    • Removed redundant imports for get_last_checkpoint, print_on_rank0, and print_with_rank from specforge.utils.
  • scripts/train_eagle3.py
    • Introduced new command-line arguments for LK loss configuration: --lk-loss-type, --kl-scale, and --kl-decay.
    • Added sanity checks to ensure kl-scale and kl-decay are non-negative.
    • Modified run_forward function signature and return values to include acceptance rates.
    • Updated record_metrcs to accept and log acceptance rates per TTT position.
    • Adjusted the main training loop to pass LK loss parameters to the Eagle3 model and log acceptance rates during training and evaluation.
  • specforge/core/eagle3.py
    • Imported compute_acceptance_rate and compute_lk_loss from the new lk_loss module.
    • Extended Eagle3Model and QwenVLOnlineEagle3Model constructors to accept lk_loss_type, kl_scale, and kl_decay parameters.
    • Refactored _acc_and_loss to compute acceptance rate and apply LK loss based on the configured type, also fixing accuracy calculation by comparing mapped draft predictions to target token IDs.
    • Modified _compute_target_p_padded and _compute_target_p to return target_p_on_draft and target_token_ids for accurate acceptance rate and accuracy calculations.
    • Updated forward methods to handle and return acceptance rates.
    • Revised _compute_metric_acc to use target_token_ids and d2t for accuracy.
  • specforge/core/eagle3_adapters.py
    • Added target_p_on_draft and target_token_ids fields to the StepState dataclass.
    • Updated step_view methods in SdpaLikeAdapter and UspAdapter to correctly slice and pass these new target-related tensors.
  • specforge/core/lk_loss.py
    • New file added.
    • Defined expected_acceptance_rate to calculate token-wise expected acceptance rates.
    • Implemented compute_acceptance_rate to derive overall acceptance rate from draft logits and target probabilities, with optional distributed reduction.
    • Provided compute_lk_loss to combine KL loss and acceptance rate into a final loss, supporting 'alpha' (log-acceptance) and 'lambda' (hybrid KL+LK) modes with adaptive KL weighting.
  • tests/test_utils/test_lk_loss_utils.py
    • New file added.
    • Contains unit tests for expected_acceptance_rate, compute_acceptance_rate, and compute_lk_loss to ensure correctness.
Activity
  • The pull request author, MrShevan, has provided a detailed motivation and list of modifications.
  • The PR includes accuracy tests and benchmark results, indicating thorough self-review and validation.
  • No external reviewer comments or activities are present in the provided context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant enhancement by integrating LK losses for direct acceptance rate optimization in speculative decoding, aligning the training objective more closely with practical decoding behavior. The changes are comprehensive, including new command-line arguments, a dedicated LK loss module, and integration into the Eagle3 training and evaluation pipelines. Additionally, it provides a valuable fix for the accuracy metric calculation and adds support for a new dataset. The implementation is well-structured and includes corresponding unit tests. My review includes a few suggestions to refactor duplicated code blocks for improved maintainability and to remove a redundant line of code.

@MrShevan MrShevan marked this pull request as ready for review March 9, 2026 16:45
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@astralord
Copy link

Hi, @shuaills @sleepcoo @FlamingoPg @FrankLeeeee,

Can you please clarify, if you plan to review/merge this PR?

Would appreciate any feedback,
Thanks!

@FrankLeeeee
Copy link
Collaborator

@astralord thanks for your contribution, sorry for outlooking this. I will take the review job.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants