Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306TOPAPEC wants to merge 15 commits into
Conversation
Standalone sequential recommender package, mimics ModelBase interface without touching existing rectools code. FlatSASRec - plain ID-embedding SASRec encoder. UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training (ID emb -> adaptor only -> full finetune). Uses lightweight rank_topk instead of TorchRanker, reuses SASRecDataPreparator for the data pipeline. 30 tests, smoke scripts for both models. Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via multiplication, switched to masked_fill.
New config options: - ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion - optimizer: adam / adamw - scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio) - loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t) - patience: early stopping via EarlyStopping callback + val split - data_preparator: accept custom preparator instance 31 tests passing.
2e923df to
d68834f
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new rectools.fast_transformers subpackage providing GPU-native preprocessing and standalone sequential transformer recommenders (FlatSASRec + UniSRec), plus ranking utilities, scripts, and comprehensive tests.
Changes:
- Introduces torch-native sequence building (
build_sequences), embedding alignment, and lightweight dataset/dataloader helpers. - Adds UniSRec (pretrained text embeddings + adaptor + SASRec encoder) with Lightning training wrapper and a standalone
UniSRecModelAPI (fit/checkpoint/ONNX export). - Adds
rank_topk()for batched scoring with CSR filtering + whitelist, along with benchmark scripts and extensive test coverage.
Reviewed changes
Copilot reviewed 17 out of 19 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| rectools/fast_transformers/init.py | Exposes the new fast_transformers public API surface. |
| rectools/fast_transformers/gpu_data.py | Implements torch-native preprocessing utilities (sequence building, embedding alignment, dataloader helpers). |
| rectools/fast_transformers/net.py | Adds FlatSASRec network implementation. |
| rectools/fast_transformers/ranking.py | Adds rank_topk() batching + filtering + whitelist ranking utility. |
| rectools/fast_transformers/unisrec_lightning.py | Adds LightningModule wrapper (loss/optimizer/scheduler dispatch) for UniSRec training phases. |
| rectools/fast_transformers/unisrec_model.py | Adds standalone UniSRecModel (3-phase training, checkpointing, ONNX export, ID mapping). |
| rectools/fast_transformers/unisrec_net.py | Adds UniSRec network (adaptor + transformer encoder + helper methods). |
| tests/fast_transformers/init.py | Test package marker for fast_transformers. |
| tests/fast_transformers/test_gpu_data.py | Tests for sequence building, embedding alignment, dataset/dataloader, and hashing. |
| tests/fast_transformers/test_net.py | Tests for FlatSASRec forward paths and encoding helpers. |
| tests/fast_transformers/test_onnx_export.py | Tests ONNX export/roundtrip for UniSRec network and UniSRecModel export. |
| tests/fast_transformers/test_ranking.py | Tests top-k ranking, filtering, whitelist behavior, and edge cases. |
| tests/fast_transformers/test_unisrec_lightning.py | Tests UniSRecLightning configuration + loss/scheduler dispatch behavior. |
| tests/fast_transformers/test_unisrec_model.py | Tests UniSRecModel fit phases, losses/optimizers/schedulers, checkpointing, and mapping. |
| tests/fast_transformers/test_unisrec_net.py | Tests UniSRec network output shapes, adaptor variants, and freeze/unfreeze helpers. |
| scripts/compare_sasrec_unisrec.py | Benchmark script to compare RecTools SASRec vs UniSRec-ID and generate a report. |
| scripts/comparison_report.md | Adds a sample benchmark report output. |
| CHANGELOG.md | Documents the new module and features under Unreleased. |
| .gitignore | Ignores new dev artifacts, model weights, and data folders. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def build_sequences( | ||
| user_ids: torch.Tensor, | ||
| item_ids: torch.Tensor, | ||
| timestamps: torch.Tensor, | ||
| max_len: int, | ||
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", | ||
| ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| user_ids = user_ids.to(device) | ||
| item_ids = item_ids.to(device) | ||
| timestamps = timestamps.to(device) |
| unique_items = torch.unique(item_ids) | ||
| n_unique = len(unique_items) | ||
|
|
||
| if id_mapping == "dense": | ||
| _, item_inv = torch.unique(item_ids, return_inverse=True) | ||
| internal_items = item_inv + 1 | ||
| elif id_mapping == "hash": |
| x, y, unique_items, unique_users = build_sequences( | ||
| user_ids, | ||
| item_ids, | ||
| timestamps, | ||
| max_len=self.session_max_len, | ||
| min_interactions=self.train_min_user_interactions, | ||
| id_mapping=self.id_mapping, | ||
| ) | ||
| self._unique_items = unique_items.cpu() | ||
| self._unique_users = unique_users.cpu() | ||
| n_items = len(unique_items) | ||
|
|
||
| aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items, self.id_mapping) | ||
|
|
||
| net = UniSRec( | ||
| n_items=n_items, | ||
| pretrained_embeddings=aligned_emb, | ||
| n_factors=self.n_factors, | ||
| projection_hidden=self.projection_hidden, | ||
| n_blocks=self.n_blocks, | ||
| n_heads=self.n_heads, | ||
| session_max_len=self.session_max_len, | ||
| dropout=self.dropout, | ||
| adaptor_dropout=self.adaptor_dropout, | ||
| adaptor_type=self.adaptor_type, | ||
| use_adaptor_ffn=self.use_adaptor_ffn, | ||
| ffn_type=self.ffn_type, | ||
| ffn_expansion=self.ffn_expansion, | ||
| ) | ||
|
|
||
| train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True) | ||
|
|
| lookup = {int(v): i + 1 for i, v in enumerate(self._unique_items.tolist())} | ||
| return torch.tensor([lookup.get(int(x), 0) for x in external_ids.tolist()], dtype=torch.long) | ||
|
|
| viewed_mask = torch.tensor(batch_csr.toarray(), dtype=torch.bool, device=device) | ||
| scores[viewed_mask] = -float("inf") | ||
|
|
| def test_padding_invariance(self, net: FlatSASRec) -> None: | ||
| """Different left-padding should produce same last-position embedding.""" | ||
| net.eval() | ||
| # Same content should produce identical output | ||
| x_a = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| x_b = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| with torch.no_grad(): | ||
| e_a = net.encode_last(x_a) | ||
| e_b = net.encode_last(x_b) | ||
| torch.testing.assert_close(e_a, e_b) |
| class TestPaddingInvariance: | ||
| def test_same_input_same_output(self, net: UniSRec) -> None: | ||
| net.eval() | ||
| x_a = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| x_b = torch.tensor([[0, 0, 0, 5, 10]]) | ||
| with torch.no_grad(): | ||
| e_a = net.encode_last(x_a, use_id=False) | ||
| e_b = net.encode_last(x_b, use_id=False) | ||
| torch.testing.assert_close(e_a, e_b) |
| train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True) | ||
|
|
||
| val_dl = None | ||
| if self.patience is not None: | ||
| val_y_last = y[:, -1:] | ||
| val_dl = make_dataloader(x, val_y_last, batch_size=self.batch_size, shuffle=False) |
| x, y, unique_items, unique_users = build_sequences( | ||
| user_ids, | ||
| item_ids, | ||
| timestamps, | ||
| max_len=self.session_max_len, | ||
| min_interactions=self.train_min_user_interactions, | ||
| id_mapping=self.id_mapping, | ||
| ) |
| max_len: int, | ||
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", |
There was a problem hiding this comment.
Better to use Literal for such things
| min_interactions: int = 2, | ||
| device: str = "cuda", | ||
| id_mapping: str = "dense", | ||
| ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Please add extensive docstrings for all the public method, especially for those supposed to be used stand-alone. Here it's especially important since you're returning 4 tensors and user doesn't understand their meaning. Also good to add examples
| catboost_info/ | ||
|
|
||
| # Dev artifacts | ||
| training_folder/ |
There was a problem hiding this comment.
a bit weird name, can we remove it?
| - `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order | ||
| - `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data | ||
| - Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor | ||
| - Tests for all `fast_transformers` submodules (143 tests) |
There was a problem hiding this comment.
We normally don't add anything that doesn't affect user directly to the changelog, so not much sense to write about the tests
There was a problem hiding this comment.
please put this script and the report to a subfolder in the benchmark folder
| return aligned | ||
|
|
||
|
|
||
| class GPUBatchDataset(TorchDataset): |
There was a problem hiding this comment.
I'm not sure the name reflect the purpose
- why GPU?
- what does Batch mean?
It also sounds quite "universal" even though I'd say it's more task-specific
| y: torch.Tensor, | ||
| batch_size: int, | ||
| shuffle: bool = True, | ||
| transform: tp.Optional[tp.Callable] = None, |
There was a problem hiding this comment.
I'd recommend to add **kwargs here to cover different parameters of data loader
On the other side I'm not sure it makes much sense to wrap 2 function calls in a separate function
| from scipy import sparse | ||
|
|
||
|
|
||
| def rank_topk( |
There was a problem hiding this comment.
Sorry, I'm too lazy to check, could you please describe why do we need it given that we have TorchRanker? Could we reuse the code?
- Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes.
- Remove ranking.py (duplicates TorchRanker) - Remove hash ID mapping from build_sequences/align_embeddings - Simplify UniSRecModel to single joint training phase (adaptor + transformer) - Rename gpu_data.py -> sequence_data.py, GPUBatchDataset -> SequenceBatchDataset - Vectorize map_item_ids with torch.searchsorted - Fix device default (None -> auto-detect from input tensor) - Fix double torch.unique call - Add empty dataset validation in fit() - Add **kwargs to make_dataloader - Add dataloader_num_workers passthrough - Move benchmark script to benchmark/ folder - Add KION training demo with Qwen3-Embedding-0.6B results - Update tests for simplified API - Clean up CHANGELOG and .gitignore
- Remove item_emb, use_id, freeze/unfreeze, phase references from net/lightning - Remove GPUBatchDataset alias and make_dataloader wrapper - Reorganize into preprocessing/ and unisrec/ subpackages - Add GPU-friendly HR@K, NDCG@K, MRR@K metrics (tested against RecTools) - Update benchmark, demo, and all tests (102 passed + 28 metric tests)
d68834f to
45ed8ae
Compare
- Add negative sampling transform in fit() for BCE/gBCE/sampled_softmax losses - Add e2e tests for all non-softmax losses via UniSRecModel.fit() - Fix load_checkpoint() default device: auto-detect cuda/cpu instead of hardcoded "cuda" - Fix map_item_ids() device mismatch when input is on CUDA - Fix Python 3.9 compat: replace PEP 604 unions with Optional[] in tests - Fix CHANGELOG: remove nonexistent FlatSASRecModel and make_dataloader() - Update benchmark: auto-download ML-20M, fallback random embeddings, fix paths
…ings, n_negatives validation - Run black/isort/flake8 on all fast_transformers files — all pass now - Fix val dataloader missing negatives when patience + non-softmax loss - Extract _NegativeSampler class: device-aware, resamples positive collisions - Validate n_negatives is a positive integer for non-softmax losses - Make align_embeddings() device-aware (supports CUDA pretrained embeddings) - Remove unused imports (os in benchmark, pytest in test_sequence_data) - Add CUDA guard in benchmark main() - Add e2e tests: non-softmax losses with patience, n_negatives=0/-1/None
Keep only device-awareness (the actual review request). Preserving pretrained.dtype could cause precision issues with float16 inputs.
- Add `device` parameter to UniSRecModel.__init__ (default None = input device) - Move x/y to CPU before DataLoader to avoid CUDA+multiprocessing issues - Benchmark: pass device="cuda" explicitly to build_sequences and UniSRecModel
…pylint, bandit) - Add type annotations across benchmark, tests, and source files (mypy 30→0 errors) - Annotate frozen_emb buffer and Optional head in net.py - Add assert guards for Optional item_id_mapping usage - Type sasrec_kwargs and nested functions in benchmark - Fix tensor index type in test_metrics
New
rectools.fast_transformersmodule — standalone transformer sequential recommenders that work with raw torch tensors, without going throughDataset/pandas.GPU-native preprocessing.
build_sequences()builds left-padded interaction sequences entirely in torch (argsort + scatter). On ML-20M (20M interactions) this takes 0.5s vs 14.6s for the pandas-basedSASRecDataPreparator— roughly 30x faster. For larger production data the problem is even worse - KION prod dataset aggregation for a period of only half a year takes up to 50 minutes only on current rectools code to preprocess data, while train take comparable time to finish.FlatSASRec. Pre-norm SASRec encoder with plain id-embeddings, no ItemNet hierarchy. Wraps into
FlatSASRecModel(inheritsModelBase) so it plugs into standard RecTools fit/recommend.UniSRec. Three-phase sequential recommender with pretrained text embeddings and a learnable PCA adaptor:
UniSRecModel.fit(user_ids, item_ids, timestamps)takes raw tensors end-to-end. Supports softmax/BCE/gBCE/sampled_softmax losses, Adam/AdamW, cosine warmup scheduler, gradient clipping, early stopping, checkpoint save/load. FFN blocks are configurable (conv1d, linear_gelu, linear_relu).rank_topk()— batched top-k with CSR viewed-item filtering and whitelist support.Benchmark (ML-20M, 10 epochs, softmax, Adam, n_factors=256)
UniSRec ID: +4.6% HR@10, +6.0% NDCG@10, 1.65x faster overall.
New files
Source (9 modules, 1683 lines):
rectools/fast_transformers/gpu_data.py—build_sequences,align_embeddings,GPUBatchDataset,make_dataloaderrectools/fast_transformers/net.py—FlatSASRec,SASRecBlockrectools/fast_transformers/lightning_wrap.py—FlatSASRecLightningrectools/fast_transformers/model.py—FlatSASRecModel,FlatSASRecConfigrectools/fast_transformers/ranking.py—rank_topkrectools/fast_transformers/unisrec_net.py—UniSRec,FeedForward,make_ffnrectools/fast_transformers/unisrec_lightning.py—UniSRecLightning, loss/optimizer/scheduler dispatchrectools/fast_transformers/unisrec_model.py—UniSRecModel(three-phase fit, checkpoint)Tests (143 tests, 1920 lines):
tests/fast_transformers/test_gpu_data.py— sequence building, alignment, dataset/dataloadertests/fast_transformers/test_net.py,test_lightning_wrap.py,test_model.py— FlatSASRec stacktests/fast_transformers/test_unisrec_net.py,test_unisrec_lightning.py,test_unisrec_model.py— UniSRec stacktests/fast_transformers/test_ranking.py— top-k, filtering, edge casesScripts:
scripts/compare_sasrec_unisrec.py— full benchmark with markdown report generationscripts/comparison_report.md— benchmark resultsTest plan
pytest tests/fast_transformers/ -q)FlatSASRecModelfit/recommend through the standard RecTools API on a small dataset