Skip to content

Fix rank mismatch in MaxText synthetic data sharding#4122

Open
lukebaumann wants to merge 1 commit into
AI-Hypercomputer:mainfrom
lukebaumann:fix-synthetic-sharding
Open

Fix rank mismatch in MaxText synthetic data sharding#4122
lukebaumann wants to merge 1 commit into
AI-Hypercomputer:mainfrom
lukebaumann:fix-synthetic-sharding

Conversation

@lukebaumann

@lukebaumann lukebaumann commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR fixes a rank mismatch issue in MaxText synthetic data sharding during data loading.

Root Cause

SyntheticDataIterator was using the legacy config.data_sharding which resolved to a 1D sharding spec P(('data', 'fsdp')) (after filtering). When applied to 2D output tensors of shape (batch, seq), JAX sharding validation failed with AssertionError: (1, 2) (rank mismatch) on JAX builds that strictly enforce this check.

Solution

Modified SyntheticDataIterator to use sharding.get_input_data_sharding(config, mesh). This helper uses config.input_data_sharding_logical_axes which correctly resolves to a 2D sharding spec P(('data', 'fsdp'), None), matching the rank of the output tensors.

Also removed the unused PartitionSpec as P import in synthetic_data_processing.py.

Tests

Added a new unit test tests/unit/synthetic_data_test.py which:

  1. Forces 4 CPU devices.
  2. Creates a 2x2 mesh.
  3. Initializes SyntheticDataIterator with llama3.1-8b config.
  4. Verifies the output shape is (8, 16) and sharding is exactly P(('data', 'fsdp'), None).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

* Change SyntheticDataIterator to use get_input_data_sharding instead of manual 1D sharding.
* This ensures the sharding spec is 2D, matching the rank of the output tensors.
* Fixes AssertionError: (1, 2) in JAX sharding validation on some JAX builds.
* Remove unused PartitionSpec import in synthetic_data_processing.py.
* Add unit test `tests/unit/synthetic_data_test.py` to verify synthetic data sharding.
self.config = config
data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
data_pspec_shardings = sharding.get_input_data_sharding(config, mesh)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love you are using functions in sharding.py!

"enable_checkpointing": False,
"dataset_type": "synthetic",
"model_name": "llama3.1-8b",
"max_target_length": 16,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add another testing instance using explicit sharding? Basically everything are same except adding shard_mode=explicit? Explicit sharding data pipeline has been broken but never get protected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants