Skip to content

Conversation

@fuheaven
Copy link
Contributor

@fuheaven fuheaven commented Jan 6, 2026

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @fuheaven, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates full support for the Enflame S60 GCU, a new hardware platform, into the system. It involves updating documentation, adding device-specific implementations for core deep learning operations like attention and normalization, and providing ready-to-use configuration files and scripts for various models. The changes ensure that models can run efficiently and correctly on the Enflame S60, expanding the range of supported hardware.

Highlights

  • Enflame S60 (GCU) Support: Introduced comprehensive support for the Enflame S60 Graphics Computing Unit (GCU) platform, enabling models to leverage its specialized hardware capabilities.
  • Platform-Specific Optimizations: Implemented GCU-optimized versions of Flash Attention, Layer Normalization, and Rotary Positional Embedding (RoPE) to ensure efficient execution on the Enflame S60.
  • Configuration and Scripting: Added new configuration files and inference scripts for Qwen-Image I2I, WAN2.1 I2V, and WAN2.1 T2V models, tailored for the Enflame S60 platform, including distributed inference setups.
  • Registry Integration: Extended the platform registry to dynamically load platform-specific implementations for Layer Normalization and RoPE, enhancing modularity and extensibility.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds comprehensive support for the Enflame S60 (GCU) platform. The changes are well-structured, including documentation updates, new configuration files, and platform-specific implementations for device handling, attention, layer normalization, and rotary position embeddings. The use of registries for loading platform-specific components is a good design choice for extensibility.

My review focuses on correctness, maintainability, and usability. I've identified a critical bug in the SDPA fallback implementation, significant code duplication that impacts maintainability, and minor usability issues in the new shell scripts. Addressing these points will improve the quality and robustness of the new platform support.

Comment on lines +211 to +232
def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, causal=False, dropout_p=0.0):
"""
Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available.
Args:
q: [B*Lq, Nq, C] Query tensor (flattened batch)
k: [B*Lk, Nk, C] Key tensor (flattened batch)
v: [B*Lk, Nk, C] Value tensor (flattened batch)
cu_seqlens_q: [B+1] Cumulative sequence lengths for queries
max_seqlen_q: Maximum sequence length in queries
causal: Whether to apply causal mask
dropout_p: Dropout probability
Returns:
Output tensor: [B*Lq, C] (flattened batch)
"""
# Reshape from flattened format to batched format
bs = cu_seqlens_q.shape[0] - 1

# Reshape q, k, v to [B, L, Nq, C]
q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1])
k = k.reshape(bs, max_seqlen_q, k.shape[-2], k.shape[-1])
v = v.reshape(bs, max_seqlen_q, v.shape[-2], v.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _sdpa_fallback method incorrectly uses max_seqlen_q to reshape the key (k) and value (v) tensors. This is only correct for self-attention and will fail or produce incorrect results for cross-attention where query and key/value sequence lengths can differ. The function should accept max_seqlen_kv and use it for reshaping k and v. Note that the call site at line 101 must also be updated to pass max_seqlen_kv.

    def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, max_seqlen_kv, causal=False, dropout_p=0.0):
        """
        Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available.

        Args:
            q: [B*Lq, Nq, C] Query tensor (flattened batch)
            k: [B*Lk, Nk, C] Key tensor (flattened batch)
            v: [B*Lk, Nk, C] Value tensor (flattened batch)
            cu_seqlens_q: [B+1] Cumulative sequence lengths for queries
            max_seqlen_q: Maximum sequence length in queries
            max_seqlen_kv: Maximum sequence length in keys/values
            causal: Whether to apply causal mask
            dropout_p: Dropout probability
        Returns:
            Output tensor: [B*Lq, C] (flattened batch)
        """
        # Reshape from flattened format to batched format
        bs = cu_seqlens_q.shape[0] - 1

        # Reshape q, k, v to [B, L, Nq, C]
        q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1])
        k = k.reshape(bs, max_seqlen_kv, k.shape[-2], k.shape[-1])
        v = v.reshape(bs, max_seqlen_kv, v.shape[-2], v.shape[-1])

del weight_tensor


class LayerNormWeightTemplate(metaclass=ABCMeta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The newly added LayerNormWeightTemplate contains a large amount of code that is duplicated from the existing RMSWeightTemplate. This includes methods like load, to_cuda, to_cpu, state_dict, and the various loading helpers. To improve maintainability and reduce redundancy, this shared logic should be refactored into a common base class from which both RMSWeightTemplate and LayerNormWeightTemplate can inherit.

Comment on lines +42 to +54
# Try to get rope function from registry first (for platform-specific implementations)
if rope_type in ROPE_REGISTER:
rope_class = ROPE_REGISTER[rope_type]
self.rope_instance = rope_class()

# Create a wrapper function that matches the expected signature
def rope_wrapper(xq, xk, cos_sin_cache):
return self.rope_instance.apply(xq, xk, cos_sin_cache)

rope_func = rope_wrapper
else:
# Fallback to hardcoded functions
rope_func = rope_funcs.get(rope_type, apply_wan_rope_with_torch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to select the RoPE function is well-designed for extensibility. However, creating the rope_wrapper function inside the __init__ method on each instantiation can be slightly inefficient. Since the wrapper's logic is consistent, you could define it at the module level or as a static/regular method on the class to avoid redefining it repeatedly. This is a minor optimization but good practice.

Comment on lines +92 to +94
try:
torch.cuda.set_device(dist.get_rank())
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected errors. It's better to catch specific exceptions that you anticipate, such as AttributeError or RuntimeError, to make the error handling more precise and avoid accidentally suppressing unrelated issues.

            except (AttributeError, RuntimeError):
                # If all else fails, just log a warning

cu_seqlens_kv = cu_seqlens_kv.to(gcu_device)

# Ensure data types are half precision
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

import math should be at the top of the file, not inside a method. This follows standard Python style guides (like PEP 8) and avoids re-importing the module on every call.

Comment on lines +7 to +24
DTYPE_MAP = {
"BF16": torch.bfloat16,
"FP16": torch.float16,
"FP32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"torch.bfloat16": torch.bfloat16,
"torch.float16": torch.float16,
"torch.float32": torch.float32,
}


@lru_cache(maxsize=None)
def GET_DTYPE():
RUNNING_FLAG = os.getenv("DTYPE", "BF16")
assert RUNNING_FLAG in ["BF16", "FP16"]
return DTYPE_MAP[RUNNING_FLAG]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The DTYPE_MAP dictionary and GET_DTYPE function are also defined in lightx2v_platform/ops/norm/norm_template.py. To avoid code duplication, these common utilities should be moved to a shared file (e.g., within lightx2v_platform/base/ or a new lightx2v_platform/ops/utils.py) and imported where needed.

Comment on lines +6 to +7
lightx2v_path=
model_path=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.

Suggested change
lightx2v_path=
model_path=
lightx2v_path=
model_path=
if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then
echo "Error: lightx2v_path and model_path must be set in the script."
exit 1
fi

Comment on lines +12 to +13
lightx2v_path=
model_path=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.

Suggested change
lightx2v_path=
model_path=
lightx2v_path=
model_path=
if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then
echo "Error: lightx2v_path and model_path must be set in the script."
exit 1
fi

Comment on lines +6 to +7
lightx2v_path=
model_path=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.

Suggested change
lightx2v_path=
model_path=
lightx2v_path=
model_path=
if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then
echo "Error: lightx2v_path and model_path must be set in the script."
exit 1
fi

@helloyongyang helloyongyang merged commit c38d132 into ModelTC:main Jan 6, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants