@@ -78,12 +78,11 @@ def test_wrap_requests(self):
7878 self .assertEqual (expected_name , mock_tracer .current_span .name )
7979
8080 def test_wrap_session_request (self ):
81- def wrapped (* args , ** kwargs ):
82- result = mock .Mock ()
83- result .status_code = 200
84- return result
81+ wrapped = mock .Mock (return_value = mock .Mock (status_code = 200 ))
8582
86- mock_tracer = MockTracer ()
83+ mock_tracer = MockTracer (propagator = mock .Mock (
84+ to_headers = lambda x : {'x-trace' : 'some-value' })
85+ )
8786
8887 patch = mock .patch (
8988 'opencensus.trace.ext.requests.trace.execution_context.'
@@ -92,10 +91,11 @@ def wrapped(*args, **kwargs):
9291
9392 url = 'http://localhost:8080'
9493 request_method = 'POST'
94+ kwargs = {}
9595
9696 with patch :
9797 result = trace .wrap_session_request (
98- wrapped , 'Session.request' , (request_method , url ), {} )
98+ wrapped , 'Session.request' , (request_method , url ), kwargs )
9999
100100 expected_attributes = {
101101 'http.url' : url ,
@@ -106,12 +106,79 @@ def wrapped(*args, **kwargs):
106106 mock_tracer .current_span .span_kind )
107107 self .assertEqual (expected_attributes ,
108108 mock_tracer .current_span .attributes )
109+ self .assertEqual (kwargs ['headers' ]['x-trace' ], 'some-value' )
109110 self .assertEqual (expected_name , mock_tracer .current_span .name )
110111
112+ def test_header_is_passed_in (self ):
113+ wrapped = mock .Mock (return_value = mock .Mock (status_code = 200 ))
114+ mock_tracer = MockTracer (propagator = mock .Mock (
115+ to_headers = lambda x : {'x-trace' : 'some-value' })
116+ )
117+
118+ patch = mock .patch (
119+ 'opencensus.trace.ext.requests.trace.execution_context.'
120+ 'get_opencensus_tracer' ,
121+ return_value = mock_tracer )
122+
123+ url = 'http://localhost:8080'
124+ request_method = 'POST'
125+ kwargs = {}
126+
127+ with patch :
128+ result = trace .wrap_session_request (
129+ wrapped , 'Session.request' , (request_method , url ), kwargs )
130+
131+ self .assertEqual (kwargs ['headers' ]['x-trace' ], 'some-value' )
132+
133+ def test_headers_are_preserved (self ):
134+ wrapped = mock .Mock (return_value = mock .Mock (status_code = 200 ))
135+ mock_tracer = MockTracer (propagator = mock .Mock (
136+ to_headers = lambda x : {'x-trace' : 'some-value' })
137+ )
138+
139+ patch = mock .patch (
140+ 'opencensus.trace.ext.requests.trace.execution_context.'
141+ 'get_opencensus_tracer' ,
142+ return_value = mock_tracer )
143+
144+ url = 'http://localhost:8080'
145+ request_method = 'POST'
146+ kwargs = {'headers' : {'key' : 'value' }}
147+
148+ with patch :
149+ result = trace .wrap_session_request (
150+ wrapped , 'Session.request' , (request_method , url ), kwargs )
151+
152+ self .assertEqual (kwargs ['headers' ]['key' ], 'value' )
153+ self .assertEqual (kwargs ['headers' ]['x-trace' ], 'some-value' )
154+
155+
156+ def test_tracer_headers_are_overwritten (self ):
157+ wrapped = mock .Mock (return_value = mock .Mock (status_code = 200 ))
158+ mock_tracer = MockTracer (propagator = mock .Mock (
159+ to_headers = lambda x : {'x-trace' : 'some-value' })
160+ )
161+
162+ patch = mock .patch (
163+ 'opencensus.trace.ext.requests.trace.execution_context.'
164+ 'get_opencensus_tracer' ,
165+ return_value = mock_tracer )
166+
167+ url = 'http://localhost:8080'
168+ request_method = 'POST'
169+ kwargs = {'headers' : {'x-trace' : 'original-value' }}
170+
171+ with patch :
172+ result = trace .wrap_session_request (
173+ wrapped , 'Session.request' , (request_method , url ), kwargs )
174+
175+ self .assertEqual (kwargs ['headers' ]['x-trace' ], 'some-value' )
111176
112177class MockTracer (object ):
113- def __init__ (self ):
178+ def __init__ (self , propagator = None ):
114179 self .current_span = None
180+ self .span_context = {}
181+ self .propagator = propagator
115182
116183 def start_span (self ):
117184 span = mock .Mock ()
0 commit comments