Skip to content

DFlash VLM training support with SGLang backend#505

Open
Mandy3311 wants to merge 1 commit intosgl-project:mainfrom
Mandy3311:feat/dflash-vlm-sglang
Open

DFlash VLM training support with SGLang backend#505
Mandy3311 wants to merge 1 commit intosgl-project:mainfrom
Mandy3311:feat/dflash-vlm-sglang

Conversation

@Mandy3311
Copy link
Contributor

Summary

This PR extends DFlash speculative decoding training to support Vision-Language Models (VLMs),
with Qwen3.5-VL as the primary validated target. It adds multimodal data ingestion,
SGLang multimodal request construction using MRotaryEmbedding, configurable weight key
paths for VLM architectures, and several correctness fixes to the base training loop.

Mixed Text + VLM Training

This PR supports training on a mixture of text-only and vision-language samples within
the same run. When --is-vlm is set and the dataset contains samples where image is
None, those samples are processed as text-only with empty tensor placeholders for
pixel_values/image_grid_thw to maintain a consistent HuggingFace Arrow schema across
parallel preprocessing shards.

--batch-size 1 is required for VLM and mixed training.

VlmDataCollatorWithPadding enforces batch_size=1 with an assertion. Each step is
either a single VLM sample (with pixel values forwarded to SGLang) or a single text
sample (pixel values are None, falls back to the text-only SGLang path). Batched
multi-sample VLM inference is left as future work.

Changes

scripts/train_dflash.py

  • Add --embed-key / --lm-head-key to configure weight key paths; VLMs typically use
    model.language_model.embed_tokens.weight instead of the LLM default
  • Load AutoProcessor when --is-vlm is set; thread it through build_dataloader

specforge/data/preprocessing.py

  • Fix multi-turn image injection: images are now attached only to the first user turn.
    Previously every user turn received the image, breaking multi-turn conversation formatting. Currently we do not support multiple images in one data item.
  • Remove the default system prompt prepended to all VLM conversations; the model's own
    chat template is responsible for system prompts.
  • Fix AttributeError: 'NoneType'.startswith crash: samples with image=None are now
    handled gracefully — processed as text-only with empty tensor placeholders for
    pixel_values/image_grid_thw to keep the HF Arrow schema consistent across shards.

specforge/data/utils.py

  • VlmDataCollatorWithPadding: enforce batch_size=1 with an assertion; pass
    pixel_values directly for VLM samples and None for text-only samples so the
    training loop falls back to the text-only SGLang path.

specforge/modeling/target/dflash_target_model.py

  • Add _build_vlm_reqs() to construct SGLang Req objects with full MultimodalInputs:
    • Slices pixel_values per sample using image_grid_thw patch counts (t×h×w)
    • Handles both single-image (1D) and multi-image (2D) image_grid_thw tensors
    • Computes mrope positions via MRotaryEmbedding.get_rope_index using the
      auto-detected vlm_model_type
    • Pads input tokens via MultiModalityDataPaddingPatternMultimodalTokens
  • Lift RadixCache creation out of _extend() into __init__ as self.dummy_tree_cache
    to avoid re-allocating the cache on every forward pass
  • Update generate_dflash_data() to branch on is_vlm_batch and dispatch to
    _build_vlm_reqs; refactor the text-only path to no longer maintain a data_cache
  • Add set_dflash_layers_to_capture() dispatch alongside the existing
    set_eagle3_layers_to_capture() in set_capture_layers()
  • Update HFDFlashTargetModel.generate_dflash_data() to forward pixel_values /
    image_grid_thw to the HF model via model_kwargs

specforge/modeling/target/target_utils.py

  • Fall back to text_config for vocab_size, hidden_size, and pad_token_id when the
    top-level VLM config does not expose them directly

References

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the DFlash speculative decoding training framework by integrating robust support for Vision-Language Models (VLMs). It allows for seamless training on datasets containing both text-only and multimodal samples, addressing critical aspects of VLM data handling, such as image injection in conversations and consistent data schema management. The changes also include architectural adaptations for VLM-specific model configurations and optimized SGLang backend interactions for multimodal inputs, laying the groundwork for more versatile and efficient VLM training.

