From f80be8605cfe61cfe8650230d3c943e7387aa0e7 Mon Sep 17 00:00:00 2001 From: gss <2783977641@qq.com> Date: Fri, 27 Feb 2026 14:34:05 +0800 Subject: [PATCH 1/3] Add Qwen3-VL-8B DFlash training support --- configs/qwen3-vl-8b-dflash.json | 54 +++ examples/run_qwen3_vl_8b_dflash_online.sh | 40 +++ scripts/train_dflash.py | 332 ++++++++++++++++-- specforge/core/__init__.py | 3 +- specforge/core/dflash.py | 159 ++++++++- specforge/data/template.py | 10 + specforge/modeling/draft/dflash.py | 66 +++- .../modeling/target/dflash_target_model.py | 145 +++++++- specforge/modeling/target/target_utils.py | 99 +++++- 9 files changed, 845 insertions(+), 63 deletions(-) create mode 100644 configs/qwen3-vl-8b-dflash.json create mode 100644 examples/run_qwen3_vl_8b_dflash_online.sh diff --git a/configs/qwen3-vl-8b-dflash.json b/configs/qwen3-vl-8b-dflash.json new file mode 100644 index 000000000..dff83fab7 --- /dev/null +++ b/configs/qwen3-vl-8b-dflash.json @@ -0,0 +1,54 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [ + 3, + 9, + 17, + 25, + 33 + ] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ], + "rope_type": "default" + }, + "rope_theta": 5000000, + "tie_word_embeddings": false, + "transformers_version": "4.57.1", + "use_cache": true, + "vocab_size": 151936 +} diff --git a/examples/run_qwen3_vl_8b_dflash_online.sh b/examples/run_qwen3_vl_8b_dflash_online.sh new file mode 100644 index 000000000..ef2e81308 --- /dev/null +++ b/examples/run_qwen3_vl_8b_dflash_online.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-8} +ATTENTION_BACKEND=${2:-flex_attention} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-16} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path Qwen/Qwen3-VL-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen3-vl-8b-dflash.json \ + --target-model-backend hf \ + --is-vlm \ + --trust-remote-code \ + --train-data-path $ROOT_DIR/cache/dataset/allava4v-mix-20k_train.localimg_regen.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --min-pixels 50176 \ + --max-pixels 1003520 \ + --output-dir $ROOT_DIR/outputs/qwen3-vl-8b-allava4v20k-dflash \ + --cache-dir $ROOT_DIR/cache \ + --num-epochs 6 \ + --batch-size 2 \ + --learning-rate 6e-4 \ + --warmup-ratio 0.04 \ + --max-grad-norm 1.0 \ + --max-length 4096 \ + --num-draft-layers 5 \ + --chat-template qwen3-vl \ + --attention-backend $ATTENTION_BACKEND \ + --block-size 16 \ + --num-anchors 512 \ + --loss-decay-gamma 7.0 \ + --log-interval 50 \ + --save-interval 1000 \ No newline at end of file diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index bc1a21d05..726496fd5 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -3,6 +3,7 @@ """DFlash Training Script.""" import argparse +import copy import logging import math import os @@ -18,16 +19,17 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoProcessor, AutoTokenizer from datasets import load_dataset from specforge.args import SGLangBackendArgs, TrackerArgs -from specforge.core.dflash import OnlineDFlashModel +from specforge.core.dflash import OnlineDFlashModel, QwenVLOnlineDFlashModel from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import destroy_distributed, get_dp_group, init_distributed from specforge.modeling.draft.dflash import DFlashDraftModel from specforge.modeling.target.dflash_target_model import ( DFlashTargetModel, + HFDFlashTargetModel, get_dflash_target_model, ) from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead @@ -39,6 +41,103 @@ print_with_rank, ) +QWEN3_VL_MODEL_TYPES = {"qwen3_vl", "qwen3_vl_moe"} + + +def _resolve_target_num_hidden_layers(target_config) -> int: + # For VLM configs (e.g. Qwen3-VL), top-level num_hidden_layers may refer to + # vision stack depth. DFlash target layers must follow language model depth. + if hasattr(target_config, "text_config") and hasattr( + target_config.text_config, "num_hidden_layers" + ): + return target_config.text_config.num_hidden_layers + if hasattr(target_config, "num_hidden_layers"): + return target_config.num_hidden_layers + raise ValueError( + f"Cannot infer num_target_layers from config type {type(target_config)}" + ) + + +def _build_target_layer_ids( + num_target_layers: int, + num_draft_layers: int, + start_layer: int = 1, + end_layer: Optional[int] = None, +) -> list[int]: + """Build evenly spaced target layer ids.""" + if num_draft_layers <= 0: + raise ValueError("num_draft_layers must be positive.") + + if end_layer is None: + end_layer = num_target_layers - 3 + + max_layer_idx = num_target_layers - 1 + start_layer = max(0, min(start_layer, max_layer_idx)) + end_layer = max(0, min(end_layer, max_layer_idx)) + + if end_layer < start_layer: + raise ValueError( + f"Invalid layer range: start_layer={start_layer}, end_layer={end_layer}" + ) + + if num_draft_layers == 1: + midpoint = num_target_layers // 2 + return [max(start_layer, min(midpoint, end_layer))] + + span = end_layer - start_layer + return [ + int(start_layer + (i * span) / (num_draft_layers - 1)) + for i in range(num_draft_layers) + ] + + +def _resolve_draft_config(target_config): + model_type = getattr(target_config, "model_type", None) + if model_type in QWEN3_VL_MODEL_TYPES and hasattr(target_config, "text_config"): + draft_config = copy.deepcopy(target_config.text_config) + for attr_name in ( + "dflash_config", + "block_size", + "rope_scaling", + "rope_theta", + "max_position_embeddings", + ): + if not hasattr(target_config, attr_name): + continue + if not hasattr(draft_config, attr_name) or getattr( + draft_config, attr_name + ) is None: + setattr(draft_config, attr_name, getattr(target_config, attr_name)) + return draft_config + return copy.deepcopy(target_config) + + +def _resolve_target_weight_keys(target_config) -> Tuple[str, str]: + model_type = getattr(target_config, "model_type", None) + if model_type in QWEN3_VL_MODEL_TYPES: + return "model.language_model.embed_tokens.weight", "lm_head.weight" + return "model.embed_tokens.weight", "lm_head.weight" + + +def _ensure_layer_types(draft_config) -> None: + if hasattr(draft_config, "layer_types") and draft_config.layer_types is not None: + return + + if not hasattr(draft_config, "num_hidden_layers"): + return + + num_hidden_layers = draft_config.num_hidden_layers + sliding_window = getattr(draft_config, "sliding_window", None) + max_window_layers = getattr(draft_config, "max_window_layers", num_hidden_layers) + if max_window_layers is None: + max_window_layers = num_hidden_layers + draft_config.layer_types = [ + "sliding_attention" + if sliding_window is not None and layer_idx >= max_window_layers + else "full_attention" + for layer_idx in range(num_hidden_layers) + ] + def parse_args(): parser = argparse.ArgumentParser(description="Train DFlash Draft Model") @@ -71,6 +170,11 @@ def parse_args(): model_group.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code" ) + model_group.add_argument( + "--is-vlm", + action="store_true", + help="Whether to enable VLM training mode. If not set, will auto-detect from target model config.", + ) model_group.add_argument( "--num-anchors", type=int, @@ -96,6 +200,18 @@ def parse_args(): type=int, default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8)), ) + dataset_group.add_argument( + "--min-pixels", + type=int, + default=50176, + help="Minimum image pixels for VLM processor.", + ) + dataset_group.add_argument( + "--max-pixels", + type=int, + default=802816, + help="Maximum image pixels for VLM processor.", + ) training_group = parser.add_argument_group("training") training_group.add_argument("--num-epochs", type=int, default=6) @@ -142,38 +258,128 @@ def parse_args(): return parser.parse_args() -def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: +def build_models( + args, target_config=None, is_vlm: bool = False +) -> Tuple[DFlashTargetModel, DFlashDraftModel, AutoConfig]: """Build target model (backend wrapper) and draft model.""" print_on_rank0( f"Loading target model from {args.target_model_path} using {args.target_model_backend} backend" ) - target_model_kwargs = {} - if args.target_model_backend == "sglang": - target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + if target_config is None: + target_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + target_model_type = getattr(target_config, "model_type", None) + + if ( + args.target_model_backend == "hf" + and is_vlm + and target_model_type == "qwen3_vl" + and args.tp_size == 1 + ): + from transformers import Qwen3VLForConditionalGeneration + + # If you're using torch==2.9.1, please ensure you have cuDNN >= 9.15 installed to avoid a performance + # regression with Conv3D. You can run `pip install nvidia-cudnn-cu12==9.16.0.29` to immediately fix it. + target_model = HFDFlashTargetModel( + Qwen3VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=args.trust_remote_code, + ) + .eval() + .cuda(), + model_type=target_model_type, + ) + elif ( + args.target_model_backend == "hf" + and is_vlm + and target_model_type == "qwen3_vl_moe" + and args.tp_size == 1 + ): + from transformers import Qwen3VLMoeForConditionalGeneration + + # If you're using torch==2.9.1, please ensure you have cuDNN >= 9.15 installed to avoid a performance + # regression with Conv3D. You can run `pip install nvidia-cudnn-cu12==9.16.0.29` to immediately fix it. + target_model = HFDFlashTargetModel( + Qwen3VLMoeForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=args.trust_remote_code, + ) + .eval() + .cuda(), + model_type=target_model_type, + ) + else: + target_model_kwargs = {} + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + + target_model = get_dflash_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device="cuda" if args.target_model_backend == "hf" else None, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) - target_model = get_dflash_target_model( - pretrained_model_name_or_path=args.target_model_path, - backend=args.target_model_backend, - torch_dtype=torch.bfloat16, - device="cuda" if args.target_model_backend == "hf" else None, - trust_remote_code=args.trust_remote_code, - **target_model_kwargs, - ) + # Resolve before draft config mutations to avoid reading modified values. + target_num_layers = _resolve_target_num_hidden_layers(target_config) if args.draft_config_path: - draft_config = AutoConfig.from_pretrained(args.draft_config_path) + draft_config = AutoConfig.from_pretrained( + args.draft_config_path, trust_remote_code=args.trust_remote_code + ) + draft_config = _resolve_draft_config(draft_config) print_on_rank0(f"Loaded draft config from {args.draft_config_path}") else: - target_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = _resolve_draft_config(target_config) draft_config.num_hidden_layers = args.num_draft_layers draft_config.block_size = args.block_size - draft_config.num_target_layers = target_config.num_hidden_layers print_on_rank0("Auto-generated draft config from target model") + # Always use target language model depth for capture layer mapping. + draft_config.num_target_layers = target_num_layers + _ensure_layer_types(draft_config) + if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: draft_config.dflash_config = {} + elif not isinstance(draft_config.dflash_config, dict): + draft_config.dflash_config = dict(draft_config.dflash_config) + + model_type = getattr(target_config, "model_type", None) + if model_type in QWEN3_VL_MODEL_TYPES: + # Keep the original evenly spaced mapping, but force the first capture + # layer to skip Qwen3-VL deepstack layers (0-2). + recommended_layer_ids = _build_target_layer_ids( + num_target_layers=target_num_layers, + num_draft_layers=draft_config.num_hidden_layers, + ) + if recommended_layer_ids and recommended_layer_ids[0] < 3: + recommended_layer_ids[0] = 3 + if "target_layer_ids" not in draft_config.dflash_config: + draft_config.dflash_config["target_layer_ids"] = recommended_layer_ids + print_on_rank0( + "Qwen3-VL detected: default target_layer_ids set to " + f"{draft_config.dflash_config['target_layer_ids']} " + "(first layer forced to 3)." + ) + elif ( + draft_config.dflash_config["target_layer_ids"] + and draft_config.dflash_config["target_layer_ids"][0] < 3 + ): + old_layer_ids = draft_config.dflash_config["target_layer_ids"] + new_layer_ids = list(old_layer_ids) + new_layer_ids[0] = 3 + draft_config.dflash_config["target_layer_ids"] = new_layer_ids + print_on_rank0( + "Qwen3-VL detected: overriding first target layer " + f"{old_layer_ids} -> {new_layer_ids} " + "to avoid deepstack train/serve mismatch." + ) draft_config._attn_implementation = args.attention_backend print_on_rank0(f"Using attention backend: {args.attention_backend}") @@ -191,10 +397,15 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}" ) - return target_model, draft_model + return target_model, draft_model, target_config -def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: +def build_dataloader( + args, + tokenizer, + is_vlm: bool = False, + processor: Optional[AutoProcessor] = None, +) -> Tuple[DataLoader, Optional[DataLoader]]: """Build train and eval dataloaders.""" import hashlib @@ -213,6 +424,8 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, num_proc=args.build_dataset_num_proc, @@ -233,6 +446,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=True, process_group=get_dp_group(), + is_vlm=is_vlm, ) eval_dataloader = None @@ -244,6 +458,8 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, @@ -251,6 +467,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=False, process_group=get_dp_group(), + is_vlm=is_vlm, ) return train_dataloader, eval_dataloader @@ -345,7 +562,27 @@ def main(): init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) print_with_rank("Initialized distributed") - target_model, draft_model = build_models(args) + target_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + detected_vlm = getattr(target_config, "model_type", None) in QWEN3_VL_MODEL_TYPES + is_vlm = args.is_vlm or detected_vlm + if detected_vlm and not args.is_vlm: + print_on_rank0( + "Detected Qwen3-VL target config; enabling VLM mode automatically." + ) + if is_vlm and args.target_model_backend != "hf": + raise ValueError( + "Real multimodal DFlash training currently supports only HF backend. " + "Please set --target-model-backend hf." + ) + print_on_rank0( + f"Detected target model_type={getattr(target_config, 'model_type', None)}, is_vlm={is_vlm}" + ) + + target_model, draft_model, target_config = build_models( + args, target_config, is_vlm=is_vlm + ) draft_model_last_checkpoint = None if args.ckpt_dir is not None: @@ -384,7 +621,17 @@ def main(): f"step {resume_state['global_step']}" ) - tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + if is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + trust_remote_code=args.trust_remote_code, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + tokenizer = processor.tokenizer + else: + processor = None + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) if args.mask_token_id is not None: mask_token_id = args.mask_token_id @@ -400,22 +647,35 @@ def main(): draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}") - train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + train_dataloader, eval_dataloader = build_dataloader( + args, + tokenizer, + is_vlm=is_vlm, + processor=processor, + ) steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) total_steps = args.num_epochs * steps_per_epoch print_on_rank0(f"Total training steps: {total_steps}") print_on_rank0("Loading target embeddings and head...") + embed_key, lm_head_key = _resolve_target_weight_keys(target_config) + print_on_rank0( + f"Loading target embeddings/head with keys: embed='{embed_key}', head='{lm_head_key}'" + ) target_components = TargetEmbeddingsAndHead.from_pretrained( args.target_model_path, - embed_key="model.embed_tokens.weight", # Adjust if Qwen/Llama differs - lm_head_key="lm_head.weight", + embed_key=embed_key, + lm_head_key=lm_head_key, device="cuda", trust_remote_code=args.trust_remote_code, ) - dflash_model = OnlineDFlashModel( + dflash_model_cls = OnlineDFlashModel + if getattr(target_config, "model_type", None) == "qwen3_vl": + dflash_model_cls = QwenVLOnlineDFlashModel + print_on_rank0(f"Using DFlash wrapper: {dflash_model_cls.__name__}") + dflash_model = dflash_model_cls( draft_model=draft_model, target_lm_head=target_components.lm_head, target_embed_tokens=target_components.embed_tokens, @@ -482,15 +742,33 @@ def main(): input_ids = data["input_ids"].cuda() attention_mask = data["attention_mask"].cuda() loss_mask = data["loss_mask"].cuda() + target_kwargs = {} + if is_vlm: + if "pixel_values" in data: + target_kwargs["pixel_values"] = data["pixel_values"].cuda() + if "pixel_values_videos" in data: + target_kwargs["pixel_values_videos"] = data["pixel_values_videos"].cuda() + if "image_grid_thw" in data: + target_kwargs["image_grid_thw"] = data["image_grid_thw"].cuda() + if "video_grid_thw" in data: + target_kwargs["video_grid_thw"] = data["video_grid_thw"].cuda() + if "second_per_grid_ts" in data: + target_kwargs["second_per_grid_ts"] = data["second_per_grid_ts"].cuda() target_output = target_model.generate_dflash_data( - input_ids, attention_mask, loss_mask + input_ids, attention_mask, loss_mask, **target_kwargs ) hidden_states = target_output.hidden_states.cuda() # Ensure on GPU + position_ids = ( + target_output.position_ids.cuda() + if target_output.position_ids is not None + else None + ) loss, accuracy = dflash_model( input_ids=input_ids, hidden_states=hidden_states, loss_mask=loss_mask, + position_ids=position_ids, ) (loss / args.accumulation_steps).backward() diff --git a/specforge/core/__init__.py b/specforge/core/__init__.py index 1b45f4f7a..4a62e0dc0 100644 --- a/specforge/core/__init__.py +++ b/specforge/core/__init__.py @@ -1,8 +1,9 @@ -from .dflash import OnlineDFlashModel +from .dflash import OnlineDFlashModel, QwenVLOnlineDFlashModel from .eagle3 import OnlineEagle3Model, QwenVLOnlineEagle3Model __all__ = [ "OnlineDFlashModel", + "QwenVLOnlineDFlashModel", "OnlineEagle3Model", "QwenVLOnlineEagle3Model", ] diff --git a/specforge/core/dflash.py b/specforge/core/dflash.py index 83ac23a66..0b2c2b22a 100644 --- a/specforge/core/dflash.py +++ b/specforge/core/dflash.py @@ -92,7 +92,10 @@ def __init__( self._cached_bsz: Optional[int] = None def _sample_anchor_positions( - self, seq_len: int, loss_mask: torch.Tensor, device: torch.device + self, + seq_len: int, + loss_mask: torch.Tensor, + device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly sample anchor positions per sample; returns (anchors, keep_mask).""" bs = self.block_size @@ -156,6 +159,61 @@ def _create_position_ids(self, anchor_positions: torch.Tensor) -> torch.Tensor: pos_ids = anchor_positions.unsqueeze(-1) + offsets return pos_ids.view(bsz, -1) + def _create_position_ids_from_context( + self, + context_position_ids: torch.Tensor, + anchor_positions: torch.Tensor, + seq_len: int, + ) -> torch.Tensor: + """ + Gather the draft block position ids from context position ids. + Text-only path expects [B, S] position ids. + """ + if context_position_ids.ndim != 2: + raise ValueError( + f"OnlineDFlashModel expects 2D context position ids, got ndim={context_position_ids.ndim}. " + "Use QwenVLOnlineDFlashModel for multimodal mRoPE position ids." + ) + device = anchor_positions.device + offsets = torch.arange(self.block_size, device=device).view(1, 1, -1) + gather_indices = (anchor_positions.unsqueeze(-1) + offsets).clamp(max=seq_len - 1) + bsz, n_blocks = anchor_positions.shape + gathered = torch.gather( + context_position_ids.unsqueeze(1).expand(-1, n_blocks, -1), + 2, + gather_indices, + ) + return gathered.reshape(bsz, -1) + + def _prepare_context_position_ids( + self, + position_ids: Optional[torch.Tensor], + *, + seq_len: int, + bsz: int, + device: torch.device, + ) -> torch.Tensor: + if position_ids is None: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + + context_position_ids = position_ids.to(device=device).long() + if context_position_ids.ndim != 2: + raise ValueError( + f"OnlineDFlashModel expects position_ids with shape [B, S], got ndim={context_position_ids.ndim}. " + "Use QwenVLOnlineDFlashModel for multimodal mRoPE position ids." + ) + expected_seq_len = context_position_ids.shape[-1] + if expected_seq_len != seq_len: + raise ValueError( + f"Position ids length mismatch: got {expected_seq_len}, expected {seq_len}" + ) + return context_position_ids + + def _concat_context_and_draft_position_ids( + self, context_position_ids: torch.Tensor, draft_position_ids: torch.Tensor + ) -> torch.Tensor: + return torch.cat([context_position_ids, draft_position_ids], dim=1) + def _create_noise_embed(self, input_ids, anchor_positions, block_keep_mask): bsz, seq_len = input_ids.shape n = anchor_positions.shape[1] @@ -186,6 +244,7 @@ def forward( input_ids: torch.Tensor, hidden_states: torch.Tensor, loss_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Parallel block-wise training forward pass.""" bsz, seq_len = input_ids.shape @@ -199,11 +258,23 @@ def forward( input_ids, anchor_positions, block_keep_mask ) - context_position_ids = ( - torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + context_position_ids = self._prepare_context_position_ids( + position_ids, + seq_len=seq_len, + bsz=bsz, + device=device, + ) + + draft_position_ids = self._create_position_ids_from_context( + context_position_ids=context_position_ids, + anchor_positions=anchor_positions, + seq_len=seq_len, + ) + + full_position_ids = self._concat_context_and_draft_position_ids( + context_position_ids=context_position_ids, + draft_position_ids=draft_position_ids, ) - draft_position_ids = self._create_position_ids(anchor_positions) - full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) dflash_attn_mask = create_dflash_block_mask( anchor_positions=anchor_positions, @@ -277,3 +348,81 @@ def forward( accuracy = correct.sum().float() / actual_token_count return loss, accuracy + + +class QwenVLOnlineDFlashModel(OnlineDFlashModel): + """Qwen-VL specific DFlash wrapper with multimodal mRoPE position id support.""" + + def _prepare_context_position_ids( + self, + position_ids: Optional[torch.Tensor], + *, + seq_len: int, + bsz: int, + device: torch.device, + ) -> torch.Tensor: + if position_ids is None: + return torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + + context_position_ids = position_ids.to(device=device).long() + if context_position_ids.ndim == 3 and context_position_ids.shape[0] != 3: + if context_position_ids.shape[1] == 3: + context_position_ids = context_position_ids.permute(1, 0, 2).contiguous() + else: + raise ValueError( + "Multimodal position_ids must have shape [3, B, S] or [B, 3, S]." + ) + if context_position_ids.ndim == 3 and not getattr( + self.draft_model, "use_interleaved_mrope", False + ): + context_position_ids = context_position_ids[0] + if context_position_ids.ndim not in (2, 3): + raise ValueError( + f"Unsupported position_ids ndim={context_position_ids.ndim}; expected 2 or 3." + ) + expected_seq_len = context_position_ids.shape[-1] + if expected_seq_len != seq_len: + raise ValueError( + f"Position ids length mismatch: got {expected_seq_len}, expected {seq_len}" + ) + return context_position_ids + + def _create_position_ids_from_context( + self, + context_position_ids: torch.Tensor, + anchor_positions: torch.Tensor, + seq_len: int, + ) -> torch.Tensor: + device = anchor_positions.device + offsets = torch.arange(self.block_size, device=device).view(1, 1, -1) + gather_indices = (anchor_positions.unsqueeze(-1) + offsets).clamp(max=seq_len - 1) + + if context_position_ids.ndim == 2: + bsz, n_blocks = anchor_positions.shape + gathered = torch.gather( + context_position_ids.unsqueeze(1).expand(-1, n_blocks, -1), + 2, + gather_indices, + ) + return gathered.reshape(bsz, -1) + + if context_position_ids.ndim == 3: + _, bsz, _ = context_position_ids.shape + n_blocks = anchor_positions.shape[1] + expanded_context = context_position_ids.unsqueeze(2).expand( + -1, -1, n_blocks, -1 + ) + expanded_indices = gather_indices.unsqueeze(0).expand(3, -1, -1, -1) + gathered = torch.gather(expanded_context, 3, expanded_indices) + return gathered.reshape(3, bsz, -1) + + raise ValueError( + f"Unsupported position_ids ndim={context_position_ids.ndim}; expected 2 or 3." + ) + + def _concat_context_and_draft_position_ids( + self, context_position_ids: torch.Tensor, draft_position_ids: torch.Tensor + ) -> torch.Tensor: + if context_position_ids.ndim == 2: + return torch.cat([context_position_ids, draft_position_ids], dim=1) + return torch.cat([context_position_ids, draft_position_ids], dim=2) diff --git a/specforge/data/template.py b/specforge/data/template.py index 4803db9af..97918d272 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -127,6 +127,16 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="qwen3-vl", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + TEMPLATE_REGISTRY.register( name="phi3", template=ChatTemplate( diff --git a/specforge/modeling/draft/dflash.py b/specforge/modeling/draft/dflash.py index 0aea03fe1..5b754efe5 100644 --- a/specforge/modeling/draft/dflash.py +++ b/specforge/modeling/draft/dflash.py @@ -39,6 +39,64 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def get_rope_scaling_value(config: Qwen3Config, key: str, default=None): + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return default + if isinstance(rope_scaling, dict): + return rope_scaling.get(key, default) + return getattr(rope_scaling, key, default) + + +class Qwen3InterleavedMultiRotaryEmbedding(Qwen3RotaryEmbedding): + """Interleaved mRoPE for Qwen3-VL style multimodal position ids.""" + + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.mrope_section = get_rope_scaling_value( + config, "mrope_section", [24, 20, 20] + ) + + def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: + freqs_t = freqs[0] + for dim_idx, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim_idx] * 3 + idx_slice = slice(offset, length, 3) + freqs_t[..., idx_slice] = freqs[dim_idx, ..., idx_slice] + return freqs_t + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + interleaved_freqs = self._apply_interleaved_mrope(freqs) + emb = torch.cat((interleaved_freqs, interleaved_freqs), dim=-1) + scaling = getattr(self, "attention_scaling", 1.0) + cos = emb.cos() * scaling + sin = emb.sin() * scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class Qwen3DFlashAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -228,7 +286,13 @@ def __init__(self, config) -> None: build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3RotaryEmbedding(config) + self.use_interleaved_mrope = bool( + get_rope_scaling_value(config, "mrope_interleaved", False) + ) + if self.use_interleaved_mrope: + self.rotary_emb = Qwen3InterleavedMultiRotaryEmbedding(config) + else: + self.rotary_emb = Qwen3RotaryEmbedding(config) self.fc = nn.Linear( len(self.target_layer_ids) * config.hidden_size, config.hidden_size, diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 732f8e4a7..00165c48c 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -15,12 +15,14 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM from specforge.distributed import get_tp_group from .sglang_backend import SGLangRunner +QWEN3_VL_MODEL_TYPES = {"qwen3_vl", "qwen3_vl_moe"} + @dataclass class DFlashTargetOutput: @@ -28,6 +30,7 @@ class DFlashTargetOutput: input_ids: torch.Tensor # [batch, seq_len] attention_mask: torch.Tensor # [batch, seq_len] loss_mask: torch.Tensor # [batch, seq_len] + position_ids: Optional[torch.Tensor] = None # [batch, seq_len] or [3, batch, seq_len] class DFlashTargetModel(ABC): @@ -56,6 +59,11 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: """Generate context hidden states for DFlash training.""" @@ -184,7 +192,23 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: + if ( + pixel_values is not None + or pixel_values_videos is not None + or image_grid_thw is not None + or video_grid_thw is not None + or second_per_grid_ts is not None + ): + raise NotImplementedError( + "SGLangDFlashTargetModel does not yet support multimodal inputs. " + "Use HF backend for real VLM DFlash training." + ) sampling_params = SamplingParams(temperature=0, max_new_tokens=1) reqs, data_cache = [], [] @@ -220,13 +244,15 @@ def generate_dflash_data( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + position_ids=None, ) class HFDFlashTargetModel(DFlashTargetModel): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, model_type: Optional[str] = None): super().__init__() self.model = model + self.model_type = model_type @classmethod def from_pretrained( @@ -238,20 +264,55 @@ def from_pretrained( trust_remote_code: bool = True, **kwargs, ) -> "HFDFlashTargetModel": - - target_model = AutoModelForCausalLM.from_pretrained( + hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, - torch_dtype=torch_dtype, cache_dir=cache_dir, - output_hidden_states=True, trust_remote_code=trust_remote_code, - **kwargs, - ).eval() + ) + model_type = getattr(hf_config, "model_type", None) + + if model_type in QWEN3_VL_MODEL_TYPES: + if model_type == "qwen3_vl": + try: + from transformers import Qwen3VLForConditionalGeneration + except ImportError as exc: + raise ImportError( + "Qwen3VLForConditionalGeneration is unavailable. " + "Please upgrade transformers to a version with qwen3_vl support." + ) from exc + + model_cls = Qwen3VLForConditionalGeneration + else: + try: + from transformers import Qwen3VLMoeForConditionalGeneration + except ImportError as exc: + raise ImportError( + "Qwen3VLMoeForConditionalGeneration is unavailable. " + "Please upgrade transformers to a version with qwen3_vl_moe support." + ) from exc + + model_cls = Qwen3VLMoeForConditionalGeneration + target_model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + **kwargs, + ).eval() + else: + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + output_hidden_states=True, + trust_remote_code=trust_remote_code, + **kwargs, + ).eval() if device: target_model = target_model.to(device) - return cls(target_model) + return cls(target_model, model_type=model_type) @torch.no_grad() def generate_dflash_data( @@ -259,13 +320,66 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - ) + target_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": True, + "use_cache": False, + } + if self.model_type in QWEN3_VL_MODEL_TYPES: + target_kwargs.update( + { + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + } + ) + + filtered_target_kwargs = {} + for key, value in target_kwargs.items(): + if key in { + "input_ids", + "attention_mask", + "output_hidden_states", + "use_cache", + } or value is not None: + filtered_target_kwargs[key] = value + + outputs = self.model(**filtered_target_kwargs) + if outputs.hidden_states is None: + raise ValueError( + "Target model did not return hidden states. Ensure output_hidden_states=True is supported." + ) + + position_ids = None + if self.model_type in QWEN3_VL_MODEL_TYPES: + target_inner_model = getattr(self.model, "model", None) + if target_inner_model is not None and hasattr( + target_inner_model, "get_rope_index" + ): + rope_kwargs = { + "input_ids": input_ids, + "image_grid_thw": image_grid_thw, + "attention_mask": attention_mask, + } + if video_grid_thw is not None: + rope_kwargs["video_grid_thw"] = video_grid_thw + if second_per_grid_ts is not None: + rope_kwargs["second_per_grid_ts"] = second_per_grid_ts + + filtered_rope_kwargs = { + key: value for key, value in rope_kwargs.items() if value is not None + } + position_ids, _ = target_inner_model.get_rope_index( + **filtered_rope_kwargs + ) # hidden_states[0] = embedding output; hidden_states[i+1] = layer i output offset = 1 @@ -282,6 +396,7 @@ def generate_dflash_data( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + position_ids=position_ids, ) diff --git a/specforge/modeling/target/target_utils.py b/specforge/modeling/target/target_utils.py index 6f7b1e101..610286956 100644 --- a/specforge/modeling/target/target_utils.py +++ b/specforge/modeling/target/target_utils.py @@ -40,9 +40,10 @@ def from_pretrained( ) -> "TargetEmbeddingsAndHead": # 1. Load Config - config = AutoConfig.from_pretrained( + full_config = AutoConfig.from_pretrained( model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code ) + config = cls._resolve_text_config(full_config) instance = cls(config) if embed_key is None: @@ -63,7 +64,11 @@ def from_pretrained( print(f"Warning: Snapshot download failed or path check failed: {e}") # 3. Handle Weight Tying - tie_weights = getattr(config, "tie_word_embeddings", False) + tie_weights = getattr( + full_config, + "tie_word_embeddings", + getattr(config, "tie_word_embeddings", False), + ) # 4. Load Weights instance._load_weights(local_model_path, embed_key, lm_head_key, tie_weights) @@ -75,28 +80,64 @@ def from_pretrained( return instance + @staticmethod + def _resolve_text_config(config): + if hasattr(config, "text_config") and hasattr(config.text_config, "hidden_size"): + return config.text_config + return config + + @staticmethod + def _candidate_suffixes(key_type: str): + if key_type == "embed": + return ( + "model.language_model.model.embed_tokens.weight", + "model.language_model.embed_tokens.weight", + "model.embed_tokens.weight", + "language_model.embed_tokens.weight", + "embed_tokens.weight", + ) + return ("lm_head.weight",) + + def _resolve_weight_key(self, available_keys, preferred_key: str, key_type: str): + if preferred_key in available_keys: + return preferred_key + + for suffix in self._candidate_suffixes(key_type): + matches = [k for k in available_keys if k.endswith(suffix)] + if matches: + # Prefer shortest full key for deterministic behavior. + return sorted(matches, key=len)[0] + return None + def _load_weights( self, model_path: str, embed_key: str, lm_head_key: str, tie_weights: bool ): index_files = glob.glob(os.path.join(model_path, "*.index.json")) weight_map = {} files_to_load = {} + resolved_embed_key = embed_key + resolved_lm_head_key = lm_head_key if not tie_weights else None if index_files: with open(index_files[0], "r") as f: index = json.load(f) weight_map = index.get("weight_map", {}) - if embed_key in weight_map: - files_to_load[embed_key] = weight_map[embed_key] - else: + resolved_embed_key = self._resolve_weight_key( + weight_map.keys(), embed_key, "embed" + ) + if resolved_embed_key is None: raise ValueError( f"Embedding key '{embed_key}' not found in weight map." ) + files_to_load[resolved_embed_key] = weight_map[resolved_embed_key] if not tie_weights: - if lm_head_key in weight_map: - files_to_load[lm_head_key] = weight_map[lm_head_key] + resolved_lm_head_key = self._resolve_weight_key( + weight_map.keys(), lm_head_key, "lm_head" + ) + if resolved_lm_head_key is not None: + files_to_load[resolved_lm_head_key] = weight_map[resolved_lm_head_key] else: print( f"Warning: {lm_head_key} not found. Ensure model doesn't use tied weights manually." @@ -109,9 +150,33 @@ def _load_weights( if not target_file: raise FileNotFoundError("No checkpoint found.") - files_to_load[embed_key] = os.path.basename(target_file) + if target_file.endswith(".safetensors"): + with safe_open(target_file, framework="pt") as f: + available_keys = list(f.keys()) + else: + full_state = torch.load(target_file, map_location="cpu") + available_keys = list(full_state.keys()) + del full_state + gc.collect() + + resolved_embed_key = self._resolve_weight_key( + available_keys, embed_key, "embed" + ) + if resolved_embed_key is None: + raise ValueError( + f"Embedding key '{embed_key}' not found in checkpoint file." + ) + files_to_load[resolved_embed_key] = os.path.basename(target_file) if not tie_weights: - files_to_load[lm_head_key] = os.path.basename(target_file) + resolved_lm_head_key = self._resolve_weight_key( + available_keys, lm_head_key, "lm_head" + ) + if resolved_lm_head_key is not None: + files_to_load[resolved_lm_head_key] = os.path.basename(target_file) + else: + print( + f"Warning: {lm_head_key} not found. Ensure model doesn't use tied weights manually." + ) loaded_keys = set() @@ -123,7 +188,9 @@ def _load_weights( file_to_keys_map[full_path].append(key) for file_path, keys in file_to_keys_map.items(): - self._load_file_content(file_path, keys, embed_key, lm_head_key) + self._load_file_content( + file_path, keys, resolved_embed_key, resolved_lm_head_key + ) loaded_keys.update(keys) if tie_weights: @@ -132,9 +199,13 @@ def _load_weights( ) self.lm_head.weight = self.embed_tokens.weight - if embed_key not in loaded_keys: + if resolved_embed_key not in loaded_keys: raise RuntimeError("Failed to load embeddings.") - if not tie_weights and lm_head_key not in loaded_keys: + if ( + not tie_weights + and resolved_lm_head_key is not None + and resolved_lm_head_key not in loaded_keys + ): print( "Warning: LM Head weights were not found (and tie_weights is False). Head is random." ) @@ -144,7 +215,7 @@ def _load_file_content( file_path: str, keys_to_extract: list, target_embed_key: str, - target_head_key: str, + target_head_key: Optional[str], ): """Helper to load specific keys from a file""" print(f"Loading {keys_to_extract} from {os.path.basename(file_path)}...") @@ -171,7 +242,7 @@ def _load_file_content( if k == target_embed_key: self.embed_tokens.weight.data.copy_(tensor) print(" -> Loaded Embeddings") - elif k == target_head_key: + elif target_head_key is not None and k == target_head_key: if tensor.shape == self.lm_head.weight.data.shape: self.lm_head.weight.data.copy_(tensor) print(" -> Loaded LM Head") From c19b96378202124d84565591fcfd9975f075dca9 Mon Sep 17 00:00:00 2001 From: gss <2783977641@qq.com> Date: Sat, 28 Feb 2026 11:49:19 +0800 Subject: [PATCH 2/3] Update train_dflash.py --- scripts/train_dflash.py | 63 ++++++++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 726496fd5..ce5b81a73 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -280,8 +280,6 @@ def build_models( ): from transformers import Qwen3VLForConditionalGeneration - # If you're using torch==2.9.1, please ensure you have cuDNN >= 9.15 installed to avoid a performance - # regression with Conv3D. You can run `pip install nvidia-cudnn-cu12==9.16.0.29` to immediately fix it. target_model = HFDFlashTargetModel( Qwen3VLForConditionalGeneration.from_pretrained( pretrained_model_name_or_path=args.target_model_path, @@ -300,8 +298,6 @@ def build_models( ): from transformers import Qwen3VLMoeForConditionalGeneration - # If you're using torch==2.9.1, please ensure you have cuDNN >= 9.15 installed to avoid a performance - # regression with Conv3D. You can run `pip install nvidia-cudnn-cu12==9.16.0.29` to immediately fix it. target_model = HFDFlashTargetModel( Qwen3VLMoeForConditionalGeneration.from_pretrained( pretrained_model_name_or_path=args.target_model_path, @@ -416,20 +412,55 @@ def build_dataloader( f"{args.target_model_path}" ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + cache_dir = os.path.join(args.cache_dir, "processed_dataset") + dist_enabled = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if dist_enabled else 0 + world_size = dist.get_world_size() if dist_enabled else 1 train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] - train_eagle3_dataset = build_eagle3_dataset( - dataset=train_dataset, - tokenizer=tokenizer, - chat_template=args.chat_template, - max_length=args.max_length, - is_preformatted=args.is_preformatted, - is_vlm=is_vlm, - processor=processor, - cache_dir=os.path.join(args.cache_dir, "processed_dataset"), - cache_key=cache_key, - num_proc=args.build_dataset_num_proc, - ) + if world_size > 1: + if rank == 0: + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + dist.barrier() + else: + dist.barrier() + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + # Rank 0 has finished preprocessing at this point; other ranks only need cache reads. + num_proc=1, + ) + else: + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) min_loss_tokens = 2 * args.block_size original_size = len(train_eagle3_dataset) From aff0de1704fd62e4c4294b6b0efba2695c3f8da9 Mon Sep 17 00:00:00 2001 From: gss <2783977641@qq.com> Date: Sat, 28 Feb 2026 11:56:40 +0800 Subject: [PATCH 3/3] fix run_qwen3_vl_8b_dflash_online.sh --- examples/run_qwen3_vl_8b_dflash_online.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_qwen3_vl_8b_dflash_online.sh b/examples/run_qwen3_vl_8b_dflash_online.sh index ef2e81308..c02085416 100644 --- a/examples/run_qwen3_vl_8b_dflash_online.sh +++ b/examples/run_qwen3_vl_8b_dflash_online.sh @@ -21,13 +21,13 @@ torchrun \ --train-data-path $ROOT_DIR/cache/dataset/allava4v-mix-20k_train.localimg_regen.jsonl \ --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ --min-pixels 50176 \ - --max-pixels 1003520 \ + --max-pixels 802816 \ --output-dir $ROOT_DIR/outputs/qwen3-vl-8b-allava4v20k-dflash \ --cache-dir $ROOT_DIR/cache \ --num-epochs 6 \ --batch-size 2 \ - --learning-rate 6e-4 \ - --warmup-ratio 0.04 \ + --learning-rate 1e-4 \ + --warmup-ratio 0.08 \ --max-grad-norm 1.0 \ --max-length 4096 \ --num-draft-layers 5 \