Skip to content

perf(optim): implement Gram Newton-Schulz for Muon#301

Open
KakaruHayate wants to merge 1 commit into
openvpi:muon_lynxnet2from
KakaruHayate:gram_ns
Open

perf(optim): implement Gram Newton-Schulz for Muon#301
KakaruHayate wants to merge 1 commit into
openvpi:muon_lynxnet2from
KakaruHayate:gram_ns

Conversation

@KakaruHayate
Copy link
Copy Markdown

Description

This PR optimizes the orthogonalization step in the Muon optimizer by integrating the Gram Newton-Schulz (GramNS) algorithm. Additionally, it addresses a known training stability issue when using float16.

Key Changes

  • Stability Fix: We discovered that performing the initial spectral normalization in float32 is sufficient to prevent gradient overflow/underflow. This allows the subsequent Newton-Schulz iterations to run stably in float16 without strictly requiring bfloat16.
  • GramNS Integration: Replaced the standard Newton-Schulz 5 (NS5) iteration with GramNS, computing iterations on the smaller NxN Gram matrix to significantly save FLOPs on rectangular matrices.

Benchmarks

Performance comparison between standard NS5 and GramNS (Batch Size = 8).
GramNS demonstrates massive efficiency gains on highly rectangular matrices (up to ~40% time reduction for 8192x1024), while maintaining parity on square matrices.

Shape Batch NS5 (ms) GramNS (ms) Ratio (GramNS / NS5)
(512, 512) 8 2.8512 2.7637 0.969
(1024, 1024) 8 10.2181 10.6101 1.038
(2048, 2048) 8 75.5655 74.7877 0.990
(4096, 4096) 8 571.8510 582.7745 1.019
(2048, 1024) 8 17.1982 16.5473 0.962
(1024, 2048) 8 19.1216 16.0166 0.838
(4096, 1024) 8 30.4058 21.5448 0.709
(1024, 4096) 8 31.9332 22.8621 0.716
(8192, 1024) 8 57.0449 34.1205 0.598
(1024, 8192) 8 58.2899 33.4286 0.573
(4096, 2048) 8 122.8127 111.8067 0.910
(2048, 4096) 8 123.4342 110.8306 0.898
(8192, 4096) 8 969.0377 893.7069 0.922
(4096, 8192) 8 1017.5225 917.9643 0.902

- Fix fp16 stability: Using float32 exclusively for the initial spectral normalization step prevents instability, allowing the rest of the algorithm to safely execute in fp16.
- Integrate Gram Newton-Schulz: Computes iterations on the smaller Gram matrix.
- Benchmarks show up to a 42% time reduction for heavily rectangular matrices (e.g., 8192x1024 drops from 58ms to 33ms) with no performance penalty on square shapes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant