-
Notifications
You must be signed in to change notification settings - Fork 128
update readme: add support for enflame S60 #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @fuheaven, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates full support for the Enflame S60 GCU, a new hardware platform, into the system. It involves updating documentation, adding device-specific implementations for core deep learning operations like attention and normalization, and providing ready-to-use configuration files and scripts for various models. The changes ensure that models can run efficiently and correctly on the Enflame S60, expanding the range of supported hardware. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds comprehensive support for the Enflame S60 (GCU) platform. The changes are well-structured, including documentation updates, new configuration files, and platform-specific implementations for device handling, attention, layer normalization, and rotary position embeddings. The use of registries for loading platform-specific components is a good design choice for extensibility.
My review focuses on correctness, maintainability, and usability. I've identified a critical bug in the SDPA fallback implementation, significant code duplication that impacts maintainability, and minor usability issues in the new shell scripts. Addressing these points will improve the quality and robustness of the new platform support.
| def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, causal=False, dropout_p=0.0): | ||
| """ | ||
| Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available. | ||
| Args: | ||
| q: [B*Lq, Nq, C] Query tensor (flattened batch) | ||
| k: [B*Lk, Nk, C] Key tensor (flattened batch) | ||
| v: [B*Lk, Nk, C] Value tensor (flattened batch) | ||
| cu_seqlens_q: [B+1] Cumulative sequence lengths for queries | ||
| max_seqlen_q: Maximum sequence length in queries | ||
| causal: Whether to apply causal mask | ||
| dropout_p: Dropout probability | ||
| Returns: | ||
| Output tensor: [B*Lq, C] (flattened batch) | ||
| """ | ||
| # Reshape from flattened format to batched format | ||
| bs = cu_seqlens_q.shape[0] - 1 | ||
|
|
||
| # Reshape q, k, v to [B, L, Nq, C] | ||
| q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1]) | ||
| k = k.reshape(bs, max_seqlen_q, k.shape[-2], k.shape[-1]) | ||
| v = v.reshape(bs, max_seqlen_q, v.shape[-2], v.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _sdpa_fallback method incorrectly uses max_seqlen_q to reshape the key (k) and value (v) tensors. This is only correct for self-attention and will fail or produce incorrect results for cross-attention where query and key/value sequence lengths can differ. The function should accept max_seqlen_kv and use it for reshaping k and v. Note that the call site at line 101 must also be updated to pass max_seqlen_kv.
def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, max_seqlen_kv, causal=False, dropout_p=0.0):
"""
Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available.
Args:
q: [B*Lq, Nq, C] Query tensor (flattened batch)
k: [B*Lk, Nk, C] Key tensor (flattened batch)
v: [B*Lk, Nk, C] Value tensor (flattened batch)
cu_seqlens_q: [B+1] Cumulative sequence lengths for queries
max_seqlen_q: Maximum sequence length in queries
max_seqlen_kv: Maximum sequence length in keys/values
causal: Whether to apply causal mask
dropout_p: Dropout probability
Returns:
Output tensor: [B*Lq, C] (flattened batch)
"""
# Reshape from flattened format to batched format
bs = cu_seqlens_q.shape[0] - 1
# Reshape q, k, v to [B, L, Nq, C]
q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1])
k = k.reshape(bs, max_seqlen_kv, k.shape[-2], k.shape[-1])
v = v.reshape(bs, max_seqlen_kv, v.shape[-2], v.shape[-1])| del weight_tensor | ||
|
|
||
|
|
||
| class LayerNormWeightTemplate(metaclass=ABCMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The newly added LayerNormWeightTemplate contains a large amount of code that is duplicated from the existing RMSWeightTemplate. This includes methods like load, to_cuda, to_cpu, state_dict, and the various loading helpers. To improve maintainability and reduce redundancy, this shared logic should be refactored into a common base class from which both RMSWeightTemplate and LayerNormWeightTemplate can inherit.
| # Try to get rope function from registry first (for platform-specific implementations) | ||
| if rope_type in ROPE_REGISTER: | ||
| rope_class = ROPE_REGISTER[rope_type] | ||
| self.rope_instance = rope_class() | ||
|
|
||
| # Create a wrapper function that matches the expected signature | ||
| def rope_wrapper(xq, xk, cos_sin_cache): | ||
| return self.rope_instance.apply(xq, xk, cos_sin_cache) | ||
|
|
||
| rope_func = rope_wrapper | ||
| else: | ||
| # Fallback to hardcoded functions | ||
| rope_func = rope_funcs.get(rope_type, apply_wan_rope_with_torch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to select the RoPE function is well-designed for extensibility. However, creating the rope_wrapper function inside the __init__ method on each instantiation can be slightly inefficient. Since the wrapper's logic is consistent, you could define it at the module level or as a static/regular method on the class to avoid redefining it repeatedly. This is a minor optimization but good practice.
| try: | ||
| torch.cuda.set_device(dist.get_rank()) | ||
| except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad Exception can hide unexpected errors. It's better to catch specific exceptions that you anticipate, such as AttributeError or RuntimeError, to make the error handling more precise and avoid accidentally suppressing unrelated issues.
except (AttributeError, RuntimeError):
# If all else fails, just log a warning| cu_seqlens_kv = cu_seqlens_kv.to(gcu_device) | ||
|
|
||
| # Ensure data types are half precision | ||
| import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| DTYPE_MAP = { | ||
| "BF16": torch.bfloat16, | ||
| "FP16": torch.float16, | ||
| "FP32": torch.float32, | ||
| "bf16": torch.bfloat16, | ||
| "fp16": torch.float16, | ||
| "fp32": torch.float32, | ||
| "torch.bfloat16": torch.bfloat16, | ||
| "torch.float16": torch.float16, | ||
| "torch.float32": torch.float32, | ||
| } | ||
|
|
||
|
|
||
| @lru_cache(maxsize=None) | ||
| def GET_DTYPE(): | ||
| RUNNING_FLAG = os.getenv("DTYPE", "BF16") | ||
| assert RUNNING_FLAG in ["BF16", "FP16"] | ||
| return DTYPE_MAP[RUNNING_FLAG] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DTYPE_MAP dictionary and GET_DTYPE function are also defined in lightx2v_platform/ops/norm/norm_template.py. To avoid code duplication, these common utilities should be moved to a shared file (e.g., within lightx2v_platform/base/ or a new lightx2v_platform/ops/utils.py) and imported where needed.
| lightx2v_path= | ||
| model_path= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.
| lightx2v_path= | |
| model_path= | |
| lightx2v_path= | |
| model_path= | |
| if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then | |
| echo "Error: lightx2v_path and model_path must be set in the script." | |
| exit 1 | |
| fi |
| lightx2v_path= | ||
| model_path= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.
| lightx2v_path= | |
| model_path= | |
| lightx2v_path= | |
| model_path= | |
| if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then | |
| echo "Error: lightx2v_path and model_path must be set in the script." | |
| exit 1 | |
| fi |
| lightx2v_path= | ||
| model_path= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script requires lightx2v_path and model_path to be set, but they are left empty. This will cause the script to fail when executed. To improve usability, it's good practice to add a check that ensures these variables are set and exits with an informative error message if they are not.
| lightx2v_path= | |
| model_path= | |
| lightx2v_path= | |
| model_path= | |
| if [ -z "${lightx2v_path}" ] || [ -z "${model_path}" ]; then | |
| echo "Error: lightx2v_path and model_path must be set in the script." | |
| exit 1 | |
| fi |
No description provided.