Skip to content

Commit a03784e

Browse files
Fix BigQueryEnrichmentHandler batch handling for duplicate keys (#38040)
* Fix BigQuery enrichment batch handling for duplicate keys * Apply yapf formatting for BigQuery enrichment changes
1 parent 392b869 commit a03784e

2 files changed

Lines changed: 59 additions & 7 deletions

File tree

sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717
import logging
18+
from collections import defaultdict
1819
from collections.abc import Callable
1920
from collections.abc import Mapping
2021
from typing import Any
@@ -189,7 +190,7 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
189190
if isinstance(request, list):
190191
values = []
191192
responses = []
192-
requests_map: dict[Any, Any] = {}
193+
requests_map: dict[Any, list[beam.Row]] = defaultdict(list)
193194
batch_size = len(request)
194195
raw_query = self.query_template
195196
if batch_size > 1:
@@ -208,25 +209,29 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
208209
"Make sure the values passed in `fields` are the "
209210
"keys in the input `beam.Row`." + str(e))
210211
values.extend(current_values)
211-
requests_map[self.create_row_key(req)] = req
212+
requests_map[self.create_row_key(req)].append(req)
212213
query = raw_query.format(*values)
213214

214215
responses_dict = self._execute_query(query)
215-
unmatched_requests = requests_map.copy()
216+
unmatched_requests = {
217+
key: list(reqs)
218+
for key, reqs in requests_map.items()
219+
}
216220
if responses_dict:
217221
for response in responses_dict:
218222
response_row = beam.Row(**response)
219223
response_key = self.create_row_key(response_row)
220224
if response_key in unmatched_requests:
221-
req = unmatched_requests.pop(response_key)
222-
responses.append((req, response_row))
225+
for req in unmatched_requests.pop(response_key):
226+
responses.append((req, response_row))
223227
if unmatched_requests:
224228
if self.throw_exception_on_empty_results:
225229
raise ValueError(f"no matching row found for query: {query}")
226230
else:
227231
_LOGGER.warning('no matching row found for query: %s', query)
228-
for req in unmatched_requests.values():
229-
responses.append((req, beam.Row()))
232+
for reqs in unmatched_requests.values():
233+
for req in reqs:
234+
responses.append((req, beam.Row()))
230235
return responses
231236
else:
232237
request_dict = request._asdict()

sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616
#
1717
import unittest
18+
from unittest import mock
1819

1920
from parameterized import parameterized
2021

22+
import apache_beam as beam
23+
2124
# pylint: disable=ungrouped-imports
2225
try:
2326
from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler
@@ -65,6 +68,50 @@ def test_valid_params(
6568
max_batch_size=max_batch_size,
6669
)
6770

71+
def test_batch_mode_fans_out_response_for_duplicate_keys(self):
72+
handler = BigQueryEnrichmentHandler(
73+
project=self.project,
74+
table_name='project.dataset.table',
75+
row_restriction_template="id='{}'",
76+
fields=['id'],
77+
min_batch_size=2,
78+
max_batch_size=2,
79+
)
80+
requests = [beam.Row(id='1', name='first'), beam.Row(id='1', name='second')]
81+
82+
with mock.patch.object(handler,
83+
'_execute_query',
84+
return_value=[{'id': '1', 'value': 'enriched'}]):
85+
responses = handler(requests)
86+
87+
self.assertEqual(
88+
responses,
89+
[
90+
(requests[0], beam.Row(id='1', value='enriched')),
91+
(requests[1], beam.Row(id='1', value='enriched')),
92+
],
93+
)
94+
95+
def test_batch_mode_emits_empty_rows_for_all_unmatched_duplicate_keys(self):
96+
handler = BigQueryEnrichmentHandler(
97+
project=self.project,
98+
table_name='project.dataset.table',
99+
row_restriction_template="id='{}'",
100+
fields=['id'],
101+
min_batch_size=2,
102+
max_batch_size=2,
103+
throw_exception_on_empty_results=False,
104+
)
105+
requests = [beam.Row(id='1', name='first'), beam.Row(id='1', name='second')]
106+
107+
with mock.patch.object(handler, '_execute_query', return_value=None):
108+
responses = handler(requests)
109+
110+
self.assertEqual(
111+
responses,
112+
[(requests[0], beam.Row()), (requests[1], beam.Row())],
113+
)
114+
68115

69116
if __name__ == '__main__':
70117
unittest.main()

0 commit comments

Comments
 (0)