Fisher-guided Adaptive Fine-Tuning (FisherAdapTune) is a model-agnostic framework for adaptive, parameter-efficient fine-tuning. Instead of selecting trainable parameters from fixed architectural rules, it tracks how each parameter group's Fisher geometry evolves during training and keeps updating only the groups that still show meaningful task-driven drift.
The method uses Jensen-Shannon distance between consecutive Fisher distributions as a scale-invariant signal of adaptation: parameter groups with stabilized Fisher structure are progressively frozen, while groups with continuous curvature shift remain trainable. This turns fine-tuning into a dynamic, task-aware process that reduces unnecessary updates, and improves the efficiency-generalization trade-off without adding inference-time overhead.
π‘ FisherAdapTune recovers the standard transfer-learning hierarchy automatically, freezing generic low-level components earlier while keeping task-dependent modules trainable for longer.
π‘ Fisher-drift based selection improves the efficiency-generalization trade-off by preserving groups that continue to adapt while avoiding unnecessary updates to stabilized parameters.
π‘ Across crack segmentation experiments, FisherAdapTune improves robustness under distribution shift and supports stronger zero-shot transfer without adding inference-time overhead.
- Overview
- Repository Structure
- Installation
- Quick Start
- Real-World Example: SAM2 Crack Segmentation
- Key Hyperparameters
- Citation
- Contact
- License
FisherAdapTune wraps any PyTorch model and optimizer with a Fisher-guided freeze loop:
- Fisher collection - diagonal FIM statistics are accumulated via AdaFisher hooks on
Linear,Conv2d,BatchNorm2d, andLayerNormlayers. - JS divergence tracking - Jensen-Shannon distance between consecutive Fisher histograms is computed per parameter chunk. A low, stable JS distance signals that a chunk has stopped learning.
- Iterative freezing - parameter groups whose JS scores fall below an adaptive threshold are masked and frozen. Frozen parameter groups are skipped in forward/backward passes, saving computation in later training stages.
The trainer is fully plug-and-play: you can supply the model, optimizer, data loaders, and two callables (train_step_fn, val_step_fn). No subclassing required. See the Quick Start section.
FisherAdapTune/
βββ scripts/ # Core library
β βββ __init__.py
β βββ adafisher.py # AdaFisher Diagonal Fisher Information Matrix estimator
β βββ fisher_core.py # JS divergence, masking, freezing
β βββ trainer.py # FisherAdapTuneTrainer (main user interface)
β βββ utils.py
βββ crack_segmentation/ # Example SAM2 fine-tuning application
β βββ train_sam2.py
β βββ config_sam2.yaml
β βββ dataset.py
βββ examples/
β βββ minimal_image_classifier.py # Self-contained runnable example
βββ requirements.txt
βββ pyproject.toml
conda create -n fisheradaptune python=3.11 -y
conda activate fisheradaptune
pip install -e .Dependencies: torch >= 2.0, numpy >= 1.24, pyyaml >= 6.0
Optional: wandb (logging), matplotlib (JS-distance plots)
import torch
import torch.nn as nn
from scripts import EarlyStopping, FisherAdapTuneTrainer
# 1. Your model and optimizer
model = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10))
# NOTE: always set weight_decay=0.0 here β FisherAdapTune applies it internally
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.0)
# 2. Step functions β do NOT call .backward() or optimizer.step() inside
loss_fn = nn.CrossEntropyLoss()
def train_step(model, batch):
x, y = batch
return loss_fn(model(x), y)
def val_step(model, batch):
x, y = batch
with torch.inference_mode():
logits = model(x)
loss = loss_fn(logits, y).item()
acc = (logits.argmax(1) == y).float().mean().item()
return {"loss": loss, "accuracy": acc}
# 3. Trainer
trainer = FisherAdapTuneTrainer(
model = model,
optimizer = optimizer,
train_loader = train_loader,
train_step_fn = train_step,
val_loader = val_loader,
val_step_fn = val_step,
num_epochs = 5,
weight_decay = 5e-5,
freeze_interval = 300, # steps between freeze decisions
fisher_ema_interval = 50, # steps between Fisher/JS updates
fisher_slice_blocks = 4, # chunks (parameter groups) per weight tensor
)
trainer.fit()See examples/minimal_image_classifier.py for a complete runnable example (works with synthetic data β no dataset required).
The crack_segmentation/ directory contains a full application of FisherAdapTune to fine-tune SAM2 for binary crack segmentation.
cd FisherAdapTune
conda activate fisheradaptune
python crack_segmentation/train_sam2.py \
--config crack_segmentation/config_sam2.yaml \
--train-data-path /path/to/train \
--val-data-path /path/to/valKey CLI flags (all override the yaml):
| Flag | Description |
|---|---|
--train-data-path |
Training image/mask root |
--val-data-path |
Validation image/mask root |
--freeze-interval |
Steps between chunk-freeze passes |
--fisher-ema-interval |
Steps between Fisher/JS updates |
--js-distance-mode |
log (default) or raw histogram mode |
--disable-wandb |
Skip W&B logging |
| Parameter | Default | Description |
|---|---|---|
fisher_ema_interval |
300 | How often (steps) to update Fisher/JS stats |
freeze_interval |
3000 | How often (steps) to evaluate and freeze parameter groups |
fisher_slice_blocks |
4 | Number of chunks (parameter groups) per weight tensor |
fisher_slice_mode |
"row" |
Split axis: "row" or "column" |
js_distance_mode |
"log" |
Histogram normalisation: "log" or "raw" |
chunk_selection_metric |
"total_variation" |
Chunk ranking: "total_variation" or "mean_js" |
js_variance_lambda |
1.0 |
Freeze threshold = mean + Ξ» Γ std of JS scores |
fisher_ema_decay |
0.9 |
EMA decay for Fisher tensors |
prev_js_ema_decay |
0.9 |
EMA decay for the JS-distance signal |
If you find FisherAdapTune useful, please cite:
@misc{rostami2026fisheradaptune,
title = {Fisher-Guided Progressive Parameter Selection for Adaptive Fine-Tuning},
author = {Rostami, Ghodsiyeh and Chen, Po-Han and Hosseini, Mahdi S.},
year = {2026},
eprint = {2606.10196},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
doi = {10.48550/arXiv.2606.10196},
url = {https://arxiv.org/abs/2606.10196}
}Feel free to contact us for questions about the FisherAdapTune paper, this repository, bug reports, or collaboration. We welcome technical feedback that improves reproducibility, implementation clarity, and future extensions of the method.
This project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
