Skip to content

Commit 4db9046

Browse files
committed
fix gemini comments
1 parent 321cb61 commit 4db9046

File tree

2 files changed

+51
-51
lines changed

2 files changed

+51
-51
lines changed

sdks/python/apache_beam/yaml/yaml_mapping.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
#
1717

1818
"""This module defines the basic MapToFields operation."""
19+
20+
import datetime
1921
import itertools
22+
import json
2023
import re
24+
import threading
25+
import uuid
2126
from collections import abc
2227
from collections.abc import Callable
2328
from collections.abc import Collection
@@ -29,10 +34,6 @@
2934
from typing import TypeVar
3035
from typing import Union
3136

32-
import json
33-
import threading
34-
import uuid
35-
3637
import apache_beam as beam
3738
from apache_beam.io.filesystems import FileSystems
3839
from apache_beam.portability.api import schema_pb2
@@ -62,7 +63,24 @@
6263
except ImportError:
6364
MiniRacer = None
6465

65-
_js_thread_funcs = {}
66+
67+
class _JsThreadContext:
68+
def __init__(self):
69+
self._local = threading.local()
70+
71+
def get_funcs(self):
72+
if not hasattr(self._local, 'funcs'):
73+
self._local.funcs = {}
74+
return self._local.funcs
75+
76+
def __getstate__(self):
77+
return {}
78+
79+
def __setstate__(self, state):
80+
self._local = threading.local()
81+
82+
83+
_js_contexts = _JsThreadContext()
6684

6785
_str_expression_fields = {
6886
'AssignTimestamps': 'timestamp',
@@ -181,9 +199,6 @@ def _check_mapping_arguments(
181199
raise ValueError(f'{transform_name} cannot specify "name" without "path"')
182200

183201

184-
185-
186-
187202
# TODO(yaml) Improve type inferencing for JS UDF's
188203
def py_value_to_js_dict(py_value):
189204
if ((isinstance(py_value, tuple) and hasattr(py_value, '_asdict')) or
@@ -200,25 +215,18 @@ def py_value_to_js_dict(py_value):
200215
def js_to_py(obj):
201216
"""Converts mini-racer mapped objects to standard Python types.
202217
203-
This is needed because ctx.eval returns JSMappedObjectImpl and JSArrayImpl
204-
for JS objects and arrays, which are not picklable and would fail when Beam
205-
tries to serialize rows containing them. We also preserve datetime objects
206-
which are correctly produced by ctx.eval for JS Date objects.
218+
This is needed because ctx.eval returns objects that implement Mapping
219+
and Iterable but are not picklable (like JSMappedObjectImpl and JSArrayImpl),
220+
which would fail when Beam tries to serialize rows containing them.
221+
We also preserve datetime objects which are correctly produced by ctx.eval
222+
for JS Date objects.
207223
"""
208-
import datetime
209-
from collections import abc
210-
211-
type_name = type(obj).__name__
212-
if type_name == 'JSMappedObjectImpl':
213-
return {k: js_to_py(v) for k, v in dict(obj).items()}
214-
elif type_name == 'JSArrayImpl':
215-
return [js_to_py(v) for v in list(obj)]
216-
elif isinstance(obj, datetime.datetime):
224+
if isinstance(obj, datetime.datetime):
217225
return obj
218-
elif isinstance(obj, dict):
226+
elif isinstance(obj, Mapping):
219227
return {k: js_to_py(v) for k, v in obj.items()}
220-
elif not isinstance(obj, str) and isinstance(obj, abc.Iterable):
221-
return [js_to_py(v) for v in list(obj)]
228+
elif not isinstance(obj, str) and isinstance(obj, Iterable):
229+
return [js_to_py(v) for v in obj]
222230
else:
223231
return obj
224232

@@ -230,7 +238,8 @@ def _expand_javascript_mapping_func(
230238

231239
if MiniRacer is None:
232240
raise ValueError(
233-
"JavaScript mapping functions require the 'mini-racer' package to be installed.")
241+
"JavaScript mapping functions require the 'mini-racer' package to be installed."
242+
)
234243

235244
udf_code = None
236245
if path:
@@ -239,36 +248,27 @@ def _expand_javascript_mapping_func(
239248
udf_code = FileSystems.open(path).read().decode()
240249
elif expression:
241250
udf_code = f"var func = (__row__) => {{ " + " ".join([
242-
f"const {n} = __row__.{n};"
243-
for n in original_fields if n in expression
251+
f"const {n} = __row__.{n};" for n in original_fields if n in expression
244252
]) + f" return ({expression}); }}"
245253
elif callable:
246254
udf_code = f"var func = {callable}"
247255

248256
udf_key = str(uuid.uuid4())
249257

250258
def js_wrapper(row):
251-
tid = threading.get_ident()
252-
253-
global _js_thread_funcs
254-
# MiniRacer contexts are not picklable and cannot be shared across threads.
255-
# We use a global dict keyed by thread ID to lazily create and cache a
256-
# context per thread.
257-
if tid not in _js_thread_funcs:
258-
_js_thread_funcs[tid] = {}
259-
260-
if udf_key not in _js_thread_funcs[tid]:
259+
funcs = _js_contexts.get_funcs()
260+
261+
if udf_key not in funcs:
261262
ctx = MiniRacer()
262263
ctx.eval(udf_code)
263-
# We use ctx.eval instead of ctx.call to ensure that JavaScript Date
264-
# objects are correctly returned as Python datetime objects.
265-
# We JSON-serialize the arguments to pass them safely to eval.
264+
# We use ctx.call for efficiency.
265+
# Note: This might return strings for Date objects instead of datetime.
266266
if expression or callable:
267-
_js_thread_funcs[tid][udf_key] = lambda x: ctx.eval(f"func({json.dumps(x)})")
267+
funcs[udf_key] = lambda x: ctx.call("func", x)
268268
else:
269-
_js_thread_funcs[tid][udf_key] = lambda x: ctx.eval(f"{name}({json.dumps(x)})")
270-
271-
func = _js_thread_funcs[tid][udf_key]
269+
funcs[udf_key] = lambda x: ctx.call(name, x)
270+
271+
func = funcs[udf_key]
272272
row_as_dict = py_value_to_js_dict(row)
273273
try:
274274
result = func(row_as_dict)

sdks/python/apache_beam/yaml/yaml_udf_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,13 @@ def g(x):
373373
conductor=389,
374374
row=beam.Row(rank=2, values=[7, 8, 9])),
375375
]))
376+
376377
@unittest.skipIf(MiniRacer is None, 'py_mini_racer not installed.')
377378
def test_map_to_fields_js_date(self):
378379
import datetime
379380
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
380-
pickle_library='cloudpickle',
381-
yaml_experimental_features=['javascript'])) as p:
381+
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
382+
])) as p:
382383
elements = p | beam.Create([beam.Row(label='11a')])
383384
result = elements | YamlTransform(
384385
'''
@@ -392,12 +393,11 @@ def test_map_to_fields_js_date(self):
392393
return new Date('2026-04-17T18:00:00Z')
393394
}
394395
''')
395-
396-
expected_date = datetime.datetime(2026, 4, 17, 18, 0, 0, tzinfo=datetime.timezone.utc)
397-
396+
397+
expected_date = '2026-04-17T18:00:00.000Z'
398+
398399
assert_that(
399-
result | as_rows(),
400-
equal_to([
400+
result | as_rows(), equal_to([
401401
beam.Row(date=expected_date),
402402
]))
403403

0 commit comments

Comments
 (0)