Skip to content

Commit 30518c8

Browse files
authored
[Stateful] Implement length-aware keying to minimize padding in BatchElements (Part 2/3) (#37565)
* Add length-aware batching to BatchElements and ModelHandler - Add length_fn and bucket_boundaries parameters to ModelHandler.__init__ to support length-aware bucketed keying for ML inference batching - Add WithLengthBucketKey DoFn to route elements by length buckets - Update BatchElements to support length-aware batching when max_batch_duration_secs is set, reducing padding waste for variable-length sequences (e.g., NLP workloads) - Default bucket boundaries: [16, 32, 64, 128, 256, 512] - Add comprehensive tests validating bucket assignment, mixed-length batching, and padding efficiency improvements (77% vs 68% on bimodal data) - All formatting (yapf) and lint (pylint 10/10) checks passed * Refine length bucketing docs and fix boundary inclusivity Expands parameter documentation for clarity and replaces bisect_left with bisect_right to ensure bucket boundaries are inclusive on the lower bound. Updates util_test.py assertions accordingly.
1 parent 9524b56 commit 30518c8

4 files changed

Lines changed: 347 additions & 2 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(
178178
max_batch_duration_secs: Optional[int] = None,
179179
max_batch_weight: Optional[int] = None,
180180
element_size_fn: Optional[Callable[[Any], int]] = None,
181+
batch_length_fn: Optional[Callable[[Any], int]] = None,
182+
batch_bucket_boundaries: Optional[list[int]] = None,
181183
large_model: bool = False,
182184
model_copies: Optional[int] = None,
183185
**kwargs):
@@ -190,6 +192,17 @@ def __init__(
190192
before emitting; used in streaming contexts.
191193
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
192194
element_size_fn: a function that returns the size (weight) of an element.
195+
batch_length_fn: a callable mapping an element to its length (int). When
196+
set together with max_batch_duration_secs, enables length-aware bucketed
197+
keying so that elements of similar length are batched together, reducing
198+
padding waste for variable-length inputs. Bucket assignment uses
199+
bisect_right so boundaries are lower-inclusive: e.g., for boundaries
200+
[10, 50], buckets are (-inf, 10), [10, 50), [50, inf).
201+
batch_bucket_boundaries: a sorted list of positive boundary values for
202+
length bucketing. Boundaries are lower-inclusive (bisect_right
203+
semantics): bucket i covers lengths in [boundaries[i-1], boundaries[i]).
204+
Requires batch_length_fn. Defaults to [16, 32, 64, 128, 256, 512] when
205+
batch_length_fn is set.
193206
large_model: set to true if your model is large enough to run into
194207
memory pressure if you load multiple copies.
195208
model_copies: The exact number of models that you would like loaded
@@ -209,6 +222,10 @@ def __init__(
209222
self._batching_kwargs['max_batch_weight'] = max_batch_weight
210223
if element_size_fn is not None:
211224
self._batching_kwargs['element_size_fn'] = element_size_fn
225+
if batch_length_fn is not None:
226+
self._batching_kwargs['length_fn'] = batch_length_fn
227+
if batch_bucket_boundaries is not None:
228+
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
212229
self._large_model = large_model
213230
self._model_copies = model_copies
214231
self._share_across_processes = large_model or (model_copies is not None)

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,45 @@ def test_max_batch_duration_secs_only(self):
22792279

22802280
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
22812281

2282+
def test_batch_length_fn_and_batch_bucket_boundaries(self):
2283+
"""batch_length_fn and batch_bucket_boundaries passed through to kwargs."""
2284+
handler = FakeModelHandlerForBatching(
2285+
batch_length_fn=len, batch_bucket_boundaries=[16, 32, 64])
2286+
kwargs = handler.batch_elements_kwargs()
2287+
2288+
self.assertIs(kwargs['length_fn'], len)
2289+
self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])
2290+
2291+
def test_batch_length_fn_only(self):
2292+
"""batch_length_fn alone is passed through without bucket_boundaries."""
2293+
handler = FakeModelHandlerForBatching(batch_length_fn=len)
2294+
kwargs = handler.batch_elements_kwargs()
2295+
2296+
self.assertIs(kwargs['length_fn'], len)
2297+
self.assertNotIn('bucket_boundaries', kwargs)
2298+
2299+
def test_batch_bucket_boundaries_without_batch_length_fn(self):
2300+
"""Passing batch_bucket_boundaries without batch_length_fn should fail in
2301+
BatchElements.
2302+
2303+
Note: ModelHandler.__init__ doesn't validate this; the error is raised
2304+
by BatchElements when batch_elements_kwargs are used."""
2305+
handler = FakeModelHandlerForBatching(batch_bucket_boundaries=[10, 20])
2306+
kwargs = handler.batch_elements_kwargs()
2307+
# The kwargs are stored, but BatchElements will reject them
2308+
self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
2309+
self.assertNotIn('length_fn', kwargs)
2310+
2311+
def test_batching_kwargs_none_values_omitted(self):
2312+
"""None values for batch_length_fn and batch_bucket_boundaries are not in
2313+
kwargs."""
2314+
handler = FakeModelHandlerForBatching(
2315+
min_batch_size=5, batch_length_fn=None, batch_bucket_boundaries=None)
2316+
kwargs = handler.batch_elements_kwargs()
2317+
self.assertNotIn('length_fn', kwargs)
2318+
self.assertNotIn('bucket_boundaries', kwargs)
2319+
self.assertEqual(kwargs['min_batch_size'], 5)
2320+
22822321

22832322
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
22842323
def load_model(self):

sdks/python/apache_beam/transforms/util.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# pytype: skip-file
2222

23+
import bisect
2324
import collections
2425
import contextlib
2526
import hashlib
@@ -1208,6 +1209,30 @@ def process(self, element):
12081209
yield (self.key, element)
12091210

12101211

1212+
class WithLengthBucketKey(DoFn):
1213+
"""Keys elements with (worker_uuid, length_bucket) for length-aware
1214+
stateful batching. Elements of similar length are routed to the same
1215+
state partition, reducing padding waste."""
1216+
def __init__(self, length_fn, bucket_boundaries):
1217+
self.shared_handle = shared.Shared()
1218+
self._length_fn = length_fn
1219+
self._bucket_boundaries = bucket_boundaries
1220+
1221+
def setup(self):
1222+
self.key = self.shared_handle.acquire(
1223+
load_shared_key, "WithLengthBucketKey").key
1224+
1225+
def _get_bucket(self, length):
1226+
# bisect_right: boundaries are lower-inclusive.
1227+
# e.g., for boundaries [10, 50], buckets are (-inf, 10), [10, 50), [50, inf)
1228+
return bisect.bisect_right(self._bucket_boundaries, length)
1229+
1230+
def process(self, element):
1231+
length = self._length_fn(element)
1232+
bucket = self._get_bucket(length)
1233+
yield ((self.key, bucket), element)
1234+
1235+
12111236
@typehints.with_input_types(T)
12121237
@typehints.with_output_types(list[T])
12131238
class BatchElements(PTransform):
@@ -1267,7 +1292,18 @@ class BatchElements(PTransform):
12671292
donwstream operations (mostly for testing)
12681293
record_metrics: (optional) whether or not to record beam metrics on
12691294
distributions of the batch size. Defaults to True.
1295+
length_fn: (optional) a callable mapping an element to its length (int).
1296+
When set together with bucket_boundaries, enables length-aware bucketed
1297+
keying on the stateful path so that elements of similar length are
1298+
routed to the same batch, reducing padding waste.
1299+
bucket_boundaries: (optional) a sorted list of positive boundary values
1300+
for length bucketing. Boundaries are lower-inclusive (bisect_right
1301+
semantics): e.g., for boundaries [10, 50], buckets are (-inf, 10),
1302+
[10, 50), [50, inf). Defaults to [16, 32, 64, 128, 256, 512] when
1303+
length_fn is set. Requires length_fn.
12701304
"""
1305+
_DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]
1306+
12711307
def __init__(
12721308
self,
12731309
min_batch_size=1,
@@ -1280,7 +1316,17 @@ def __init__(
12801316
element_size_fn=lambda x: 1,
12811317
variance=0.25,
12821318
clock=time.time,
1283-
record_metrics=True):
1319+
record_metrics=True,
1320+
length_fn=None,
1321+
bucket_boundaries=None):
1322+
if bucket_boundaries is not None and length_fn is None:
1323+
raise ValueError('bucket_boundaries requires length_fn to be set.')
1324+
if bucket_boundaries is not None:
1325+
if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or
1326+
bucket_boundaries != sorted(bucket_boundaries)):
1327+
raise ValueError(
1328+
'bucket_boundaries must be a non-empty sorted list of '
1329+
'positive values.')
12841330
self._batch_size_estimator = _BatchSizeEstimator(
12851331
min_batch_size=min_batch_size,
12861332
max_batch_size=max_batch_size,
@@ -1294,13 +1340,23 @@ def __init__(
12941340
self._element_size_fn = element_size_fn
12951341
self._max_batch_dur = max_batch_duration_secs
12961342
self._clock = clock
1343+
self._length_fn = length_fn
1344+
if length_fn is not None and bucket_boundaries is None:
1345+
self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES
1346+
else:
1347+
self._bucket_boundaries = bucket_boundaries
12971348

12981349
def expand(self, pcoll):
12991350
if getattr(pcoll.pipeline.runner, 'is_streaming', False):
13001351
raise NotImplementedError("Requires stateful processing (BEAM-2687)")
13011352
elif self._max_batch_dur is not None:
13021353
coder = coders.registry.get_coder(pcoll)
1303-
return pcoll | ParDo(WithSharedKey()) | ParDo(
1354+
if self._length_fn is not None:
1355+
keying_dofn = WithLengthBucketKey(
1356+
self._length_fn, self._bucket_boundaries)
1357+
else:
1358+
keying_dofn = WithSharedKey()
1359+
return pcoll | ParDo(keying_dofn) | ParDo(
13041360
_pardo_stateful_batch_elements(
13051361
coder,
13061362
self._batch_size_estimator,

0 commit comments

Comments
 (0)