Highlights

  • VLM Training Support: Introduced comprehensive support for Vision-Language Model (VLM) training within DFlash speculative decoding, with Qwen3.5-VL as a primary validated target.
  • Mixed Text + VLM Training: Enabled mixed training on both text-only and multimodal (vision-language) samples within the same run, gracefully handling text-only samples with empty tensor placeholders for pixel values.
  • Configurable Weight Key Paths: Added configurable embed-key and lm-head-key arguments to train_dflash.py to support diverse VLM architectures' weight key paths (e.g., model.language_model.embed_tokens.weight).
  • Multi-turn Image Injection Fix: Implemented a fix in preprocessing.py to ensure images are injected only into the first user turn of multi-turn VLM conversations, preventing incorrect formatting.
  • Batch Size Enforcement: Enforced batch_size=1 for VLM and mixed training in VlmDataCollatorWithPadding due to current limitations in batched multi-sample VLM inference.
  • SGLang Multimodal Request Construction: Developed _build_vlm_reqs() in dflash_target_model.py to construct SGLang Req objects with MultimodalInputs, handling image slicing, mrope positions via MRotaryEmbedding, and multimodal token padding.
  • RadixCache Optimization: Refactored RadixCache creation in SGLangDFlashTargetModel to initialize it once in __init__ as self.dummy_tree_cache, avoiding re-allocation on every forward pass.
  • Dynamic Data Processing: Updated generate_dflash_data() to dynamically dispatch between VLM and text-only processing paths based on the presence of pixel_values and image_grid_thw.
  • VLM Config Attribute Fallback: Implemented fallback logic in target_utils.py to retrieve vocab_size, hidden_size, and pad_token_id from text_config when not directly available in the VLM's top-level configuration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • scripts/train_dflash.py
    • Added datasets.load_from_disk import and logging configuration for SGLang.
    • Introduced command-line arguments: --embed-key, --lm-head-key, --is-vlm, --min-pixels, and --max-pixels for VLM configuration.
    • Refactored _load_raw_dataset to handle both directory and JSONL data paths.
    • Modified build_dataloader to accept and pass a processor argument, and is_vlm flag to dataset building.
    • Updated checkpoint loading logic to initialize start_epoch and global_step more robustly.
    • Integrated AutoProcessor loading when --is-vlm is enabled, using min_pixels and max_pixels arguments.
    • Passed processor to build_dataloader and used configurable embed_key/lm_head_key for TargetEmbeddingsAndHead initialization.
    • Adjusted generate_dflash_data call to pass pixel_values and image_grid_thw conditionally based on VLM training.
  • specforge/data/preprocessing.py
    • Removed the default system prompt prepended to all VLM conversations.
    • Implemented logic to inject images only into the first user turn in multi-turn conversations, preventing repeated image attachments.
    • Added handling for image=None samples, processing them as text-only with empty tensor placeholders for pixel values and image grid information to maintain schema consistency.
  • specforge/data/utils.py
    • Enforced batch_size=1 for VlmDataCollatorWithPadding with an assertion.
    • Modified __call__ to conditionally include pixel_values and image_grid_thw in the batch based on their presence and length, supporting text-only samples.
  • specforge/modeling/target/dflash_target_model.py
    • Imported various SGLang multimodal utilities, including MRotaryEmbedding, MultimodalDataItem, and MultimodalInputs.
    • Added _init_vlm_attributes method to detect and cache VLM-specific token IDs and vision configuration from the HuggingFace config.
    • Moved RadixCache initialization to __init__ as self.dummy_tree_cache to prevent re-allocation on every forward pass.
    • Introduced _build_vlm_reqs to construct SGLang Req objects for VLM inputs, including slicing pixel_values, computing mrope positions via MRotaryEmbedding, and multimodal token padding.
    • Updated set_capture_layers to dispatch to set_dflash_layers_to_capture or set_eagle3_layers_to_capture.
    • Refactored _extend to use self.dummy_tree_cache and updated model_runner.forward call to retrieve logits_output.
    • Modified generate_dflash_data to branch based on is_vlm_batch and use _build_vlm_reqs for VLM inputs, or a simplified path for text-only inputs.
    • Updated HFDFlashTargetModel.generate_dflash_data to accept and forward pixel_values and image_grid_thw via model_kwargs to the underlying HuggingFace model.
    • Removed enable_piecewise_cuda_graph from kwargs for SGLang backend initialization.
  • specforge/modeling/target/sglang_backend/model_runner.py
    • Removed unnecessary debug print statements related to parallel sizes during model initialization.
  • specforge/modeling/target/sglang_backend/patch.py
    • Corrected a descriptive comment in initialize_model_parallel regarding pipeline model-parallel groups from [b1, g3, g5, g7] to [g1, g3, g5, g7].
  • specforge/modeling/target/target_utils.py
    • Added logic in from_pretrained to fall back to config.text_config for vocab_size, hidden_size, and pad_token_id if they are not directly present in the main model configuration.
