1414
1515import logging
1616import threading
17+ from multiprocessing import pool
18+ from concurrent import futures
1719
1820from opencensus .trace import execution_context
21+ from opencensus .trace import tracer
22+ from opencensus .trace .propagation import binary_format
1923
2024log = logging .getLogger (__name__ )
2125
@@ -27,13 +31,29 @@ def trace_integration(tracer=None):
2731 log .info ("Integrated module: {}" .format (MODULE_NAME ))
2832 # Wrap the threading start function
2933 start_func = getattr (threading .Thread , "start" )
30- setattr (threading .Thread , start_func .__name__ ,
31- wrap_threading_start (start_func ))
34+ setattr (
35+ threading .Thread , start_func .__name__ , wrap_threading_start (start_func )
36+ )
3237
3338 # Wrap the threading run function
3439 run_func = getattr (threading .Thread , "run" )
35- setattr (threading .Thread , run_func .__name__ ,
36- wrap_threading_run (run_func ))
40+ setattr (threading .Thread , run_func .__name__ , wrap_threading_run (run_func ))
41+
42+ # Wrap the threading run function
43+ apply_async_func = getattr (pool .Pool , "apply_async" )
44+ setattr (
45+ pool .Pool ,
46+ apply_async_func .__name__ ,
47+ wrap_apply_async (apply_async_func ),
48+ )
49+
50+ # Wrap the threading run function
51+ submit_func = getattr (futures .ThreadPoolExecutor , "submit" )
52+ setattr (
53+ futures .ThreadPoolExecutor ,
54+ submit_func .__name__ ,
55+ wrap_submit (submit_func ),
56+ )
3757
3858
3959def wrap_threading_start (start_func ):
@@ -42,7 +62,9 @@ def wrap_threading_start(start_func):
4262 """
4363
4464 def call (self ):
45- self .__opencensus_tracer = execution_context .get_opencensus_tracer ()
65+ self ._opencensus_context = (
66+ execution_context .get_opencensus_full_context ()
67+ )
4668 return start_func (self )
4769
4870 return call
@@ -54,7 +76,80 @@ def wrap_threading_run(run_func):
5476 """
5577
5678 def call (self ):
57- execution_context .set_opencensus_tracer (self .__opencensus_tracer )
79+ execution_context .set_opencensus_full_context (
80+ * self ._opencensus_context
81+ )
5882 return run_func (self )
5983
6084 return call
85+
86+
87+ def wrap_apply_async (apply_async_func ):
88+ """Wrap the apply_async function of multiprocessing.pools. Get the function
89+ that will be called and wrap it then add the opencensus context."""
90+
91+ def call (self , func , args = (), kwds = {}, ** kwargs ):
92+ wrapped_func = wrap_task_func (func )
93+ _tracer = execution_context .get_opencensus_tracer ()
94+ propagator = binary_format .BinaryFormatPropagator ()
95+
96+ wrapped_kwargs = {}
97+ print (_tracer )
98+ wrapped_kwargs ["span_context_binary" ] = propagator .to_header (
99+ _tracer .span_context
100+ )
101+ wrapped_kwargs ["kwds" ] = kwds
102+ wrapped_kwargs ["sampler" ] = _tracer .sampler
103+ wrapped_kwargs ["exporter" ] = _tracer .exporter
104+ wrapped_kwargs ["propagator" ] = _tracer .propagator
105+
106+ return apply_async_func (
107+ self , wrapped_func , args = args , kwds = wrapped_kwargs , ** kwargs
108+ )
109+
110+ return call
111+
112+
113+ def wrap_submit (submit_func ):
114+ """Wrap the apply_async function of multiprocessing.pools. Get the function
115+ that will be called and wrap it then add the opencensus context."""
116+
117+ def call (self , func , * args , ** kwargs ):
118+ wrapped_func = wrap_task_func (func )
119+ _tracer = execution_context .get_opencensus_tracer ()
120+ propagator = binary_format .BinaryFormatPropagator ()
121+
122+ wrapped_kwargs = {}
123+ wrapped_kwargs ["span_context_binary" ] = propagator .to_header (
124+ _tracer .span_context
125+ )
126+ wrapped_kwargs ["kwds" ] = kwargs
127+ wrapped_kwargs ["sampler" ] = _tracer .sampler
128+ wrapped_kwargs ["exporter" ] = _tracer .exporter
129+ wrapped_kwargs ["propagator" ] = _tracer .propagator
130+
131+ return submit_func (self , wrapped_func , * args , ** wrapped_kwargs )
132+
133+ return call
134+
135+
136+ class wrap_task_func (object ):
137+ """Wrap the function given to apply_async to get the tracer from context,
138+ execute the function then clear the context."""
139+
140+ def __init__ (self , func ):
141+ self .func = func
142+
143+ def __call__ (self , * args , ** kwargs ):
144+ kwds = kwargs .pop ("kwds" )
145+
146+ span_context_binary = kwargs .pop ("span_context_binary" )
147+ propagator = binary_format .BinaryFormatPropagator ()
148+ kwargs ["span_context" ] = propagator .from_header (span_context_binary )
149+
150+ _tracer = tracer .Tracer (** kwargs )
151+ execution_context .set_opencensus_tracer (_tracer )
152+ with _tracer .span (name = threading .current_thread ().name ):
153+ result = self .func (* args , ** kwds )
154+ execution_context .clean ()
155+ return result
0 commit comments