Add audio model adapters and improve SSM partition specs#1352
Open
vishesh9131 wants to merge 31 commits intoapple:mainfrom
Open
Add audio model adapters and improve SSM partition specs#1352vishesh9131 wants to merge 31 commits intoapple:mainfrom
vishesh9131 wants to merge 31 commits intoapple:mainfrom
Conversation
Introduces AudioModelAdapter and ASRModelAdapter for efficient fine-tuning of audio models, along with comprehensive tests. Adds input_grain_csv_test.py for CSV/TSV input processing tests. Updates SSM partition spec helpers to support sequence parallelism and adds corresponding tests in ssm_test.py. Updates .gitignore to exclude run_specific_test.sh.
The function returns PartitionSpec but was annotated as dict[str, PartitionSpec]. This mismatch was causing CI test failures.
The test file was added without implementing the csv_dataset and tsv_dataset functions in input_grain.py. This caused import errors in CI.
- Fixed TypeError by wrapping single tensor inputs in tuples for F() calls in both adapter.py and adapter_test.py - Fixed parameter count assertion by including layer_norm.bias in the count calculation
- Fixed state passing by extracting encoder_adapter and decoder_adapter from the full state dict in adapt_encoder_features and adapt_decoder_features - Fixed expected parameter count from 33664 to 33600 in test_parameter_counts
- Made prng_key and state required parameters in adapt_encoder_features and adapt_decoder_features - Removed fallback direct module calls which don't work outside invocation context - Updated test_direct_call_fallback to pass required prng_key and state parameters
- Install uv and use uv pip install to respect ml-dtypes override - Fixes dependency conflict with ml-dtypes>=0.5,<0.6 vs tensorflow<0.5.0
- uv pip requires either a venv or --system flag - actions/setup-python creates venv but uv doesn't auto-detect it - Use --system to install into the Python environment directly
- jit_mamba_scan needs 6 positional args for shard_map compatibility - This is a nested function within JAX jit decorator - pylint 2.17+ flags this as R0917 (too-many-positional-arguments)
…ents - MambaConfig has 24 parameters to match HuggingFace's PretrainedConfig - pylint 2.17 added R0917 check which flags this legitimate case - Add disable comment to suppress the warning
- uv pip install works correctly with actions/setup-python venv - Only uv pip install --system was causing issues - This matches the Dockerfile approach which works correctly
- uv doesn't work properly with actions/setup-python in GitHub Actions - Upstream uses pip successfully with same dependency versions - This should work correctly as-is
- pip doesn't respect tool.uv override-dependencies - Set VIRTUAL_ENV to pythonLocation so uv detects the venv - This allows uv to honor ml-dtypes>=0.5,<0.6 override
- Environment variables don't persist across GitHub Actions steps - Export VIRTUAL_ENV in the same run block as uv pip install - This ensures uv can detect the venv and honor ml-dtypes override
- Revert to pip (matching upstream) with legacy resolver - Legacy resolver can install despite ml-dtypes version conflict - This allows tensorflow 2.17.1 and jax 0.6.2 to coexist
- isort 7.0 has stricter import ordering than 5.x - Fixed 5 files to match CI expectations - Added pylint disables for too-many-positional-arguments in gpu_attention.py - These are pre-existing code style issues, not related to our changes
- Legacy resolver may skip or break google-cloud-aiplatform installation - Force reinstall with --no-deps to fix pytype import errors - This ensures pytype can analyze vertexai_tensorboard.py
- Legacy resolver breaks transformers package installation - Add transformers==4.51.3 to force reinstall for pytype - This fixes param_converter.py import errors
- Create .venv and export VIRTUAL_ENV + PATH via GITHUB_ENV/GITHUB_PATH - Install pip+uv, then install extras with uv (honors ml-dtypes override) - Remove legacy resolver and post-install hacks (aiplatform/transformers) - Ensures pytype can resolve imports reliably
… after uv install
Contributor
Author
|
Hey @samos123 |
Refactored multiple test files to use parenthesized 'with' statements for improved readability and consistency. Added or adjusted trailing commas, improved docstring formatting, and added or updated pylint disables where appropriate. No functional changes were made.
Contributor
Author
|
@zhiyun @ruomingp @changlan @jiarui-lu2 |
Contributor
Author
|
@changlan |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
took me long enough to ship this one
Hey @markblee @jiarui-lu2, it’s been ages since my last merge... feels good to finally ship this one.
New Features
Audio Model Adapters (
axlearn/audio/adapter.py,axlearn/audio/adapter_test.py)AudioModelAdapter, a general-purpose bottleneck adapter for fine-tuning audio models.scale=0.01).ASRModelAdapterfor encoder–decoder ASR models.adapt_encoder_featuresandadapt_decoder_featuresmethods with functional state handling.SSM Partition Spec Improvements (
axlearn/common/ssm.py,axlearn/common/ssm_test.py)default_mamba_dim_to_partition_specsnow shards the sequence dimension when"seq"is present.default_output_partition_specnow returns aPartitionSpecinstead of a dictionary."seq", model over"model".PallasLinearScanMambaRecurrencefor consistency.Bug Fixes
default_output_partition_spec.F().layer_norm.bias.input_grain_csv_test.py.Testing
All tests pass with broad coverage.
Impact