983 lines
38 KiB
Python
983 lines
38 KiB
Python
# Copyright 2016 gRPC authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Invocation-side implementation of gRPC Python."""
|
|
|
|
import logging
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
import grpc
|
|
from grpc import _common
|
|
from grpc import _grpcio_metadata
|
|
from grpc._cython import cygrpc
|
|
from grpc.framework.foundation import callable_util
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
_USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__)
|
|
|
|
_EMPTY_FLAGS = 0
|
|
|
|
_UNARY_UNARY_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.send_message,
|
|
cygrpc.OperationType.send_close_from_client,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_message,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_UNARY_STREAM_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.send_message,
|
|
cygrpc.OperationType.send_close_from_client,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_STREAM_UNARY_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_message,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
_STREAM_STREAM_INITIAL_DUE = (
|
|
cygrpc.OperationType.send_initial_metadata,
|
|
cygrpc.OperationType.receive_initial_metadata,
|
|
cygrpc.OperationType.receive_status_on_client,
|
|
)
|
|
|
|
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
|
|
'Exception calling channel subscription callback!')
|
|
|
|
_OK_RENDEZVOUS_REPR_FORMAT = ('<_Rendezvous of RPC that terminated with:\n'
|
|
'\tstatus = {}\n'
|
|
'\tdetails = "{}"\n'
|
|
'>')
|
|
|
|
_NON_OK_RENDEZVOUS_REPR_FORMAT = ('<_Rendezvous of RPC that terminated with:\n'
|
|
'\tstatus = {}\n'
|
|
'\tdetails = "{}"\n'
|
|
'\tdebug_error_string = "{}"\n'
|
|
'>')
|
|
|
|
|
|
def _deadline(timeout):
|
|
return None if timeout is None else time.time() + timeout
|
|
|
|
|
|
def _unknown_code_details(unknown_cygrpc_code, details):
|
|
return 'Server sent unknown code {} and details "{}"'.format(
|
|
unknown_cygrpc_code, details)
|
|
|
|
|
|
def _wait_once_until(condition, until):
|
|
if until is None:
|
|
condition.wait()
|
|
else:
|
|
remaining = until - time.time()
|
|
if remaining < 0:
|
|
raise grpc.FutureTimeoutError()
|
|
else:
|
|
condition.wait(timeout=remaining)
|
|
|
|
|
|
class _RPCState(object):
|
|
|
|
def __init__(self, due, initial_metadata, trailing_metadata, code, details):
|
|
self.condition = threading.Condition()
|
|
# The cygrpc.OperationType objects representing events due from the RPC's
|
|
# completion queue.
|
|
self.due = set(due)
|
|
self.initial_metadata = initial_metadata
|
|
self.response = None
|
|
self.trailing_metadata = trailing_metadata
|
|
self.code = code
|
|
self.details = details
|
|
self.debug_error_string = None
|
|
# The semantics of grpc.Future.cancel and grpc.Future.cancelled are
|
|
# slightly wonky, so they have to be tracked separately from the rest of the
|
|
# result of the RPC. This field tracks whether cancellation was requested
|
|
# prior to termination of the RPC.
|
|
self.cancelled = False
|
|
self.callbacks = []
|
|
self.fork_epoch = cygrpc.get_fork_epoch()
|
|
|
|
def reset_postfork_child(self):
|
|
self.condition = threading.Condition()
|
|
|
|
|
|
def _abort(state, code, details):
|
|
if state.code is None:
|
|
state.code = code
|
|
state.details = details
|
|
if state.initial_metadata is None:
|
|
state.initial_metadata = ()
|
|
state.trailing_metadata = ()
|
|
|
|
|
|
def _handle_event(event, state, response_deserializer):
|
|
callbacks = []
|
|
for batch_operation in event.batch_operations:
|
|
operation_type = batch_operation.type()
|
|
state.due.remove(operation_type)
|
|
if operation_type == cygrpc.OperationType.receive_initial_metadata:
|
|
state.initial_metadata = batch_operation.initial_metadata()
|
|
elif operation_type == cygrpc.OperationType.receive_message:
|
|
serialized_response = batch_operation.message()
|
|
if serialized_response is not None:
|
|
response = _common.deserialize(serialized_response,
|
|
response_deserializer)
|
|
if response is None:
|
|
details = 'Exception deserializing response!'
|
|
_abort(state, grpc.StatusCode.INTERNAL, details)
|
|
else:
|
|
state.response = response
|
|
elif operation_type == cygrpc.OperationType.receive_status_on_client:
|
|
state.trailing_metadata = batch_operation.trailing_metadata()
|
|
if state.code is None:
|
|
code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get(
|
|
batch_operation.code())
|
|
if code is None:
|
|
state.code = grpc.StatusCode.UNKNOWN
|
|
state.details = _unknown_code_details(
|
|
code, batch_operation.details())
|
|
else:
|
|
state.code = code
|
|
state.details = batch_operation.details()
|
|
state.debug_error_string = batch_operation.error_string()
|
|
callbacks.extend(state.callbacks)
|
|
state.callbacks = None
|
|
return callbacks
|
|
|
|
|
|
def _event_handler(state, response_deserializer):
|
|
|
|
def handle_event(event):
|
|
with state.condition:
|
|
callbacks = _handle_event(event, state, response_deserializer)
|
|
state.condition.notify_all()
|
|
done = not state.due
|
|
for callback in callbacks:
|
|
callback()
|
|
return done and state.fork_epoch >= cygrpc.get_fork_epoch()
|
|
|
|
return handle_event
|
|
|
|
|
|
def _consume_request_iterator(request_iterator, state, call, request_serializer,
|
|
event_handler):
|
|
if cygrpc.is_fork_support_enabled():
|
|
condition_wait_timeout = 1.0
|
|
else:
|
|
condition_wait_timeout = None
|
|
|
|
def consume_request_iterator(): # pylint: disable=too-many-branches
|
|
while True:
|
|
return_from_user_request_generator_invoked = False
|
|
try:
|
|
# The thread may die in user-code. Do not block fork for this.
|
|
cygrpc.enter_user_request_generator()
|
|
request = next(request_iterator)
|
|
except StopIteration:
|
|
break
|
|
except Exception: # pylint: disable=broad-except
|
|
cygrpc.return_from_user_request_generator()
|
|
return_from_user_request_generator_invoked = True
|
|
code = grpc.StatusCode.UNKNOWN
|
|
details = 'Exception iterating requests!'
|
|
_LOGGER.exception(details)
|
|
call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
|
|
details)
|
|
_abort(state, code, details)
|
|
return
|
|
finally:
|
|
if not return_from_user_request_generator_invoked:
|
|
cygrpc.return_from_user_request_generator()
|
|
serialized_request = _common.serialize(request, request_serializer)
|
|
with state.condition:
|
|
if state.code is None and not state.cancelled:
|
|
if serialized_request is None:
|
|
code = grpc.StatusCode.INTERNAL
|
|
details = 'Exception serializing request!'
|
|
call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code],
|
|
details)
|
|
_abort(state, code, details)
|
|
return
|
|
else:
|
|
operations = (cygrpc.SendMessageOperation(
|
|
serialized_request, _EMPTY_FLAGS),)
|
|
operating = call.operate(operations, event_handler)
|
|
if operating:
|
|
state.due.add(cygrpc.OperationType.send_message)
|
|
else:
|
|
return
|
|
while True:
|
|
state.condition.wait(condition_wait_timeout)
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
if state.code is None:
|
|
if cygrpc.OperationType.send_message not in state.due:
|
|
break
|
|
else:
|
|
return
|
|
else:
|
|
return
|
|
with state.condition:
|
|
if state.code is None:
|
|
operations = (
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),)
|
|
operating = call.operate(operations, event_handler)
|
|
if operating:
|
|
state.due.add(cygrpc.OperationType.send_close_from_client)
|
|
|
|
consumption_thread = cygrpc.ForkManagedThread(
|
|
target=consume_request_iterator)
|
|
consumption_thread.setDaemon(True)
|
|
consumption_thread.start()
|
|
|
|
|
|
class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
|
|
|
|
def __init__(self, state, call, response_deserializer, deadline):
|
|
super(_Rendezvous, self).__init__()
|
|
self._state = state
|
|
self._call = call
|
|
self._response_deserializer = response_deserializer
|
|
self._deadline = deadline
|
|
|
|
def cancel(self):
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
code = grpc.StatusCode.CANCELLED
|
|
details = 'Locally cancelled by application!'
|
|
self._call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details)
|
|
self._state.cancelled = True
|
|
_abort(self._state, code, details)
|
|
self._state.condition.notify_all()
|
|
return False
|
|
|
|
def cancelled(self):
|
|
with self._state.condition:
|
|
return self._state.cancelled
|
|
|
|
def running(self):
|
|
with self._state.condition:
|
|
return self._state.code is None
|
|
|
|
def done(self):
|
|
with self._state.condition:
|
|
return self._state.code is not None
|
|
|
|
def result(self, timeout=None):
|
|
until = None if timeout is None else time.time() + timeout
|
|
with self._state.condition:
|
|
while True:
|
|
if self._state.code is None:
|
|
_wait_once_until(self._state.condition, until)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
return self._state.response
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
raise self
|
|
|
|
def exception(self, timeout=None):
|
|
until = None if timeout is None else time.time() + timeout
|
|
with self._state.condition:
|
|
while True:
|
|
if self._state.code is None:
|
|
_wait_once_until(self._state.condition, until)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
return self
|
|
|
|
def traceback(self, timeout=None):
|
|
until = None if timeout is None else time.time() + timeout
|
|
with self._state.condition:
|
|
while True:
|
|
if self._state.code is None:
|
|
_wait_once_until(self._state.condition, until)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
return None
|
|
elif self._state.cancelled:
|
|
raise grpc.FutureCancelledError()
|
|
else:
|
|
try:
|
|
raise self
|
|
except grpc.RpcError:
|
|
return sys.exc_info()[2]
|
|
|
|
def add_done_callback(self, fn):
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
self._state.callbacks.append(lambda: fn(self))
|
|
return
|
|
|
|
fn(self)
|
|
|
|
def _next(self):
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
event_handler = _event_handler(self._state,
|
|
self._response_deserializer)
|
|
operating = self._call.operate(
|
|
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
|
|
event_handler)
|
|
if operating:
|
|
self._state.due.add(cygrpc.OperationType.receive_message)
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
else:
|
|
raise self
|
|
while True:
|
|
self._state.condition.wait()
|
|
if self._state.response is not None:
|
|
response = self._state.response
|
|
self._state.response = None
|
|
return response
|
|
elif cygrpc.OperationType.receive_message not in self._state.due:
|
|
if self._state.code is grpc.StatusCode.OK:
|
|
raise StopIteration()
|
|
elif self._state.code is not None:
|
|
raise self
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
return self._next()
|
|
|
|
def next(self):
|
|
return self._next()
|
|
|
|
def is_active(self):
|
|
with self._state.condition:
|
|
return self._state.code is None
|
|
|
|
def time_remaining(self):
|
|
if self._deadline is None:
|
|
return None
|
|
else:
|
|
return max(self._deadline - time.time(), 0)
|
|
|
|
def add_callback(self, callback):
|
|
with self._state.condition:
|
|
if self._state.callbacks is None:
|
|
return False
|
|
else:
|
|
self._state.callbacks.append(callback)
|
|
return True
|
|
|
|
def initial_metadata(self):
|
|
with self._state.condition:
|
|
while self._state.initial_metadata is None:
|
|
self._state.condition.wait()
|
|
return self._state.initial_metadata
|
|
|
|
def trailing_metadata(self):
|
|
with self._state.condition:
|
|
while self._state.trailing_metadata is None:
|
|
self._state.condition.wait()
|
|
return self._state.trailing_metadata
|
|
|
|
def code(self):
|
|
with self._state.condition:
|
|
while self._state.code is None:
|
|
self._state.condition.wait()
|
|
return self._state.code
|
|
|
|
def details(self):
|
|
with self._state.condition:
|
|
while self._state.details is None:
|
|
self._state.condition.wait()
|
|
return _common.decode(self._state.details)
|
|
|
|
def debug_error_string(self):
|
|
with self._state.condition:
|
|
while self._state.debug_error_string is None:
|
|
self._state.condition.wait()
|
|
return _common.decode(self._state.debug_error_string)
|
|
|
|
def _repr(self):
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
return '<_Rendezvous object of in-flight RPC>'
|
|
elif self._state.code is grpc.StatusCode.OK:
|
|
return _OK_RENDEZVOUS_REPR_FORMAT.format(
|
|
self._state.code, self._state.details)
|
|
else:
|
|
return _NON_OK_RENDEZVOUS_REPR_FORMAT.format(
|
|
self._state.code, self._state.details,
|
|
self._state.debug_error_string)
|
|
|
|
def __repr__(self):
|
|
return self._repr()
|
|
|
|
def __str__(self):
|
|
return self._repr()
|
|
|
|
def __del__(self):
|
|
with self._state.condition:
|
|
if self._state.code is None:
|
|
self._state.code = grpc.StatusCode.CANCELLED
|
|
self._state.details = 'Cancelled upon garbage collection!'
|
|
self._state.cancelled = True
|
|
self._call.cancel(
|
|
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code],
|
|
self._state.details)
|
|
self._state.condition.notify_all()
|
|
|
|
|
|
def _start_unary_request(request, timeout, request_serializer):
|
|
deadline = _deadline(timeout)
|
|
serialized_request = _common.serialize(request, request_serializer)
|
|
if serialized_request is None:
|
|
state = _RPCState((), (), (), grpc.StatusCode.INTERNAL,
|
|
'Exception serializing request!')
|
|
rendezvous = _Rendezvous(state, None, None, deadline)
|
|
return deadline, None, rendezvous
|
|
else:
|
|
return deadline, serialized_request, None
|
|
|
|
|
|
def _end_unary_response_blocking(state, call, with_call, deadline):
|
|
if state.code is grpc.StatusCode.OK:
|
|
if with_call:
|
|
rendezvous = _Rendezvous(state, call, None, deadline)
|
|
return state.response, rendezvous
|
|
else:
|
|
return state.response
|
|
else:
|
|
raise _Rendezvous(state, None, None, deadline)
|
|
|
|
|
|
def _stream_unary_invocation_operationses(metadata):
|
|
return (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
|
|
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
|
|
|
|
def _stream_unary_invocation_operationses_and_tags(metadata):
|
|
return tuple((
|
|
operations,
|
|
None,
|
|
) for operations in _stream_unary_invocation_operationses(metadata))
|
|
|
|
|
|
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
|
|
|
|
def __init__(self, channel, managed_call, method, request_serializer,
|
|
response_deserializer):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
|
|
def _prepare(self, request, timeout, metadata):
|
|
deadline, serialized_request, rendezvous = _start_unary_request(
|
|
request, timeout, self._request_serializer)
|
|
if serialized_request is None:
|
|
return None, None, None, rendezvous
|
|
else:
|
|
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
|
|
operations = (
|
|
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
|
|
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
)
|
|
return state, operations, deadline, None
|
|
|
|
def _blocking(self, request, timeout, metadata, credentials):
|
|
state, operations, deadline, rendezvous = self._prepare(
|
|
request, timeout, metadata)
|
|
if state is None:
|
|
raise rendezvous
|
|
else:
|
|
call = self._channel.segregated_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials, ((
|
|
operations,
|
|
None,
|
|
),))
|
|
event = call.next_event()
|
|
_handle_event(event, state, self._response_deserializer)
|
|
return state, call,
|
|
|
|
def __call__(self, request, timeout=None, metadata=None, credentials=None):
|
|
state, call, = self._blocking(request, timeout, metadata, credentials)
|
|
return _end_unary_response_blocking(state, call, False, None)
|
|
|
|
def with_call(self, request, timeout=None, metadata=None, credentials=None):
|
|
state, call, = self._blocking(request, timeout, metadata, credentials)
|
|
return _end_unary_response_blocking(state, call, True, None)
|
|
|
|
def future(self, request, timeout=None, metadata=None, credentials=None):
|
|
state, operations, deadline, rendezvous = self._prepare(
|
|
request, timeout, metadata)
|
|
if state is None:
|
|
raise rendezvous
|
|
else:
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
call = self._managed_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials,
|
|
(operations,), event_handler)
|
|
return _Rendezvous(state, call, self._response_deserializer,
|
|
deadline)
|
|
|
|
|
|
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
|
|
|
|
def __init__(self, channel, managed_call, method, request_serializer,
|
|
response_deserializer):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
|
|
def __call__(self, request, timeout=None, metadata=None, credentials=None):
|
|
deadline, serialized_request, rendezvous = _start_unary_request(
|
|
request, timeout, self._request_serializer)
|
|
if serialized_request is None:
|
|
raise rendezvous
|
|
else:
|
|
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
|
|
operationses = (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
|
|
cygrpc.SendMessageOperation(serialized_request,
|
|
_EMPTY_FLAGS),
|
|
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
call = self._managed_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials,
|
|
operationses, event_handler)
|
|
return _Rendezvous(state, call, self._response_deserializer,
|
|
deadline)
|
|
|
|
|
|
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
|
|
|
|
def __init__(self, channel, managed_call, method, request_serializer,
|
|
response_deserializer):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
|
|
def _blocking(self, request_iterator, timeout, metadata, credentials):
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
|
|
call = self._channel.segregated_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials,
|
|
_stream_unary_invocation_operationses_and_tags(metadata))
|
|
_consume_request_iterator(request_iterator, state, call,
|
|
self._request_serializer, None)
|
|
while True:
|
|
event = call.next_event()
|
|
with state.condition:
|
|
_handle_event(event, state, self._response_deserializer)
|
|
state.condition.notify_all()
|
|
if not state.due:
|
|
break
|
|
return state, call,
|
|
|
|
def __call__(self,
|
|
request_iterator,
|
|
timeout=None,
|
|
metadata=None,
|
|
credentials=None):
|
|
state, call, = self._blocking(request_iterator, timeout, metadata,
|
|
credentials)
|
|
return _end_unary_response_blocking(state, call, False, None)
|
|
|
|
def with_call(self,
|
|
request_iterator,
|
|
timeout=None,
|
|
metadata=None,
|
|
credentials=None):
|
|
state, call, = self._blocking(request_iterator, timeout, metadata,
|
|
credentials)
|
|
return _end_unary_response_blocking(state, call, True, None)
|
|
|
|
def future(self,
|
|
request_iterator,
|
|
timeout=None,
|
|
metadata=None,
|
|
credentials=None):
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
call = self._managed_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials,
|
|
_stream_unary_invocation_operationses(metadata), event_handler)
|
|
_consume_request_iterator(request_iterator, state, call,
|
|
self._request_serializer, event_handler)
|
|
return _Rendezvous(state, call, self._response_deserializer, deadline)
|
|
|
|
|
|
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
|
|
|
|
def __init__(self, channel, managed_call, method, request_serializer,
|
|
response_deserializer):
|
|
self._channel = channel
|
|
self._managed_call = managed_call
|
|
self._method = method
|
|
self._request_serializer = request_serializer
|
|
self._response_deserializer = response_deserializer
|
|
|
|
def __call__(self,
|
|
request_iterator,
|
|
timeout=None,
|
|
metadata=None,
|
|
credentials=None):
|
|
deadline = _deadline(timeout)
|
|
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
|
|
operationses = (
|
|
(
|
|
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
|
|
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
|
|
),
|
|
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
|
|
)
|
|
event_handler = _event_handler(state, self._response_deserializer)
|
|
call = self._managed_call(
|
|
0, self._method, None, deadline, metadata, None
|
|
if credentials is None else credentials._credentials, operationses,
|
|
event_handler)
|
|
_consume_request_iterator(request_iterator, state, call,
|
|
self._request_serializer, event_handler)
|
|
return _Rendezvous(state, call, self._response_deserializer, deadline)
|
|
|
|
|
|
class _ChannelCallState(object):
|
|
|
|
def __init__(self, channel):
|
|
self.lock = threading.Lock()
|
|
self.channel = channel
|
|
self.managed_calls = 0
|
|
self.threading = False
|
|
|
|
def reset_postfork_child(self):
|
|
self.managed_calls = 0
|
|
|
|
|
|
def _run_channel_spin_thread(state):
|
|
|
|
def channel_spin():
|
|
while True:
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
event = state.channel.next_call_event()
|
|
if event.completion_type == cygrpc.CompletionType.queue_timeout:
|
|
continue
|
|
call_completed = event.tag(event)
|
|
if call_completed:
|
|
with state.lock:
|
|
state.managed_calls -= 1
|
|
if state.managed_calls == 0:
|
|
return
|
|
|
|
channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin)
|
|
channel_spin_thread.setDaemon(True)
|
|
channel_spin_thread.start()
|
|
|
|
|
|
def _channel_managed_call_management(state):
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def create(flags, method, host, deadline, metadata, credentials,
|
|
operationses, event_handler):
|
|
"""Creates a cygrpc.IntegratedCall.
|
|
|
|
Args:
|
|
flags: An integer bitfield of call flags.
|
|
method: The RPC method.
|
|
host: A host string for the created call.
|
|
deadline: A float to be the deadline of the created call or None if
|
|
the call is to have an infinite deadline.
|
|
metadata: The metadata for the call or None.
|
|
credentials: A cygrpc.CallCredentials or None.
|
|
operationses: An iterable of iterables of cygrpc.Operations to be
|
|
started on the call.
|
|
event_handler: A behavior to call to handle the events resultant from
|
|
the operations on the call.
|
|
|
|
Returns:
|
|
A cygrpc.IntegratedCall with which to conduct an RPC.
|
|
"""
|
|
operationses_and_tags = tuple((
|
|
operations,
|
|
event_handler,
|
|
) for operations in operationses)
|
|
with state.lock:
|
|
call = state.channel.integrated_call(flags, method, host, deadline,
|
|
metadata, credentials,
|
|
operationses_and_tags)
|
|
if state.managed_calls == 0:
|
|
state.managed_calls = 1
|
|
_run_channel_spin_thread(state)
|
|
else:
|
|
state.managed_calls += 1
|
|
return call
|
|
|
|
return create
|
|
|
|
|
|
class _ChannelConnectivityState(object):
|
|
|
|
def __init__(self, channel):
|
|
self.lock = threading.RLock()
|
|
self.channel = channel
|
|
self.polling = False
|
|
self.connectivity = None
|
|
self.try_to_connect = False
|
|
self.callbacks_and_connectivities = []
|
|
self.delivering = False
|
|
|
|
def reset_postfork_child(self):
|
|
self.polling = False
|
|
self.connectivity = None
|
|
self.try_to_connect = False
|
|
self.callbacks_and_connectivities = []
|
|
self.delivering = False
|
|
|
|
|
|
def _deliveries(state):
|
|
callbacks_needing_update = []
|
|
for callback_and_connectivity in state.callbacks_and_connectivities:
|
|
callback, callback_connectivity, = callback_and_connectivity
|
|
if callback_connectivity is not state.connectivity:
|
|
callbacks_needing_update.append(callback)
|
|
callback_and_connectivity[1] = state.connectivity
|
|
return callbacks_needing_update
|
|
|
|
|
|
def _deliver(state, initial_connectivity, initial_callbacks):
|
|
connectivity = initial_connectivity
|
|
callbacks = initial_callbacks
|
|
while True:
|
|
for callback in callbacks:
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
callable_util.call_logging_exceptions(
|
|
callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
|
|
connectivity)
|
|
with state.lock:
|
|
callbacks = _deliveries(state)
|
|
if callbacks:
|
|
connectivity = state.connectivity
|
|
else:
|
|
state.delivering = False
|
|
return
|
|
|
|
|
|
def _spawn_delivery(state, callbacks):
|
|
delivering_thread = cygrpc.ForkManagedThread(
|
|
target=_deliver, args=(
|
|
state,
|
|
state.connectivity,
|
|
callbacks,
|
|
))
|
|
delivering_thread.start()
|
|
state.delivering = True
|
|
|
|
|
|
# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll.
|
|
def _poll_connectivity(state, channel, initial_try_to_connect):
|
|
try_to_connect = initial_try_to_connect
|
|
connectivity = channel.check_connectivity_state(try_to_connect)
|
|
with state.lock:
|
|
state.connectivity = (
|
|
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
|
|
connectivity])
|
|
callbacks = tuple(callback
|
|
for callback, unused_but_known_to_be_none_connectivity
|
|
in state.callbacks_and_connectivities)
|
|
for callback_and_connectivity in state.callbacks_and_connectivities:
|
|
callback_and_connectivity[1] = state.connectivity
|
|
if callbacks:
|
|
_spawn_delivery(state, callbacks)
|
|
while True:
|
|
event = channel.watch_connectivity_state(connectivity,
|
|
time.time() + 0.2)
|
|
cygrpc.block_if_fork_in_progress(state)
|
|
with state.lock:
|
|
if not state.callbacks_and_connectivities and not state.try_to_connect:
|
|
state.polling = False
|
|
state.connectivity = None
|
|
break
|
|
try_to_connect = state.try_to_connect
|
|
state.try_to_connect = False
|
|
if event.success or try_to_connect:
|
|
connectivity = channel.check_connectivity_state(try_to_connect)
|
|
with state.lock:
|
|
state.connectivity = (
|
|
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
|
|
connectivity])
|
|
if not state.delivering:
|
|
callbacks = _deliveries(state)
|
|
if callbacks:
|
|
_spawn_delivery(state, callbacks)
|
|
|
|
|
|
def _moot(state):
|
|
with state.lock:
|
|
del state.callbacks_and_connectivities[:]
|
|
|
|
|
|
def _subscribe(state, callback, try_to_connect):
|
|
with state.lock:
|
|
if not state.callbacks_and_connectivities and not state.polling:
|
|
polling_thread = cygrpc.ForkManagedThread(
|
|
target=_poll_connectivity,
|
|
args=(state, state.channel, bool(try_to_connect)))
|
|
polling_thread.setDaemon(True)
|
|
polling_thread.start()
|
|
state.polling = True
|
|
state.callbacks_and_connectivities.append([callback, None])
|
|
elif not state.delivering and state.connectivity is not None:
|
|
_spawn_delivery(state, (callback,))
|
|
state.try_to_connect |= bool(try_to_connect)
|
|
state.callbacks_and_connectivities.append(
|
|
[callback, state.connectivity])
|
|
else:
|
|
state.try_to_connect |= bool(try_to_connect)
|
|
state.callbacks_and_connectivities.append([callback, None])
|
|
|
|
|
|
def _unsubscribe(state, callback):
|
|
with state.lock:
|
|
for index, (subscribed_callback, unused_connectivity) in enumerate(
|
|
state.callbacks_and_connectivities):
|
|
if callback == subscribed_callback:
|
|
state.callbacks_and_connectivities.pop(index)
|
|
break
|
|
|
|
|
|
def _options(options):
|
|
return list(options) + [
|
|
(
|
|
cygrpc.ChannelArgKey.primary_user_agent_string,
|
|
_USER_AGENT,
|
|
),
|
|
]
|
|
|
|
|
|
class Channel(grpc.Channel):
|
|
"""A cygrpc.Channel-backed implementation of grpc.Channel."""
|
|
|
|
def __init__(self, target, options, credentials):
|
|
"""Constructor.
|
|
|
|
Args:
|
|
target: The target to which to connect.
|
|
options: Configuration options for the channel.
|
|
credentials: A cygrpc.ChannelCredentials or None.
|
|
"""
|
|
self._channel = cygrpc.Channel(
|
|
_common.encode(target), _options(options), credentials)
|
|
self._call_state = _ChannelCallState(self._channel)
|
|
self._connectivity_state = _ChannelConnectivityState(self._channel)
|
|
cygrpc.fork_register_channel(self)
|
|
|
|
def subscribe(self, callback, try_to_connect=None):
|
|
_subscribe(self._connectivity_state, callback, try_to_connect)
|
|
|
|
def unsubscribe(self, callback):
|
|
_unsubscribe(self._connectivity_state, callback)
|
|
|
|
def unary_unary(self,
|
|
method,
|
|
request_serializer=None,
|
|
response_deserializer=None):
|
|
return _UnaryUnaryMultiCallable(
|
|
self._channel, _channel_managed_call_management(self._call_state),
|
|
_common.encode(method), request_serializer, response_deserializer)
|
|
|
|
def unary_stream(self,
|
|
method,
|
|
request_serializer=None,
|
|
response_deserializer=None):
|
|
return _UnaryStreamMultiCallable(
|
|
self._channel, _channel_managed_call_management(self._call_state),
|
|
_common.encode(method), request_serializer, response_deserializer)
|
|
|
|
def stream_unary(self,
|
|
method,
|
|
request_serializer=None,
|
|
response_deserializer=None):
|
|
return _StreamUnaryMultiCallable(
|
|
self._channel, _channel_managed_call_management(self._call_state),
|
|
_common.encode(method), request_serializer, response_deserializer)
|
|
|
|
def stream_stream(self,
|
|
method,
|
|
request_serializer=None,
|
|
response_deserializer=None):
|
|
return _StreamStreamMultiCallable(
|
|
self._channel, _channel_managed_call_management(self._call_state),
|
|
_common.encode(method), request_serializer, response_deserializer)
|
|
|
|
def _close(self):
|
|
self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!')
|
|
_moot(self._connectivity_state)
|
|
|
|
def _close_on_fork(self):
|
|
self._channel.close_on_fork(cygrpc.StatusCode.cancelled,
|
|
'Channel closed due to fork')
|
|
_moot(self._connectivity_state)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._close()
|
|
return False
|
|
|
|
def close(self):
|
|
self._close()
|
|
|
|
def __del__(self):
|
|
# TODO(https://github.com/grpc/grpc/issues/12531): Several releases
|
|
# after 1.12 (1.16 or thereabouts?) add a "self._channel.close" call
|
|
# here (or more likely, call self._close() here). We don't do this today
|
|
# because many valid use cases today allow the channel to be deleted
|
|
# immediately after stubs are created. After a sufficient period of time
|
|
# has passed for all users to be trusted to hang out to their channels
|
|
# for as long as they are in use and to close them after using them,
|
|
# then deletion of this grpc._channel.Channel instance can be made to
|
|
# effect closure of the underlying cygrpc.Channel instance.
|
|
cygrpc.fork_unregister_channel(self)
|
|
_moot(self._connectivity_state)
|