|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import collections |
| 16 | +import mock |
| 17 | +import threading |
15 | 18 | import unittest |
16 | 19 |
|
17 | | -import mock |
| 20 | +from google.api_core import bidi |
| 21 | +from google.protobuf import proto_builder |
| 22 | +from grpc.framework.foundation import logging_pool |
| 23 | +import grpc |
18 | 24 |
|
19 | 25 | from opencensus.ext.grpc import client_interceptor |
20 | 26 | from opencensus.trace import execution_context |
@@ -282,6 +288,108 @@ def test_intercept_stream_stream_not_trace(self): |
282 | 288 | self.assertFalse(mock_tracer.end_span.called) |
283 | 289 |
|
284 | 290 |
|
| 291 | +class TestGrpcInterface(unittest.TestCase): |
| 292 | + |
| 293 | + def setUp(self): |
| 294 | + self._server = _start_server() |
| 295 | + self._port = self._server.add_insecure_port('[::]:0') |
| 296 | + self._channel = grpc.insecure_channel('localhost:%d' % self._port) |
| 297 | + |
| 298 | + def tearDown(self): |
| 299 | + self._server.stop(None) |
| 300 | + self._channel.close() |
| 301 | + |
| 302 | + def _intercepted_channel(self, tracer=None): |
| 303 | + return grpc.intercept_channel( |
| 304 | + self._channel, |
| 305 | + client_interceptor.OpenCensusClientInterceptor(tracer=tracer)) |
| 306 | + |
| 307 | + def test_bidi_rpc_stream(self): |
| 308 | + event = threading.Event() |
| 309 | + |
| 310 | + def _helper(request_iterator, context): |
| 311 | + counter = 0 |
| 312 | + for _ in request_iterator: |
| 313 | + counter += 1 |
| 314 | + if counter == 2: |
| 315 | + event.set() |
| 316 | + yield |
| 317 | + |
| 318 | + self._server.add_generic_rpc_handlers( |
| 319 | + (StreamStreamRpcHandler(_helper),)) |
| 320 | + self._server.start() |
| 321 | + |
| 322 | + rpc = bidi.BidiRpc( |
| 323 | + self._intercepted_channel().stream_stream( |
| 324 | + '', EmptyMessage.SerializeToString), |
| 325 | + initial_request=EmptyMessage()) |
| 326 | + done_event = threading.Event() |
| 327 | + rpc.add_done_callback(lambda _: done_event.set()) |
| 328 | + |
| 329 | + rpc.open() |
| 330 | + rpc.send(EmptyMessage()) |
| 331 | + self.assertTrue(event.wait(timeout=1)) |
| 332 | + rpc.close() |
| 333 | + self.assertTrue(done_event.wait(timeout=1)) |
| 334 | + |
| 335 | + @mock.patch('opencensus.trace.execution_context.get_opencensus_tracer') |
| 336 | + def test_close_span_on_done(self, mock_tracer): |
| 337 | + def _helper(request_iterator, context): |
| 338 | + for _ in request_iterator: |
| 339 | + yield EmptyMessage() |
| 340 | + yield |
| 341 | + |
| 342 | + self._server.add_generic_rpc_handlers( |
| 343 | + (StreamStreamRpcHandler(_helper), )) |
| 344 | + self._server.start() |
| 345 | + |
| 346 | + mock_tracer.return_value = mock_tracer |
| 347 | + rpc = self._intercepted_channel(NoopTracer()).stream_stream( |
| 348 | + method='', |
| 349 | + request_serializer=EmptyMessage.SerializeToString, |
| 350 | + response_deserializer=EmptyMessage.FromString)(iter( |
| 351 | + [EmptyMessage()])) |
| 352 | + |
| 353 | + for resp in rpc: |
| 354 | + pass |
| 355 | + |
| 356 | + self.assertEqual(mock_tracer.end_span.call_count, 1) |
| 357 | + |
| 358 | + |
| 359 | +EmptyMessage = proto_builder.MakeSimpleProtoClass( |
| 360 | + collections.OrderedDict([]), |
| 361 | + full_name='tests.test_client_interceptor.EmptyMessage') |
| 362 | + |
| 363 | + |
| 364 | +def _start_server(): |
| 365 | + """Starts an insecure grpc server.""" |
| 366 | + return grpc.server(logging_pool.pool(max_workers=1), |
| 367 | + options=(('grpc.so_reuseport', 0), )) |
| 368 | + |
| 369 | + |
| 370 | +class StreamStreamMethodHandler(grpc.RpcMethodHandler): |
| 371 | + |
| 372 | + def __init__(self, stream_handler_func): |
| 373 | + self.request_streaming = True |
| 374 | + self.response_streaming = True |
| 375 | + self.request_deserializer = None |
| 376 | + self.response_serializer = EmptyMessage.SerializeToString |
| 377 | + self.unary_unary = None |
| 378 | + self.unary_stream = None |
| 379 | + self.stream_unary = None |
| 380 | + self.stream_stream = stream_handler_func |
| 381 | + |
| 382 | + |
| 383 | +class StreamStreamRpcHandler(grpc.GenericRpcHandler): |
| 384 | + |
| 385 | + def __init__(self, stream_stream_handler): |
| 386 | + self._stream_stream_handler = stream_stream_handler |
| 387 | + |
| 388 | + def service(self, handler_call_details): |
| 389 | + resp = StreamStreamMethodHandler(self._stream_stream_handler) |
| 390 | + return resp |
| 391 | + |
| 392 | + |
285 | 393 | class MockTracer(object): |
286 | 394 | def __init__(self, current_span): |
287 | 395 | self.current_span = current_span |
|
0 commit comments