Skip to content

Add Schur completment and its mat-free mode#35

Open
zitongzhan wants to merge 36 commits into
releasefrom
memory-issue-swp
Open

Add Schur completment and its mat-free mode#35
zitongzhan wants to merge 36 commits into
releasefrom
memory-issue-swp

Conversation

@zitongzhan
Copy link
Copy Markdown
Collaborator

This pull request introduces significant improvements to the optimizer infrastructure, focusing on enhanced memory profiling, a new Schur complement optimizer, and better support for matrix-free operations.

Optimizer Enhancements

  • Added a new Schur optimizer class in bae.optim.optimizer, implementing the Schur complement method with support for both standard and matrix-free normal equations, block Jacobi preconditioning, and efficient memory usage.

  • Updated the LM optimizer to support a matrix_free_normal mode, allowing for more efficient computation and memory usage in large-scale problems.

  • Add a custom TrustRegion class that supports Warp, especially for use with the Schur optimizer.

Sparse Matrix and PyOps Improvements

  • Improved sparse matrix operations, including fixes to inv_op for correct tensor creation and a new test block in py_ops.py for diagonal operations on CUDA.

Comment thread bae/sparse/warp_wrappers.py Fixed
Comment thread bae/optim/optimizer.py Fixed
Comment thread bae/sparse/py_ops.py Fixed
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces high-performance Triton kernels for sparse BSR operations, including matrix-vector multiplication, matrix-matrix multiplication, and transposition. It also implements a matrix-free NormalMatVec operator and a new Schur complement-based optimizer to improve the efficiency of bundle adjustment tasks. The bundle adjustment example was updated with CUDA memory snapshotting and Warp mempool reporting. Review feedback highlights a critical issue where in-place diagonal modifications in the LM and Schur optimizers cause damping factors to accumulate incorrectly during step rejections. Additionally, the reviewer recommends removing performance-hindering torch.cuda.empty_cache() calls, addressing potential divisions by zero in the Conjugate Gradient solver, and cleaning up redundant or commented-out code.

Comment thread bae/optim/optimizer.py
diag_scale *= 1.0 + pg['damping']
A.set_damping(diag_scale - 1.0)
else:
diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The diagonal_op_ function performs an in-place multiplication on the matrix A. Since this is inside the while loop, if a step is rejected and the loop repeats, the damping will be applied cumulatively (e.g., $(1+\lambda_1)(1+\lambda_2)...$) instead of being applied to the original $J^T J$ diagonal. This deviates from the standard Levenberg-Marquardt algorithm and can lead to excessively aggressive damping. Consider cloning the matrix or resetting the diagonal before applying damping in each iteration.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve in future PR

Comment thread bae/optim/optimizer.py Outdated
Comment thread bae/optim/optimizer.py Outdated
R = R.tensor()
else:
R = R.detach()
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.cuda.empty_cache() inside the optimization step is generally discouraged as it triggers a GPU synchronization and can significantly degrade performance. If memory management is a concern, it's better to optimize tensor lifecycles or use a dedicated memory pool. If this was added for debugging memory usage, it should be removed before merging.

Comment thread bae/optim/triton_kernel.py Outdated

Ap = matvec(p)
Ap_flat = Ap.reshape(-1)
alpha = (rz / torch.dot(p.reshape(-1), Ap_flat)).item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential division by zero if torch.dot(p.reshape(-1), Ap_flat) is zero (e.g., if the matrix is singular or not positive definite). While $J^T J$ is positive semi-definite, numerical issues or zero curvature directions could cause this to be zero. Consider adding a small epsilon or a check for numerical stability.

Comment thread bae/optim/triton_kernel.py Outdated

rz_new = torch.dot(r_flat, z_flat)
beta = (rz_new / rz).item()
p.mul_(beta).add_(z)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential division by zero if rz is zero. Although the convergence check at line 692 should ideally terminate the loop if the residual is zero, a safety check for rz before division is recommended to prevent NaN values in case of numerical instability.

Comment thread bae/sparse/warp_wrappers.py Outdated
Comment thread bae/utils/pysolvers.py Outdated
zitongzhan and others added 3 commits May 23, 2026 20:35
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Comment thread ba_example.py Fixed
Comment thread ba_example.py Fixed
Comment thread ba_example.py Fixed
Comment thread ba_example.py Fixed
Comment thread ba_example.py Fixed
Comment thread ba_example.py
from pathlib import Path
import warp as wp
from warp import sparse as wpsparse
from datapipes.bal_loader import get_problem, read_bal_data
Comment thread ba_example.py Fixed
Comment thread ba_example.py Fixed
Comment thread ba_example.py
import warp as wp
from warp import sparse as wpsparse
from datapipes.bal_loader import get_problem, read_bal_data
from bae.sparse.py_ops import *
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Comment thread ba_example.py
"observes": dataset["points_2d"],
"cidx": dataset["camera_index_of_observations"],
"pidx": dataset["point_index_of_observations"],
"points_2d": dataset["points_2d"],
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SEOKWOOPARK Let's keep the diff minimal and avoid changes unrelated to the purpose of the PR

@zitongzhan
Copy link
Copy Markdown
Collaborator Author

Profile Summary
Profiled current ba_example.py on Venice problem-1778-993923-pre: 5,001,946 observations, 1,778 cameras, 993,923 points. I passed matrix_free_normal=True and False; current ba_example.py defaults to disabled.

Mode Steady wall time Main slow operators
matrix_free_normal=True 1.53 s Warp BSR MV kernels inside linear.cg: ~1.19 s CUDA, ~84%
matrix_free_normal=False 1.83 s Split between explicit Schur warp_bsr_mm: ~0.66 s, and CG BSR MV: ~0.68 s

Enabled
With matrix-free enabled, the bottleneck is still BSR matvec, now inside Warp CG. The hottest kernels were:

Kernel / scope CUDA time
bsr_mv_transpose_kernel... 801 ms
bsr_mv_kernel_acf84b96... 230 ms
bsr_mv_kernel_0d4f3dc9... 163 ms
jacobian 123 ms

This corresponds to the repeated matrix-free Schur matvec in optimizer.py, especially the sparse.bsr_mv chain at lines 145-152 and the CG calls at lines 180 and 200.

Disabled
With matrix-free disabled, the cost shifts: explicit Schur construction becomes about as expensive as CG matvecs.

Kernel / scope CUDA time
warp_bsr_mm scope 658 ms
_bsr_mm_compute_values... 611 ms
bsr_mv_tiled_kernel... 645 ms
jacobian 122 ms

That maps to explicit Schur construction at optimizer.py: WV_i = sparse.bsr_mm(W, V_i) and WVi_Wt = sparse.bsr_mm(WV_i, Wt).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants