473 lines
16 KiB
Python
473 lines
16 KiB
Python
|
# Copyright 2017 gRPC authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Implementation of gRPC Python interceptors."""
|
||
|
|
||
|
import collections
|
||
|
import sys
|
||
|
|
||
|
import grpc
|
||
|
|
||
|
|
||
|
class _ServicePipeline(object):
|
||
|
|
||
|
def __init__(self, interceptors):
|
||
|
self.interceptors = tuple(interceptors)
|
||
|
|
||
|
def _continuation(self, thunk, index):
|
||
|
return lambda context: self._intercept_at(thunk, index, context)
|
||
|
|
||
|
def _intercept_at(self, thunk, index, context):
|
||
|
if index < len(self.interceptors):
|
||
|
interceptor = self.interceptors[index]
|
||
|
thunk = self._continuation(thunk, index + 1)
|
||
|
return interceptor.intercept_service(thunk, context)
|
||
|
else:
|
||
|
return thunk(context)
|
||
|
|
||
|
def execute(self, thunk, context):
|
||
|
return self._intercept_at(thunk, 0, context)
|
||
|
|
||
|
|
||
|
def service_pipeline(interceptors):
|
||
|
return _ServicePipeline(interceptors) if interceptors else None
|
||
|
|
||
|
|
||
|
class _ClientCallDetails(
|
||
|
collections.namedtuple(
|
||
|
'_ClientCallDetails',
|
||
|
('method', 'timeout', 'metadata', 'credentials')),
|
||
|
grpc.ClientCallDetails):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _unwrap_client_call_details(call_details, default_details):
|
||
|
try:
|
||
|
method = call_details.method
|
||
|
except AttributeError:
|
||
|
method = default_details.method
|
||
|
|
||
|
try:
|
||
|
timeout = call_details.timeout
|
||
|
except AttributeError:
|
||
|
timeout = default_details.timeout
|
||
|
|
||
|
try:
|
||
|
metadata = call_details.metadata
|
||
|
except AttributeError:
|
||
|
metadata = default_details.metadata
|
||
|
|
||
|
try:
|
||
|
credentials = call_details.credentials
|
||
|
except AttributeError:
|
||
|
credentials = default_details.credentials
|
||
|
|
||
|
return method, timeout, metadata, credentials
|
||
|
|
||
|
|
||
|
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call):
|
||
|
|
||
|
def __init__(self, exception, traceback):
|
||
|
super(_FailureOutcome, self).__init__()
|
||
|
self._exception = exception
|
||
|
self._traceback = traceback
|
||
|
|
||
|
def initial_metadata(self):
|
||
|
return None
|
||
|
|
||
|
def trailing_metadata(self):
|
||
|
return None
|
||
|
|
||
|
def code(self):
|
||
|
return grpc.StatusCode.INTERNAL
|
||
|
|
||
|
def details(self):
|
||
|
return 'Exception raised while intercepting the RPC'
|
||
|
|
||
|
def cancel(self):
|
||
|
return False
|
||
|
|
||
|
def cancelled(self):
|
||
|
return False
|
||
|
|
||
|
def is_active(self):
|
||
|
return False
|
||
|
|
||
|
def time_remaining(self):
|
||
|
return None
|
||
|
|
||
|
def running(self):
|
||
|
return False
|
||
|
|
||
|
def done(self):
|
||
|
return True
|
||
|
|
||
|
def result(self, ignored_timeout=None):
|
||
|
raise self._exception
|
||
|
|
||
|
def exception(self, ignored_timeout=None):
|
||
|
return self._exception
|
||
|
|
||
|
def traceback(self, ignored_timeout=None):
|
||
|
return self._traceback
|
||
|
|
||
|
def add_callback(self, callback):
|
||
|
return False
|
||
|
|
||
|
def add_done_callback(self, fn):
|
||
|
fn(self)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def next(self):
|
||
|
raise self._exception
|
||
|
|
||
|
|
||
|
class _UnaryOutcome(grpc.Call, grpc.Future):
|
||
|
|
||
|
def __init__(self, response, call):
|
||
|
self._response = response
|
||
|
self._call = call
|
||
|
|
||
|
def initial_metadata(self):
|
||
|
return self._call.initial_metadata()
|
||
|
|
||
|
def trailing_metadata(self):
|
||
|
return self._call.trailing_metadata()
|
||
|
|
||
|
def code(self):
|
||
|
return self._call.code()
|
||
|
|
||
|
def details(self):
|
||
|
return self._call.details()
|
||
|
|
||
|
def is_active(self):
|
||
|
return self._call.is_active()
|
||
|
|
||
|
def time_remaining(self):
|
||
|
return self._call.time_remaining()
|
||
|
|
||
|
def cancel(self):
|
||
|
return self._call.cancel()
|
||
|
|
||
|
def add_callback(self, callback):
|
||
|
return self._call.add_callback(callback)
|
||
|
|
||
|
def cancelled(self):
|
||
|
return False
|
||
|
|
||
|
def running(self):
|
||
|
return False
|
||
|
|
||
|
def done(self):
|
||
|
return True
|
||
|
|
||
|
def result(self, ignored_timeout=None):
|
||
|
return self._response
|
||
|
|
||
|
def exception(self, ignored_timeout=None):
|
||
|
return None
|
||
|
|
||
|
def traceback(self, ignored_timeout=None):
|
||
|
return None
|
||
|
|
||
|
def add_done_callback(self, fn):
|
||
|
fn(self)
|
||
|
|
||
|
|
||
|
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
|
||
|
|
||
|
def __init__(self, thunk, method, interceptor):
|
||
|
self._thunk = thunk
|
||
|
self._method = method
|
||
|
self._interceptor = interceptor
|
||
|
|
||
|
def __call__(self, request, timeout=None, metadata=None, credentials=None):
|
||
|
response, ignored_call = self._with_call(
|
||
|
request,
|
||
|
timeout=timeout,
|
||
|
metadata=metadata,
|
||
|
credentials=credentials)
|
||
|
return response
|
||
|
|
||
|
def _with_call(self, request, timeout=None, metadata=None,
|
||
|
credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
try:
|
||
|
response, call = self._thunk(new_method).with_call(
|
||
|
request,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
return _UnaryOutcome(response, call)
|
||
|
except grpc.RpcError:
|
||
|
raise
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
call = self._interceptor.intercept_unary_unary(
|
||
|
continuation, client_call_details, request)
|
||
|
return call.result(), call
|
||
|
|
||
|
def with_call(self, request, timeout=None, metadata=None, credentials=None):
|
||
|
return self._with_call(
|
||
|
request,
|
||
|
timeout=timeout,
|
||
|
metadata=metadata,
|
||
|
credentials=credentials)
|
||
|
|
||
|
def future(self, request, timeout=None, metadata=None, credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
return self._thunk(new_method).future(
|
||
|
request,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
|
||
|
try:
|
||
|
return self._interceptor.intercept_unary_unary(
|
||
|
continuation, client_call_details, request)
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
|
||
|
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
|
||
|
|
||
|
def __init__(self, thunk, method, interceptor):
|
||
|
self._thunk = thunk
|
||
|
self._method = method
|
||
|
self._interceptor = interceptor
|
||
|
|
||
|
def __call__(self, request, timeout=None, metadata=None, credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
return self._thunk(new_method)(
|
||
|
request,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
|
||
|
try:
|
||
|
return self._interceptor.intercept_unary_stream(
|
||
|
continuation, client_call_details, request)
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
|
||
|
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
|
||
|
|
||
|
def __init__(self, thunk, method, interceptor):
|
||
|
self._thunk = thunk
|
||
|
self._method = method
|
||
|
self._interceptor = interceptor
|
||
|
|
||
|
def __call__(self,
|
||
|
request_iterator,
|
||
|
timeout=None,
|
||
|
metadata=None,
|
||
|
credentials=None):
|
||
|
response, ignored_call = self._with_call(
|
||
|
request_iterator,
|
||
|
timeout=timeout,
|
||
|
metadata=metadata,
|
||
|
credentials=credentials)
|
||
|
return response
|
||
|
|
||
|
def _with_call(self,
|
||
|
request_iterator,
|
||
|
timeout=None,
|
||
|
metadata=None,
|
||
|
credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request_iterator):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
try:
|
||
|
response, call = self._thunk(new_method).with_call(
|
||
|
request_iterator,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
return _UnaryOutcome(response, call)
|
||
|
except grpc.RpcError:
|
||
|
raise
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
call = self._interceptor.intercept_stream_unary(
|
||
|
continuation, client_call_details, request_iterator)
|
||
|
return call.result(), call
|
||
|
|
||
|
def with_call(self,
|
||
|
request_iterator,
|
||
|
timeout=None,
|
||
|
metadata=None,
|
||
|
credentials=None):
|
||
|
return self._with_call(
|
||
|
request_iterator,
|
||
|
timeout=timeout,
|
||
|
metadata=metadata,
|
||
|
credentials=credentials)
|
||
|
|
||
|
def future(self,
|
||
|
request_iterator,
|
||
|
timeout=None,
|
||
|
metadata=None,
|
||
|
credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request_iterator):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
return self._thunk(new_method).future(
|
||
|
request_iterator,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
|
||
|
try:
|
||
|
return self._interceptor.intercept_stream_unary(
|
||
|
continuation, client_call_details, request_iterator)
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
|
||
|
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
|
||
|
|
||
|
def __init__(self, thunk, method, interceptor):
|
||
|
self._thunk = thunk
|
||
|
self._method = method
|
||
|
self._interceptor = interceptor
|
||
|
|
||
|
def __call__(self,
|
||
|
request_iterator,
|
||
|
timeout=None,
|
||
|
metadata=None,
|
||
|
credentials=None):
|
||
|
client_call_details = _ClientCallDetails(self._method, timeout,
|
||
|
metadata, credentials)
|
||
|
|
||
|
def continuation(new_details, request_iterator):
|
||
|
new_method, new_timeout, new_metadata, new_credentials = (
|
||
|
_unwrap_client_call_details(new_details, client_call_details))
|
||
|
return self._thunk(new_method)(
|
||
|
request_iterator,
|
||
|
timeout=new_timeout,
|
||
|
metadata=new_metadata,
|
||
|
credentials=new_credentials)
|
||
|
|
||
|
try:
|
||
|
return self._interceptor.intercept_stream_stream(
|
||
|
continuation, client_call_details, request_iterator)
|
||
|
except Exception as exception: # pylint:disable=broad-except
|
||
|
return _FailureOutcome(exception, sys.exc_info()[2])
|
||
|
|
||
|
|
||
|
class _Channel(grpc.Channel):
|
||
|
|
||
|
def __init__(self, channel, interceptor):
|
||
|
self._channel = channel
|
||
|
self._interceptor = interceptor
|
||
|
|
||
|
def subscribe(self, callback, try_to_connect=False):
|
||
|
self._channel.subscribe(callback, try_to_connect=try_to_connect)
|
||
|
|
||
|
def unsubscribe(self, callback):
|
||
|
self._channel.unsubscribe(callback)
|
||
|
|
||
|
def unary_unary(self,
|
||
|
method,
|
||
|
request_serializer=None,
|
||
|
response_deserializer=None):
|
||
|
thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer)
|
||
|
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
|
||
|
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
|
||
|
else:
|
||
|
return thunk(method)
|
||
|
|
||
|
def unary_stream(self,
|
||
|
method,
|
||
|
request_serializer=None,
|
||
|
response_deserializer=None):
|
||
|
thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer)
|
||
|
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
|
||
|
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
|
||
|
else:
|
||
|
return thunk(method)
|
||
|
|
||
|
def stream_unary(self,
|
||
|
method,
|
||
|
request_serializer=None,
|
||
|
response_deserializer=None):
|
||
|
thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer)
|
||
|
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
|
||
|
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
|
||
|
else:
|
||
|
return thunk(method)
|
||
|
|
||
|
def stream_stream(self,
|
||
|
method,
|
||
|
request_serializer=None,
|
||
|
response_deserializer=None):
|
||
|
thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer)
|
||
|
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
|
||
|
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
|
||
|
else:
|
||
|
return thunk(method)
|
||
|
|
||
|
def _close(self):
|
||
|
self._channel.close()
|
||
|
|
||
|
def __enter__(self):
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
|
self._close()
|
||
|
return False
|
||
|
|
||
|
def close(self):
|
||
|
self._channel.close()
|
||
|
|
||
|
|
||
|
def intercept_channel(channel, *interceptors):
|
||
|
for interceptor in reversed(list(interceptors)):
|
||
|
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
|
||
|
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
|
||
|
not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
|
||
|
not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
|
||
|
raise TypeError('interceptor must be '
|
||
|
'grpc.UnaryUnaryClientInterceptor or '
|
||
|
'grpc.UnaryStreamClientInterceptor or '
|
||
|
'grpc.StreamUnaryClientInterceptor or '
|
||
|
'grpc.StreamStreamClientInterceptor or ')
|
||
|
channel = _Channel(channel, interceptor)
|
||
|
return channel
|