Activity
  • No specific human activity (comments, reviews, or progress updates) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements to enable DFlash training for Vision-Language Models (VLMs), with a focus on the SGLang backend. The changes include adding multimodal data ingestion, constructing SGLang multimodal requests, and making weight key paths configurable for VLM architectures. The PR also contains several important correctness fixes and refactorings, such as ensuring images are only attached to the first user turn in multi-turn conversations and gracefully handling text-only samples during mixed VLM training. My review focuses on improving code clarity and robustness by addressing non-standard argument usage and adding explanatory comments for non-obvious code.

Comment on lines +397 to +399
ckpt_info = None # 预定义以防万一,虽然下面的逻辑更稳妥

# --- 步骤 2: 尝试获取 checkpoint 信息 ---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These comments are in Chinese. For consistency and maintainability in this English-language codebase, please translate them to English.

Suggested change
ckpt_info = None # 预定义以防万一,虽然下面的逻辑更稳妥
# --- 步骤 2: 尝试获取 checkpoint 信息 ---
ckpt_info = None # Pre-define to be safe, although the logic below is more robust
# --- Step 2: Try to get checkpoint information ---

Comment on lines +438 to +444
processor = AutoProcessor.from_pretrained(
args.target_model_path,
min_pixels=args.min_pixels,
max_pixels=args.max_pixels,
trust_remote_code=args.trust_remote_code,
exist_ok=True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The exist_ok=True argument passed to AutoProcessor.from_pretrained does not appear to be a standard argument for this Hugging Face Transformers method. While it might be ignored if trust_remote_code=True allows for custom arguments in the model's loading code, it's not guaranteed and could lead to unexpected behavior or errors with different models or library versions. It would be safer to remove this argument if it's not strictly required by the Qwen-VL model's custom code.

Comment on lines +450 to +451
if 'enable_piecewise_cuda_graph' in kwargs:
del kwargs['enable_piecewise_cuda_graph']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code explicitly deletes enable_piecewise_cuda_graph from kwargs before calling SGLangDFlashTargetModel.from_pretrained. This seems like a workaround for a potential issue, but there's no explanation. This argument is related to an SGLang inference optimization (CUDA graphs), which might not be compatible with the training-time hidden state generation being done here. Please add a comment explaining why this argument needs to be removed to improve code clarity and maintainability.

@Mandy3311 Mandy3311 marked this pull request as draft March 17, 2026 11:59
@Mandy3311 Mandy3311 marked this pull request as ready for review March 17, 2026 12:06
@Mandy3311 Mandy3311 force-pushed the feat/dflash-vlm-sglang branch from c5455f2 to 8781f41 Compare March 19, 2026 02:18
Co-authored-by: hukongyi <hukongyi@cmbchina.com>
@Mandy3311 Mandy3311 force-pushed the feat/dflash-vlm-sglang branch from 8781f41 to 2ef8024 Compare March 19, 2026 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant