Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 16, 2025

Description

SequenceDescriptor's from_segment_ids_and_pos() accepts the segment_ids and an optional segment_pos as input. This class is supposed to serve as a convenience method to do two things:

  1. Stuff the segment_ids and segment_pos in a SequenceDescriptor object for TE to use downstream
  2. If the segment_pos is not passed, then calculate/extrapolate it

In it's current form, the second functionality gives incorrect results for THD + non-reordered and THD + reordered cases as it merely uses an arange to calculate the segment_pos naively. This could result in incorrect masking for these cases.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR makes few changes:

  1. Passed two new args to this function from_segment_ids_and_pos() : is_thd and is_segment_ids_reordered- the only cases that this function can currently guarantee to support is BSHD with and without load balancing and, THD without load balancing.
  2. BSHD with load balancing is supported natively because the segment_ids and segment_pos are not reordered before passing onto from_segment_ids_and_pos(). However, if the segment_pos are reordered and passed to from_segment_ids_and_pos() it will assert
  3. If THD + reordered use cases calls the function from_segment_ids_and_pos(), it will assert
  4. The fused attn tests were modified to account for these two new args
    • For THD fused attn non-CP tests, segment_pos=None is passed so as to exercise the newly added THD path in from_segment_ids_and_pos()
    • For THD fused attn CP tests, segment_pos is explicitly passed (as before, not a new change)
    • For BSHD fused attn CP tests, segment_pos=None is passed so as to exercise the default BSHD path to generate segment_pos (as before, not a new change)

Impact on user of the API:

  1. These two new args, is_thd and is_segment_ids_reordered are not Optional and hence they will cause a TypeError for current users of this API - a breaking change. However, this is needed to ensure correct usage of this API
  2. The user is now expected to let the API know whether this is a THD or BSHD layout and whether the segment_ids are reordered or not. It is expected that the segment_ids passed will be reordered only for THD load balancing. For all other cases the segment_ids should not be reordered

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani added attention jax bug Something isn't working labels Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 17, 2025 02:16
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 17, 2025

Greptile Summary

Fixed incorrect segment position calculation in SequenceDescriptor.from_segment_ids_and_pos() by adding proper THD support. Previously used naive arange that resulted in incorrect masking for THD cases.

Key changes:

  • Added two new required parameters: is_thd and is_segment_ids_reordered to from_segment_ids_and_pos()
  • Implemented proper THD segment position generation that detects segment boundaries and calculates positions relative to each segment start
  • Added assertions to prevent unsupported combinations (THD + reordered when segment_pos=None)
  • Updated all test call sites to pass the new required parameters

Breaking change: Existing users must now provide is_thd and is_segment_ids_reordered arguments, causing TypeError if not updated.

Confidence Score: 1/5

  • This PR cannot be merged - contains critical logic bug that will cause runtime failures
  • Lines 833-844 in attention.py contain contradictory assertions that will always fail when is_segment_ids_reordered=True. The first assertion requires not is_thd (line 833) while the second requires is_thd (line 839), making it impossible to satisfy both. This bug wasn't caught by tests because they always pass explicit segment_pos when is_segment_ids_reordered=True, avoiding the buggy code path.
  • Pay immediate attention to transformer_engine/jax/attention.py lines 833-844 - the contradictory assertions must be fixed before merge

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Critical logic bug: contradictory assertions (lines 833-844) will always fail when is_segment_ids_reordered=True. Added THD segment position generation logic.
tests/jax/test_fused_attn.py Updated tests to pass new required params is_thd and is_segment_ids_reordered. Minor comment inaccuracy on line 671.

Sequence Diagram

sequenceDiagram
    participant User
    participant SequenceDescriptor
    participant generate_default_pos
    
    User->>SequenceDescriptor: from_segment_ids_and_pos(segment_ids, segment_pos=None, is_thd, is_segment_ids_reordered)
    
    alt segment_pos is None
        SequenceDescriptor->>SequenceDescriptor: Check if is_segment_ids_reordered
        
        alt is_segment_ids_reordered = True
            SequenceDescriptor->>SequenceDescriptor: assert not is_thd (line 833)
            SequenceDescriptor->>SequenceDescriptor: assert is_thd (line 839)
            Note right of SequenceDescriptor: BUG: Contradictory assertions!<br/>Will always fail
        end
        
        SequenceDescriptor->>generate_default_pos: Call for q_seg_ids
        
        alt is_thd = True
            generate_default_pos->>generate_default_pos: Find segment boundaries
            generate_default_pos->>generate_default_pos: Calculate segment_start_offsets
            generate_default_pos->>generate_default_pos: Compute positions relative to segment start
            generate_default_pos->>generate_default_pos: Apply padding mask
            generate_default_pos-->>SequenceDescriptor: Return THD segment_pos
        else is_thd = False (BSHD)
            generate_default_pos->>generate_default_pos: Use simple arange
            generate_default_pos-->>SequenceDescriptor: Return BSHD segment_pos
        end
        
        SequenceDescriptor->>generate_default_pos: Call for kv_seg_ids
        generate_default_pos-->>SequenceDescriptor: Return kv segment_pos
        
    else segment_pos provided
        SequenceDescriptor->>SequenceDescriptor: Expand to pair
    end
    
    SequenceDescriptor-->>User: Return SequenceDescriptor with segment_ids and segment_pos
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 841-842 (link)

    logic: q_seg_ids and kv_seg_ids are used here but not defined until line 847. This will cause a NameError at runtime when segment_pos is None.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 829-832 (link)

    style: f-strings in assert messages won't be evaluated until assertion fails

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

KshitijLakhani and others added 7 commits December 22, 2025 11:30
…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/incorrect-sequence-descr-from-seg-ids-only branch from 499f4de to 217ea58 Compare December 22, 2025 19:31
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 833-837 (link)

    logic: warning is shown even for the valid case (BSHD without load balancing)

    this will cause users to see warnings every time they correctly use segment_pos=None with BSHD layout

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Kshitij Janardan Lakhani and others added 2 commits December 22, 2025 16:42
2. Add support for THD in from_segment_ids_and_pos()
3. Assert if load balanced segment_ids is passed to generate a segment_pos

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/jax/attention.py, line 880-881 (link)

    syntax: segment_ids is undefined, should be seg_ids (the parameter name on line 836)

  2. tests/jax/test_fused_attn.py, line 680-681 (link)

    syntax: is_thd and is_load_balanced are keyword-only parameters (note the * on line 800 in attention.py), so they must be passed as keyword arguments

  3. tests/jax/test_fused_attn.py, line 709-710 (link)

    syntax: is_thd and is_load_balanced are keyword-only parameters (note the * on line 800 in attention.py), so they must be passed as keyword arguments

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Kshitij Janardan Lakhani added 2 commits December 22, 2025 16:55
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani changed the title [JAX] Fix incorrect calculation of segment pos from segment ids [JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API Dec 23, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. tests/jax/test_fused_attn.py, line 706-711 (link)

    logic: passing segment_pos=None with is_load_balanced=True will trigger assertion error on attention.py:830

    the BSHD + load balanced + SegmentIDs combination should either:

    1. pass explicit segment_pos (like THD does on line 677-678), or
    2. skip this format when load balanced

    currently only works because default mesh_shape=(1,1,1) means cp_size=1

  2. transformer_engine/jax/attention.py, line 794 (link)

    style: misleading TODO - THD layout (non load balanced) IS now supported by the implementation below (lines 837-878)

  3. transformer_engine/jax/attention.py, line 839-840 (link)

    style: assumption not enforced - if sequence starts with padding (seg_ids[0]=0), first_is_segment=True will incorrectly mark it as segment start

    consider asserting seg_ids[:, 0] != 0 or handle leading padding explicitly

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as draft December 23, 2025 01:15
…actually reordered segment_ids and segment_pos

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
pre-commit-ci bot and others added 3 commits December 23, 2025 22:18
…ncing

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 24, 2025 00:23
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/jax/attention.py, line 833-844 (link)

    logic: contradictory assertions will always fail when is_segment_ids_reordered=True

    line 833 asserts not is_thd (meaning is_thd must be False)
    line 839 asserts is_thd (meaning is_thd must be True)

    both assertions cannot be satisfied simultaneously, causing the function to always fail when segment_pos=None and is_segment_ids_reordered=True

  2. transformer_engine/jax/attention.py, line 850-851 (link)

    style: assumption that first token is always a valid segment may be incorrect if THD format allows leading padding

    if seg_ids[0, 0] == 0 (first token is padding), this assumption causes incorrect segment_pos calculation. verify THD format guarantees no leading padding or add validation

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

for more information, see https://pre-commit.ci

Fix lint errors

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/incorrect-sequence-descr-from-seg-ids-only branch from 4101df5 to 74a352e Compare December 24, 2025 01:46
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/jax/attention.py, line 833-844 (link)

    logic: contradictory assertions on lines 833 and 839 - both will execute when is_segment_ids_reordered=True, but line 833 requires is_thd=False while line 839 requires is_thd=True

    this means when is_segment_ids_reordered=True:

    • if is_thd=True: line 833 fails with "not supported for load balanced reordered (Striped) THD"
    • if is_thd=False: line 839 fails with "not supported for load balanced reordered (Dual Chunk) BSHD"

    based on line 822 comments and PR description, only THD with load balancing should set is_segment_ids_reordered=True, so line 839's assertion should be removed (it contradicts the intended behavior)

  2. transformer_engine/jax/attention.py, line 856 (link)

    style: potential edge case: (seg_ids[..., 1:] != 0) condition may not handle all segment transitions correctly

    consider sequence: [1, 1, 0, 2, 2] where 0 is padding in the middle

    • transition from 1→0 at index 2: (1 != 0) & (0 != 0) = False (correctly not marked as segment start)
    • transition from 0→2 at index 3: (0 != 2) & (2 != 0) = True (marked as segment start)

    however, the assumption on line 850 is that "the first token belongs to a segment and is not a padded token", which conflicts with the possibility of having padding in the middle. verify whether THD format allows mid-sequence padding, and if so, this logic needs adjustment. does THD format allow padding (segment_id=0) in the middle of a sequence, or is padding only at the end?

  3. tests/jax/test_fused_attn.py, line 687-689 (link)

    style: simplify boolean expression

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI, thanks!

@ksivaman
Copy link
Member

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/jax/attention.py, line 833-844 (link)

    logic: contradictory assertions will always fail when is_segment_ids_reordered=True

    Line 833 asserts not is_thd (THD must be False), but line 839 asserts is_thd (THD must be True). When is_segment_ids_reordered=True, the first assertion will pass only if is_thd=False, but then the second assertion will always fail because it requires is_thd=True.

    Based on the comments and documentation (line 822: "Only THD with load balancing is expected to have this flag set to True"), the correct logic should be:

  2. tests/jax/test_fused_attn.py, line 671-672 (link)

    style: comment doesn't match implementation logic

    Comment says "if no CP and load balancing", but the code on line 683 does the opposite - it explicitly passes segment_pos when cp_size > 1 and self.cp_load_balanced (i.e., WITH CP and load balancing), and passes None otherwise.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@ksivaman ksivaman merged commit 26c82db into NVIDIA:main Dec 31, 2025
13 of 14 checks passed
@KshitijLakhani
Copy link
Collaborator Author

@ksivaman this PR has a breaking change (FYI for TE release notes)

@KshitijLakhani
Copy link
Collaborator Author

KshitijLakhani commented Jan 5, 2026

@mgoldfarb-nvidia @huanghua1994 @mingxu1067 - even though this PR has been merged, it would be great to get your feedback on this, in case you spot something that needs to be reverted. Thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention bug Something isn't working jax

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants