Skip to content
This repository was archived by the owner on Sep 17, 2025. It is now read-only.

Commit 821ea4a

Browse files
geobeaureyang
authored andcommitted
Thread pools integration (#363)
* Handle thread pools * Clean tracing of debug informations * Improve testing * Add threadpool executor * Implement testing * Fix lint
1 parent 1df8f58 commit 821ea4a

4 files changed

Lines changed: 245 additions & 9 deletions

File tree

opencensus/trace/execution_context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def set_opencensus_attr(attr_key, attr_value):
3838
setattr(_thread_local, 'attrs', attrs)
3939

4040

41+
def set_opencensus_attrs(attrs):
42+
setattr(_thread_local, 'attrs', attrs)
43+
44+
4145
def get_opencensus_attr(attr_key):
4246
attrs = getattr(_thread_local, 'attrs', None)
4347

@@ -47,6 +51,10 @@ def get_opencensus_attr(attr_key):
4751
return None
4852

4953

54+
def get_opencensus_attrs():
55+
return getattr(_thread_local, 'attrs', None)
56+
57+
5058
def get_current_span():
5159
return getattr(_thread_local, 'current_span', None)
5260

@@ -55,6 +63,30 @@ def set_current_span(current_span):
5563
setattr(_thread_local, 'current_span', current_span)
5664

5765

66+
def get_opencensus_full_context():
67+
_tracer = get_opencensus_tracer()
68+
_span = get_current_span()
69+
_attrs = get_opencensus_attrs()
70+
return _tracer, _span, _attrs
71+
72+
73+
def set_opencensus_full_context(tracer, span, attrs):
74+
set_opencensus_tracer(tracer)
75+
set_current_span(span)
76+
if not attrs:
77+
set_opencensus_attrs({})
78+
else:
79+
set_opencensus_attrs(attrs)
80+
81+
82+
def clean():
83+
setattr(_thread_local, 'attrs', {})
84+
if hasattr(_thread_local, 'current_span'):
85+
delattr(_thread_local, 'current_span')
86+
if hasattr(_thread_local, 'tracer'):
87+
delattr(_thread_local, 'tracer')
88+
89+
5890
def clear():
5991
"""Clear the thread local, used in test."""
6092
_thread_local.__dict__.clear()

opencensus/trace/ext/threading/trace.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
import logging
1616
import threading
17+
from multiprocessing import pool
18+
from concurrent import futures
1719

1820
from opencensus.trace import execution_context
21+
from opencensus.trace import tracer
22+
from opencensus.trace.propagation import binary_format
1923

2024
log = logging.getLogger(__name__)
2125

@@ -27,13 +31,29 @@ def trace_integration(tracer=None):
2731
log.info("Integrated module: {}".format(MODULE_NAME))
2832
# Wrap the threading start function
2933
start_func = getattr(threading.Thread, "start")
30-
setattr(threading.Thread, start_func.__name__,
31-
wrap_threading_start(start_func))
34+
setattr(
35+
threading.Thread, start_func.__name__, wrap_threading_start(start_func)
36+
)
3237

3338
# Wrap the threading run function
3439
run_func = getattr(threading.Thread, "run")
35-
setattr(threading.Thread, run_func.__name__,
36-
wrap_threading_run(run_func))
40+
setattr(threading.Thread, run_func.__name__, wrap_threading_run(run_func))
41+
42+
# Wrap the threading run function
43+
apply_async_func = getattr(pool.Pool, "apply_async")
44+
setattr(
45+
pool.Pool,
46+
apply_async_func.__name__,
47+
wrap_apply_async(apply_async_func),
48+
)
49+
50+
# Wrap the threading run function
51+
submit_func = getattr(futures.ThreadPoolExecutor, "submit")
52+
setattr(
53+
futures.ThreadPoolExecutor,
54+
submit_func.__name__,
55+
wrap_submit(submit_func),
56+
)
3757

3858

3959
def wrap_threading_start(start_func):
@@ -42,7 +62,9 @@ def wrap_threading_start(start_func):
4262
"""
4363

4464
def call(self):
45-
self.__opencensus_tracer = execution_context.get_opencensus_tracer()
65+
self._opencensus_context = (
66+
execution_context.get_opencensus_full_context()
67+
)
4668
return start_func(self)
4769

4870
return call
@@ -54,7 +76,80 @@ def wrap_threading_run(run_func):
5476
"""
5577

5678
def call(self):
57-
execution_context.set_opencensus_tracer(self.__opencensus_tracer)
79+
execution_context.set_opencensus_full_context(
80+
*self._opencensus_context
81+
)
5882
return run_func(self)
5983

6084
return call
85+
86+
87+
def wrap_apply_async(apply_async_func):
88+
"""Wrap the apply_async function of multiprocessing.pools. Get the function
89+
that will be called and wrap it then add the opencensus context."""
90+
91+
def call(self, func, args=(), kwds={}, **kwargs):
92+
wrapped_func = wrap_task_func(func)
93+
_tracer = execution_context.get_opencensus_tracer()
94+
propagator = binary_format.BinaryFormatPropagator()
95+
96+
wrapped_kwargs = {}
97+
print(_tracer)
98+
wrapped_kwargs["span_context_binary"] = propagator.to_header(
99+
_tracer.span_context
100+
)
101+
wrapped_kwargs["kwds"] = kwds
102+
wrapped_kwargs["sampler"] = _tracer.sampler
103+
wrapped_kwargs["exporter"] = _tracer.exporter
104+
wrapped_kwargs["propagator"] = _tracer.propagator
105+
106+
return apply_async_func(
107+
self, wrapped_func, args=args, kwds=wrapped_kwargs, **kwargs
108+
)
109+
110+
return call
111+
112+
113+
def wrap_submit(submit_func):
114+
"""Wrap the apply_async function of multiprocessing.pools. Get the function
115+
that will be called and wrap it then add the opencensus context."""
116+
117+
def call(self, func, *args, **kwargs):
118+
wrapped_func = wrap_task_func(func)
119+
_tracer = execution_context.get_opencensus_tracer()
120+
propagator = binary_format.BinaryFormatPropagator()
121+
122+
wrapped_kwargs = {}
123+
wrapped_kwargs["span_context_binary"] = propagator.to_header(
124+
_tracer.span_context
125+
)
126+
wrapped_kwargs["kwds"] = kwargs
127+
wrapped_kwargs["sampler"] = _tracer.sampler
128+
wrapped_kwargs["exporter"] = _tracer.exporter
129+
wrapped_kwargs["propagator"] = _tracer.propagator
130+
131+
return submit_func(self, wrapped_func, *args, **wrapped_kwargs)
132+
133+
return call
134+
135+
136+
class wrap_task_func(object):
137+
"""Wrap the function given to apply_async to get the tracer from context,
138+
execute the function then clear the context."""
139+
140+
def __init__(self, func):
141+
self.func = func
142+
143+
def __call__(self, *args, **kwargs):
144+
kwds = kwargs.pop("kwds")
145+
146+
span_context_binary = kwargs.pop("span_context_binary")
147+
propagator = binary_format.BinaryFormatPropagator()
148+
kwargs["span_context"] = propagator.from_header(span_context_binary)
149+
150+
_tracer = tracer.Tracer(**kwargs)
151+
execution_context.set_opencensus_tracer(_tracer)
152+
with _tracer.span(name=threading.current_thread().name):
153+
result = self.func(*args, **kwds)
154+
execution_context.clean()
155+
return result

tests/unit/trace/ext/threading/test_threading_trace.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import unittest
1616
import threading
1717
import mock
18+
from multiprocessing.pool import Pool
19+
from concurrent.futures import ThreadPoolExecutor
1820

19-
from opencensus.trace import span as span_module
2021
from opencensus.trace.ext.threading import trace
21-
from opencensus.trace import execution_context
22+
from opencensus.trace import execution_context, tracer
2223

2324
class Test_threading_trace(unittest.TestCase):
2425

@@ -29,36 +30,51 @@ def tearDown(self):
2930
def test_trace_integration(self):
3031
mock_wrap_start = mock.Mock()
3132
mock_wrap_run = mock.Mock()
33+
mock_wrap_apply_async = mock.Mock()
34+
3235
mock_threading = mock.Mock()
36+
mock_pool = mock.Mock()
3337

3438
wrap_start_result = 'wrap start result'
3539
wrap_run_result = 'wrap run result'
40+
wrap_apply_async_result = 'wrap apply_async result'
3641
mock_wrap_start.return_value = wrap_start_result
3742
mock_wrap_run.return_value = wrap_run_result
43+
mock_wrap_apply_async.return_value = wrap_apply_async_result
3844

3945
mock_start_func = mock.Mock()
4046
mock_run_func = mock.Mock()
47+
mock_apply_async_func = mock.Mock()
4148
mock_start_func.__name__ = 'start'
4249
mock_run_func.__name__ = 'run'
50+
mock_apply_async_func.__name__ = 'apply_async'
4351
setattr(mock_threading.Thread, 'start', mock_start_func)
4452
setattr(mock_threading.Thread, 'run', mock_run_func)
53+
setattr(mock_pool.Pool, 'apply_async', mock_apply_async_func)
4554

4655
patch_wrap_start = mock.patch(
4756
'opencensus.trace.ext.threading.trace.wrap_threading_start',
4857
mock_wrap_start)
4958
patch_wrap_run = mock.patch(
5059
'opencensus.trace.ext.threading.trace.wrap_threading_run',
5160
mock_wrap_run)
61+
patch_wrap_apply_async = mock.patch(
62+
'opencensus.trace.ext.threading.trace.wrap_apply_async',
63+
mock_wrap_apply_async)
5264
patch_threading = mock.patch(
5365
'opencensus.trace.ext.threading.trace.threading', mock_threading)
66+
patch_pool = mock.patch(
67+
'opencensus.trace.ext.threading.trace.pool', mock_pool)
5468

55-
with patch_wrap_start, patch_wrap_run, patch_threading:
69+
with patch_wrap_start, patch_wrap_run, patch_wrap_apply_async, patch_threading, patch_pool:
5670
trace.trace_integration()
5771

5872
self.assertEqual(getattr(mock_threading.Thread, 'start'),
5973
wrap_start_result)
6074
self.assertEqual(getattr(mock_threading.Thread, 'run'),
6175
wrap_run_result)
76+
self.assertEqual(getattr(mock_pool.Pool, 'apply_async'),
77+
wrap_apply_async_result)
6278

6379
def test_wrap_threading(self):
6480
global global_tracer
@@ -75,11 +91,45 @@ def test_wrap_threading(self):
7591
t.join()
7692
assert isinstance(global_tracer, MockTracer)
7793

94+
def test_wrap_pool(self):
95+
_tracer = tracer.Tracer()
96+
execution_context.set_opencensus_tracer(tracer)
97+
98+
trace.trace_integration()
99+
context = tracer.Tracer().span_context
100+
print(context.trace_id)
101+
102+
pool = Pool(processes=1)
103+
with _tracer.span(name='span1'):
104+
result = pool.apply_async(fake_pooled_func, ()).get(timeout=1)
105+
106+
self.assertEqual(result, context.trace_id)
107+
108+
def test_wrap_futures(self):
109+
_tracer = tracer.Tracer()
110+
execution_context.set_opencensus_tracer(tracer)
111+
112+
trace.trace_integration()
113+
context = tracer.Tracer().span_context
114+
print(context.trace_id)
115+
116+
pool = ThreadPoolExecutor(max_workers=1)
117+
with _tracer.span(name='span1'):
118+
future = pool.submit(fake_pooled_func)
119+
result = future.result()
120+
121+
self.assertEqual(result, context.trace_id)
122+
78123
def fake_threaded_func(self):
79124
global global_tracer
80125
global_tracer = execution_context.get_opencensus_tracer()
81126

82127

128+
def fake_pooled_func():
129+
_tracer = execution_context.get_opencensus_tracer()
130+
return _tracer.span_context.trace_id
131+
132+
83133
class MockTracer(object):
84134
def __init__(self, span=None):
85135
self.span = span

tests/unit/trace/test_execution_context.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
import mock
17+
import threading
1618

1719
from opencensus.trace import execution_context
1820

@@ -38,3 +40,60 @@ def test_has_attrs(self):
3840
result = execution_context.get_opencensus_attr(key)
3941

4042
self.assertEqual(result, value)
43+
44+
def test_get_and_set_full_context(self):
45+
mock_tracer_get = mock.Mock()
46+
mock_span_get = mock.Mock()
47+
execution_context.set_opencensus_tracer(mock_tracer_get)
48+
execution_context.set_current_span(mock_span_get)
49+
50+
execution_context.set_opencensus_attr("test", "test_value")
51+
52+
tracer, span, attrs = execution_context.get_opencensus_full_context()
53+
54+
self.assertEqual(mock_tracer_get, tracer)
55+
self.assertEqual(mock_span_get, span)
56+
self.assertEqual({"test":"test_value"}, attrs)
57+
58+
mock_tracer_set = mock.Mock()
59+
mock_span_set = mock.Mock()
60+
61+
execution_context.set_opencensus_full_context(mock_tracer_set, mock_span_set, None)
62+
self.assertEqual(mock_tracer_set, execution_context.get_opencensus_tracer())
63+
self.assertEqual(mock_span_set, execution_context.get_current_span())
64+
self.assertEqual({}, execution_context.get_opencensus_attrs())
65+
66+
execution_context.set_opencensus_full_context(mock_tracer_set, mock_span_set, {"test": "test_value"})
67+
self.assertEqual("test_value", execution_context.get_opencensus_attr("test"))
68+
69+
def test_clean_tracer(self):
70+
mock_tracer = mock.Mock()
71+
some_value = mock.Mock()
72+
execution_context.set_opencensus_tracer(mock_tracer)
73+
74+
thread_local = threading.local()
75+
setattr(thread_local, 'random_non_oc_attr', some_value)
76+
77+
execution_context.clean()
78+
79+
self.assertNotEqual(mock_tracer, execution_context.get_opencensus_tracer())
80+
self.assertEqual(some_value, getattr(thread_local, 'random_non_oc_attr'))
81+
82+
def test_clean_span(self):
83+
mock_span = mock.Mock()
84+
some_value = mock.Mock()
85+
execution_context.set_current_span(mock_span)
86+
87+
thread_local = threading.local()
88+
setattr(thread_local, 'random_non_oc_attr', some_value)
89+
90+
execution_context.clean()
91+
92+
self.assertNotEqual(mock_span, execution_context.get_current_span())
93+
self.assertEqual(some_value, getattr(thread_local, 'random_non_oc_attr'))
94+
95+
96+
97+
98+
99+

0 commit comments

Comments
 (0)