1515# limitations under the License.
1616#
1717import logging
18+ from collections import defaultdict
1819from collections .abc import Callable
1920from collections .abc import Mapping
2021from typing import Any
@@ -189,7 +190,7 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
189190 if isinstance (request , list ):
190191 values = []
191192 responses = []
192- requests_map : dict [Any , Any ] = {}
193+ requests_map : dict [Any , list [ beam . Row ]] = defaultdict ( list )
193194 batch_size = len (request )
194195 raw_query = self .query_template
195196 if batch_size > 1 :
@@ -208,25 +209,29 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
208209 "Make sure the values passed in `fields` are the "
209210 "keys in the input `beam.Row`." + str (e ))
210211 values .extend (current_values )
211- requests_map [self .create_row_key (req )] = req
212+ requests_map [self .create_row_key (req )]. append ( req )
212213 query = raw_query .format (* values )
213214
214215 responses_dict = self ._execute_query (query )
215- unmatched_requests = requests_map .copy ()
216+ unmatched_requests = {
217+ key : list (reqs )
218+ for key , reqs in requests_map .items ()
219+ }
216220 if responses_dict :
217221 for response in responses_dict :
218222 response_row = beam .Row (** response )
219223 response_key = self .create_row_key (response_row )
220224 if response_key in unmatched_requests :
221- req = unmatched_requests .pop (response_key )
222- responses .append ((req , response_row ))
225+ for req in unmatched_requests .pop (response_key ):
226+ responses .append ((req , response_row ))
223227 if unmatched_requests :
224228 if self .throw_exception_on_empty_results :
225229 raise ValueError (f"no matching row found for query: { query } " )
226230 else :
227231 _LOGGER .warning ('no matching row found for query: %s' , query )
228- for req in unmatched_requests .values ():
229- responses .append ((req , beam .Row ()))
232+ for reqs in unmatched_requests .values ():
233+ for req in reqs :
234+ responses .append ((req , beam .Row ()))
230235 return responses
231236 else :
232237 request_dict = request ._asdict ()
0 commit comments