diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 483843c2a6..1181c6f686 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -28,6 +28,7 @@ from copy import copy from functools import partial, reduce, wraps from itertools import groupby, count, chain +import enum import json import logging from typing import Any, Dict, Optional, Union, Tuple @@ -514,8 +515,9 @@ def __init__(self, load_balancing_policy=None, retry_policy=None, class ProfileManager(object): - def __init__(self): + def __init__(self, pools_allowed: bool=True): self.profiles = dict() + self.pools_allowed = pools_allowed def _profiles_without_explicit_lbps(self): names = (profile_name for @@ -527,6 +529,8 @@ def _profiles_without_explicit_lbps(self): ) def distance(self, host): + if not self.pools_allowed: + return HostDistance.IGNORED distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ HostDistance.LOCAL if HostDistance.LOCAL in distances else \ @@ -542,10 +546,14 @@ def check_supported(self): p.load_balancing_policy.check_supported() def on_up(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_down(host) @@ -619,6 +627,31 @@ class _ConfigMode(object): PROFILES = 2 +class ControlConnectionQueryFallback(enum.Enum): + """ + Controls how application queries use the control connection when node pools + are unavailable. + + ``Disabled`` requires a usable node pool for application queries. If the + driver cannot establish one during session startup, it raises + :class:`NoHostAvailable`. + + ``Fallback`` still attempts to create node pools, but allows application + queries to fall back to the control connection when no usable node pool is + available. Session startup is allowed to proceed even if the initial pool + attempts all fail. + + ``SkipPoolCreation`` disables node-pool creation for the session and uses + the control-connection fallback path for application queries. + + The fallback path is not used for requests targeted to an explicit host. + """ + + Disabled = "Disabled" + Fallback = "Fallback" + SkipPoolCreation = "SkipPoolCreation" + + class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. @@ -939,6 +972,16 @@ def default_retry_policy(self, policy): If set to :const:`None`, there will be no timeout for these queries. """ + allow_control_connection_query_fallback: ControlConnectionQueryFallback = ControlConnectionQueryFallback.Disabled + """ + Controls whether application queries may fall back to the control connection. + + ``Disabled`` keeps the old behavior. + ``Fallback`` enables control-connection fallback when no usable node pools exist. + ``SkipPoolCreation`` skips node-pool creation and uses the control connection fallback path. + This fallback is still not used for requests targeted to an explicit host. + """ + idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps @@ -1225,7 +1268,8 @@ def __init__(self, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, application_info:Optional[ApplicationInfoBase]=None, - client_routes_config:Optional[ClientRoutesConfig]=None + client_routes_config:Optional[ClientRoutesConfig]=None, + allow_control_connection_query_fallback:Optional[ControlConnectionQueryFallback]=ControlConnectionQueryFallback.Disabled ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1243,6 +1287,10 @@ def __init__(self, if port < 1 or port > 65535: raise ValueError("Invalid port number (%s) (1-65535)" % port) + if not isinstance(allow_control_connection_query_fallback, ControlConnectionQueryFallback): + raise TypeError( + "allow_control_connection_query_fallback must be a ControlConnectionQueryFallback value") + if connection_class is not None: self.connection_class = connection_class @@ -1404,7 +1452,8 @@ def __init__(self, else: self.timestamp_generator = MonotonicTimestampGenerator() - self.profile_manager = ProfileManager() + self.profile_manager = ProfileManager( + pools_allowed=allow_control_connection_query_fallback != ControlConnectionQueryFallback.SkipPoolCreation) self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( self.load_balancing_policy, self.default_retry_policy, @@ -1473,6 +1522,7 @@ def __init__(self, self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout + self.allow_control_connection_query_fallback = allow_control_connection_query_fallback self.metadata_request_timeout = self.control_connection_timeout if metadata_request_timeout is None else metadata_request_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout @@ -1815,7 +1865,8 @@ def get_all_pools(self): return pools def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + pools = self.get_all_pools() + return bool(pools and pools[0].host.sharding_info) def shard_aware_stats(self): if self.is_shard_aware(): @@ -1920,7 +1971,7 @@ def on_up(self, host): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return log.debug("Waiting to acquire lock for handling up status of node %s", host) @@ -2028,7 +2079,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return with host.lock: @@ -2633,20 +2684,24 @@ def __init__(self, cluster, hosts, keyspace=None): # create connection pools in parallel self._initial_connect_futures = set() - for host in hosts: - future = self.add_or_renew_pool(host, is_host_addition=False) - if future: - self._initial_connect_futures.add(future) - - futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) - while futures.not_done and not any(f.result() for f in futures.done): - futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) - - if not any(f.result() for f in self._initial_connect_futures): - msg = "Unable to connect to any servers" - if self.keyspace: - msg += " using keyspace '%s'" % self.keyspace - raise NoHostAvailable(msg, [h.address for h in hosts]) + fallback_mode = self.cluster.allow_control_connection_query_fallback + if fallback_mode is not ControlConnectionQueryFallback.SkipPoolCreation: + for host in hosts: + future = self.add_or_renew_pool(host, is_host_addition=False) + if future: + self._initial_connect_futures.add(future) + + futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + while futures.not_done and not any(f.result() for f in futures.done): + futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) + + # Only Disabled requires an initial pool to come up. + if not any(f.result() for f in self._initial_connect_futures) and \ + fallback_mode is ControlConnectionQueryFallback.Disabled: + msg = "Unable to connect to any servers" + if self.keyspace: + msg += " using keyspace '%s'" % self.keyspace + raise NoHostAvailable(msg, [h.address for h in hosts]) self.session_id = uuid.uuid4() @@ -3245,6 +3300,9 @@ def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return None + distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None @@ -3315,6 +3373,9 @@ def update_created_pools(self): For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return set() + futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) @@ -4650,6 +4711,7 @@ class ResponseFuture(object): _spec_execution_plan = NoSpeculativeExecutionPlan() _continuous_paging_session = None _host = None + _control_connection_query_attempted = False _TABLET_ROUTING_CTYPE = None _warned_timeout = False @@ -4670,6 +4732,7 @@ def __init__(self, session, message, query, timeout, metrics=None, prepared_stat self._callback_lock = Lock() self._start_time = start_time or time.time() self._host = host + self._control_connection_query_attempted = False self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() @@ -4748,11 +4811,22 @@ def _on_timeout(self, _attempts=0): self._connection.orphaned_threshold_reached = True pool.return_connection(self._connection, stream_was_orphaned=True) + elif self._connection.is_control_connection: + with self._connection.lock: + self._connection.orphaned_request_ids.add(self._req_id) + if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + self._connection.orphaned_threshold_reached = True errors = self._errors if not errors: if self.is_schema_agreed: - key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' + if self._current_host is None: + key = 'no host queried before timeout' + elif self._connection is not None and self._connection.is_control_connection: + control_host = self.session.cluster.get_control_connection_host() + key = str(control_host.endpoint) if control_host is not None else str(self._connection.endpoint) + else: + key = str(self._current_host.endpoint) errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection @@ -4810,14 +4884,110 @@ def send_request(self, error_no_hosts=True): self._on_timeout() return True if error_no_hosts: + if self._fallback_to_control_connection(): + req_id = self._query_control_connection() + if req_id is not None: + self._req_id = req_id + return True + self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False + def _has_usable_node_pool(self): + try: + pools = tuple(self.session._pools.values()) + except (AttributeError, TypeError): + return False + + return any(pool and not pool.is_shutdown for pool in pools) + + def _fallback_to_control_connection(self): + fallback_mode = self.session.cluster.allow_control_connection_query_fallback + if fallback_mode is ControlConnectionQueryFallback.Disabled: + return False + if self._host or self._control_connection_query_attempted: + return False + if fallback_mode is ControlConnectionQueryFallback.SkipPoolCreation: + return True + return not self._has_usable_node_pool() + + def _borrow_control_connection(self, connection): + with connection.lock: + if connection.in_flight >= connection.max_request_id: + raise NoConnectionsAvailable("All request IDs are currently in use") + connection.in_flight += 1 + return connection.get_request_id() + + def _release_control_connection_request(self, connection, request_id): + with connection.lock: + connection.in_flight -= 1 + connection.request_ids.append(request_id) + connection._requests.pop(request_id, None) + + def _handle_control_connection_response(self, connection, cb, response): + with connection.lock: + connection.in_flight -= 1 + cb(response) + + def _query_control_connection(self, message=None, cb=None, connection=None, host=None): + self._control_connection_query_attempted = True + + if message is None: + message = self.message + + if connection is None: + control_connection = self.session.cluster.control_connection + connection = control_connection._connection if control_connection else None + if not connection: + self._errors['control connection'] = ConnectionException("Control connection is not connected") + return None + + if host is None: + host = self.session.cluster.get_control_connection_host() or connection.endpoint + self._current_host = host + + request_id = None + request_sent = False + try: + request_id = self._borrow_control_connection(connection) + self._connection = connection + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + if cb is None: + cb = partial(self._set_result, host, connection, None) + cb = partial(self._handle_control_connection_response, connection, cb) + + log.debug("No usable node pools; falling back to control connection for host %s", host) + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) + request_sent = True + self.attempted_hosts.append(host) + return request_id + except NoConnectionsAvailable as exc: + log.debug("Control connection is at capacity") + self._errors[host] = exc + except ConnectionBusy as exc: + log.debug("Control connection is busy") + self._errors[host] = exc + except Exception as exc: + log.debug("Error querying control connection", exc_info=True) + self._errors[host] = exc + if self._metrics is not None: + self._metrics.on_connection_error() + finally: + if request_id is not None and not request_sent: + self._release_control_connection_request(connection, request_id) + + return None + def _query(self, host, message=None, cb=None): if message is None: message = self.message + self._control_connection_query_attempted = False + pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -4928,12 +5098,17 @@ def start_fetching_next_page(self): self._event.clear() self._final_result = _NOT_SET self._final_exception = None + self._control_connection_query_attempted = False self._start_timer() self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) - request_id = self._query(host, prepare_message, cb=cb) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(prepare_message, cb=cb, + connection=connection, host=host) + else: + request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() @@ -4972,6 +5147,8 @@ def _set_result(self, host, connection, pool, response): if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) + if connection is not None: + connection.keyspace = response.new_keyspace # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event @@ -5148,10 +5325,13 @@ def _execute_after_prepare(self, host, connection, pool, response): new_metadata_id = response.result_metadata_id if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id - + # use self._query to re-use the same host and # at the same time properly borrow the connection - request_id = self._query(host) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(connection=connection, host=host) + else: + request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() @@ -5264,6 +5444,11 @@ def _retry_task(self, reuse_connection, host): # to retry the operation return + if self._control_connection_query_attempted: + self._control_connection_query_attempted = False + self.send_request() + return + if reuse_connection and self._query(host) is not None: return diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index de8518d271..44b7b63f67 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -48,6 +48,8 @@ Clusters and Sessions .. autoattribute:: control_connection_timeout + .. autoattribute:: allow_control_connection_query_fallback + .. autoattribute:: idle_heartbeat_interval .. autoattribute:: idle_heartbeat_timeout @@ -106,6 +108,9 @@ Clusters and Sessions .. automethod:: set_meta_refresh_enabled +.. autoclass:: ControlConnectionQueryFallback + :members: + .. autoclass:: ExecutionProfile (load_balancing_policy=, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=, speculative_execution_policy=None) :members: :exclude-members: consistency_level diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index cafe6ae9c9..98d71993fd 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -259,10 +259,8 @@ class SensitiveModel(Model): rows[-1] rows[-1:] - # ignore DeprecationWarning('The loop argument is deprecated since Python 3.8, and scheduled for removal in Python 3.10.') - relevant_warnings = [warn for warn in w if "The loop argument is deprecated" not in str(warn.message)] + warning_messages = [str(warn.message) for warn in w] - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[0].message) - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[1].message) - assert "ModelQuerySet indexing with negative indices support will be removed in 4.0." in str(relevant_warnings[2].message) - assert "ModelQuerySet slicing with negative indices support will be removed in 4.0." in str(relevant_warnings[3].message) + assert sum("__table_name_case_sensitive__ will be removed in 4.0." in message for message in warning_messages) == 2 + assert sum("ModelQuerySet indexing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 + assert sum("ModelQuerySet slicing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 diff --git a/tests/integration/standard/conftest.py b/tests/integration/standard/conftest.py index 3adaf371b0..9934cfcbbb 100644 --- a/tests/integration/standard/conftest.py +++ b/tests/integration/standard/conftest.py @@ -37,6 +37,7 @@ "test_ip_change": 4, "test_authentication": 4, "test_authentication_misconfiguration": 4, + "test_control_connection_query_fallback": 4, "test_custom_cluster": 4, "test_query": 4, # Group 5: tablets (destructive — decommissions a node) diff --git a/tests/integration/standard/test_control_connection_query_fallback.py b/tests/integration/standard/test_control_connection_query_fallback.py new file mode 100644 index 0000000000..e64763a72c --- /dev/null +++ b/tests/integration/standard/test_control_connection_query_fallback.py @@ -0,0 +1,115 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from cassandra.cluster import ControlConnectionQueryFallback, NoHostAvailable + +from tests.integration import USE_CASS_EXTERNAL, TestCluster, local, remove_cluster, use_cluster + + +_CLUSTER_NAME = "control_connection_query_fallback" +_UNREACHABLE_BROADCAST_RPC_ADDRESS = "127.255.255.1" + + +def setup_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + ccm_cluster = use_cluster(_CLUSTER_NAME, [1], start=False) + ccm_cluster.nodes["node1"].set_configuration_options(values={ + "broadcast_rpc_address": _UNREACHABLE_BROADCAST_RPC_ADDRESS, + }) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + +def teardown_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + +@local +class ControlConnectionQueryFallbackIntegrationTests(unittest.TestCase): + + def setUp(self): + self.cluster = None + + def tearDown(self): + if self.cluster is not None: + self.cluster.shutdown() + + def _assert_unreachable_broadcast_rpc_metadata(self): + hosts = self.cluster.metadata.all_hosts() + assert len(hosts) == 1 + + host = hosts[0] + assert host.broadcast_rpc_address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert host.endpoint.address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + return host + + def test_disabled_raises_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Disabled, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + with pytest.raises(NoHostAvailable): + self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert self.cluster.control_connection._connection is not None + assert self.cluster.get_all_pools() == [] + + def test_fallback_executes_queries_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version + + def test_no_node_pool_fallback_executes_queries_without_creating_pools(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures == set() + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index b6f2da5372..bb6de88fc0 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from concurrent.futures import Future import logging import socket from types import SimpleNamespace @@ -22,9 +23,9 @@ from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion -from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, default_lbp_factory, \ +from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, ControlConnectionQueryFallback, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.connection import ConnectionBusy +from cassandra.connection import ConnectionBusy, ConnectionException from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -186,6 +187,52 @@ def test_port_range(self): with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port=invalid_port) + def test_control_connection_query_fallback_modes(self): + assert Cluster().allow_control_connection_query_fallback is ControlConnectionQueryFallback.Disabled + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=False) + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=True) + assert ( + Cluster(allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback) + .allow_control_connection_query_fallback + is ControlConnectionQueryFallback.Fallback + ) + assert Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation + ).allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation + + def test_control_connection_query_fallback_no_node_pool_mode_skips_pool_creation(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + with patch.object(Session, "add_or_renew_pool") as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_not_called() + assert session._initial_connect_futures == set() + assert session._pools == {} + assert session.update_created_pools() == set() + + def test_control_connection_query_fallback_fallback_tolerates_empty_initial_pools(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + future = Future() + future.set_result(False) + + with patch.object(Session, "add_or_renew_pool", return_value=future) as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + assert session._initial_connect_futures == {future} + assert session._pools == {} + def test_compression_autodisabled_without_libraries(self): with patch.dict('cassandra.cluster.locally_supported_compressions', {}, clear=True): with patch('cassandra.cluster.log') as patched_logger: @@ -454,6 +501,7 @@ def test_set_keyspace_escapes_quotes(self, *_): "Simple keyspace names should not be quoted, got: %r" % query) @mock_session_pools +<<<<<<< HEAD def test_wait_for_schema_agreement_default_scope_queries_all_connected_hosts(self, *_): session, hosts, _ = self._new_schema_agreement_session( ["a", "a"], @@ -550,6 +598,32 @@ def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): with pytest.raises(ValueError): session.wait_for_schema_agreement(wait_time=1, scope='planet') +======= + def test_set_keyspace_for_all_pools_reports_all_errors(self, *_): + cluster = Cluster() + session = Session( + cluster, + [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())], + ) + + pool1 = Mock(host='host1') + pool2 = Mock(host='host2') + keyspace_error = ConnectionException("boom") + + pool1._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool1, [keyspace_error]) + ) + pool2._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool2, []) + ) + session._pools = {'host1': pool1, 'host2': pool2} + + callback = Mock() + session._set_keyspace_for_all_pools('ks', callback) + + callback.assert_called_once() + assert callback.call_args.args[0] == {'host1': [keyspace_error]} +>>>>>>> 40719aea4 (cluster: add control-connection query fallback) class ProtocolVersionTests(unittest.TestCase): diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..9673b0d634 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -19,7 +19,7 @@ from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion, ControlConnectionQueryFallback from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, @@ -41,6 +41,7 @@ def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] s.cluster.control_connection._tablets_routing_v1 = False + s.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Disabled return s def make_pool(self): @@ -49,6 +50,22 @@ def make_pool(self): pool.borrow_connection.return_value = [Mock(), Mock()] return pool + def make_control_connection(self): + connection = Mock(spec=Connection) + connection.endpoint = 'control-host' + connection.lock = RLock() + connection.in_flight = 0 + connection.max_request_id = 100 + connection.request_ids = deque() + connection._requests = {} + connection.orphaned_request_ids = set() + connection.orphaned_threshold = 75 + connection.orphaned_threshold_reached = False + connection.is_control_connection = True + connection.get_request_id.return_value = 7 + connection.send_msg.return_value = 128 + return connection + def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] @@ -391,6 +408,268 @@ def test_all_pools_shutdown(self): with pytest.raises(NoHostAvailable): rf.result() + def test_control_connection_fallback_disabled_by_default(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_updates_connection_keyspace(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + + def set_keyspace_for_all_pools(keyspace, callback): + session.keyspace = keyspace + callback({}) + + session._set_keyspace_for_all_pools.side_effect = set_keyspace_for_all_pools + + connection = self.make_control_connection() + connection.keyspace = 'oldks' + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, new_keyspace='newks') + connection.send_msg.call_args[1]['cb'](result) + + assert connection.keyspace == 'newks' + assert session.keyspace == 'newks' + assert rf.result().current_rows == [] + + def test_control_connection_fallback_when_no_usable_pools(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.SkipPoolCreation + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + connection.send_msg.assert_called_once_with( + rf.message, 7, cb=ANY, encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, result_metadata=[]) + assert connection.in_flight == 1 + assert rf.attempted_hosts == [control_host] + + cb = connection.send_msg.call_args[1]['cb'] + expected_result = (object(), object()) + cb(self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_retries_after_server_error(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = Mock(spec=ServerError, info={}) + first_response.summary = 'boom' + first_response.to_exception.return_value = first_response + connection.send_msg.call_args[1]['cb'](first_response) + + rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, control_host) + + # The retry decision must come from the future state, not the live connection reference. + rf._connection = Mock(is_control_connection=False) + + rf._retry_task(False, control_host) + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.attempted_hosts == [control_host, control_host] + + expected_result = (object(), object()) + connection.send_msg.call_args_list[1][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_fetches_next_page(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = self.make_mock_response(['col'], [(1,)]) + first_response.paging_state = b'next-page' + connection.send_msg.call_args[1]['cb'](first_response) + + assert rf.result().current_rows == [(['col'], [(1,)])] + assert rf.has_more_pages + + rf.start_fetching_next_page() + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.message.paging_state == b'next-page' + + second_response = self.make_mock_response(['col'], [(2,)]) + connection.send_msg.call_args_list[1][1]['cb'](second_response) + + assert connection.in_flight == 0 + assert rf.result().current_rows == [(['col'], [(2,)])] + + def test_control_connection_fallback_reprepares_prepared_statement(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + session.submit.side_effect = lambda fn, *args, **kwargs: fn(*args, **kwargs) + + query_id = b'a' * 16 + prepared_statement = Mock( + query_id=query_id, + query_string="SELECT * FROM foobar", + keyspace="FooKeyspace", + result_metadata=[], + result_metadata_id=None) + session.cluster._prepared_statements = {query_id: prepared_statement} + + connection = self.make_control_connection() + connection.keyspace = "FooKeyspace" + connection.get_request_id.side_effect = [7, 8, 9] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + rf.prepared_statement = prepared_statement + assert rf.send_request() + + missing = Mock(spec=PreparedQueryNotFound, info=query_id) + connection.send_msg.call_args_list[0][1]['cb'](missing) + + assert connection.send_msg.call_count == 2 + prepare_message = connection.send_msg.call_args_list[1][0][0] + assert isinstance(prepare_message, PrepareMessage) + assert prepare_message.query == "SELECT * FROM foobar" + assert connection.send_msg.call_args_list[1][0][1] == 8 + + prepared_response = Mock( + spec=ResultMessage, + kind=RESULT_KIND_PREPARED, + query_id=query_id, + column_metadata=[], + result_metadata_id=None) + connection.send_msg.call_args_list[1][1]['cb'](prepared_response) + + assert connection.send_msg.call_count == 3 + assert connection.send_msg.call_args_list[2][0][0] is rf.message + assert connection.send_msg.call_args_list[2][0][1] == 9 + + expected_result = (['col'], [(1,)]) + connection.send_msg.call_args_list[2][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_not_used_when_pool_can_serve(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + pool = Mock(is_shutdown=False) + pool.borrow_connection.side_effect = NoConnectionsAvailable() + session._pools = {'ip1': pool} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_orphans_stream_on_timeout(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + rf.send_request() + rf._on_timeout() + + assert 7 in connection.orphaned_request_ids + assert connection.in_flight == 1 + with pytest.raises(OperationTimedOut): + rf.result() + + def test_control_connection_fallback_timeout_without_metadata_host_uses_connection_endpoint(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [] + session._pools = {} + session.cluster.get_control_connection_host.return_value = None + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + assert rf.send_request() + rf._on_timeout() + + with pytest.raises(OperationTimedOut) as exc_info: + rf.result() + + assert exc_info.value.errors == { + 'control-host': 'Client request timeout. See Session.execute[_async](timeout)' + } + def test_first_pool_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2']