diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py index 80b2477a6d..e447e37df2 100644 --- a/cassandra/client_routes.py +++ b/cassandra/client_routes.py @@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float, return routes = self._query_routes_for_change_event(connection, timeout, pairs) - self._routes.merge(routes, affected_host_ids=set(host_uuids)) + self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs}) def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, connection_ids: Set[str]) -> List[_Route]: @@ -322,27 +322,25 @@ def _query_all_routes_for_connections(self, connection: 'Connection', timeout: f def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: """ - Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + Query current routes for hosts affected by a CLIENT_ROUTES_CHANGE event. - Takes a list of (connection_id, host_id) pairs that represent the exact - routes affected by an operation. This provides precise updates without - fetching unrelated routes. - - If the pairs list is empty or None, falls back to a complete refresh - of all routes for safety. + The in-memory route store keeps a single preferred route per host. When + any configured connection_id changes for a host, fetch all configured + connection_ids for that host so the existing preferred route can be + retained if it is still present. :param connection: Connection to execute query on :param timeout: Query timeout in seconds - :param route_pairs: List of (connection_id, host_id) tuples + :param route_pairs: List of affected (connection_id, host_id) tuples :return: List of _Route """ unique_pairs = list(dict.fromkeys(route_pairs)) - conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + conn_ids = sorted(self._connection_ids) host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) - log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " - "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + log.debug("[client routes] Querying routes from CLIENT_ROUTES_CHANGE " + "for host_ids (first 5 of %d): %s", len(host_ids), host_ids[:5]) conn_ph = ', '.join('?' for _ in conn_ids) host_ph = ', '.join('?' for _ in host_ids) diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..b873243a6f 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -428,7 +428,7 @@ class ClientRoutesEndPoint(EndPoint): _host_id: uuid.UUID _handler: _ClientRoutesHandler _original_address: str - _original_port: int + _original_port: Optional[int] def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None: """ @@ -467,15 +467,24 @@ def resolve(self) -> Tuple[str, int]: def __eq__(self, other): return (isinstance(other, ClientRoutesEndPoint) and - self._host_id == other._host_id and - self._original_address == other._original_address) + self._identity_key() == other._identity_key()) def __hash__(self): - return hash((self._host_id, self._original_address)) + return hash(self._identity_key()) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return self._ordering_key() < other._ordering_key() + + def _identity_key(self): + return self._host_id, self._original_address, self._original_port + + def _ordering_key(self): + return ( + self._host_id, + self._original_address, + self._original_port is None, + self._original_port, + ) def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..8c76852c6b 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -545,12 +545,14 @@ class TestGetHostPortMapping(unittest.TestCase): @classmethod def setUpClass(cls): + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] + cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( - proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + proxies=[ClientRouteProxy(connection_id, "127.0.0.1") + for connection_id in cls.connection_ids])) cls.session = cls.cluster.connect() - cls.host_ids = [uuid.uuid4() for _ in range(3)] - cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] cls.expected = [] for idx, host_id in enumerate(cls.host_ids): @@ -639,8 +641,8 @@ def test_get_routes_for_change_event_all_pairs(self): self._sort_routes(expected) self.assertEqual(got, expected) - def test_get_routes_for_change_event_single_pair(self): - """Querying a single (connection_id, host_id) pair returns one route.""" + def test_get_routes_for_change_event_single_host(self): + """Querying a single changed host returns all configured routes for it.""" cc = self.cluster.control_connection target_conn_id = self.connection_ids[0] target_host_id = self.host_ids[0] @@ -650,8 +652,7 @@ def test_get_routes_for_change_event_single_pair(self): got = self._routes_to_dicts(routes) self._sort_routes(got) filtered = [r for r in self.expected - if r['connection_id'] == target_conn_id - and r['host_id'] == target_host_id] + if r['host_id'] == target_host_id] expected = self._expected_dicts(filtered) self._sort_routes(expected) self.assertEqual(got, expected) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 0aa82fc76a..925bef4353 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -233,6 +233,92 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query): self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query): + """Routes for unrelated connection_ids in mixed events should not be removed.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + conn_id = str(self.conn_id) + changed_host = uuid.uuid4() + unrelated_host = uuid.uuid4() + + handler._routes.update([ + _Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042), + _Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id, "unrelated-conn-id"], + host_ids=[str(changed_host), str(unrelated_host)], + ) + + self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com") + self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com") + + def test_handle_change_preserves_preferred_route_for_same_host(self): + conn_a = str(uuid.uuid4()) + conn_b = str(uuid.uuid4()) + host_id = uuid.uuid4() + config = ClientRoutesConfig([ + ClientRouteProxy(conn_a), + ClientRouteProxy(conn_b), + ]) + handler = _ClientRoutesHandler(config) + handler._routes.update([ + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ]) + + table_routes = [ + _Route(connection_id=conn_a, host_id=host_id, + address="changed.example.com", port=9042), + _Route(connection_id=conn_b, host_id=host_id, + address="current.example.com", port=9042), + ] + + def wait_for_response(query_msg, timeout): + conn_placeholders = query_msg.query.split( + "connection_id IN (", 1)[1].split(")", 1)[0].count("?") + conn_ids = { + param.decode("utf-8") + for param in query_msg.query_params[:conn_placeholders] + } + host_ids = { + uuid.UUID(bytes=param) + for param in query_msg.query_params[conn_placeholders:] + } + rows = [ + (route.connection_id, route.host_id, route.address, + route.port, route.port) + for route in table_routes + if route.connection_id in conn_ids and route.host_id in host_ids + ] + return Mock( + column_names=["connection_id", "host_id", "address", "port", "tls_port"], + parsed_rows=rows, + ) + + mock_conn = Mock() + mock_conn.wait_for_response.side_effect = wait_for_response + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_a], + host_ids=[str(host_id)], + ) + + route = handler._routes.get_by_host_id(host_id) + self.assertEqual(route.connection_id, conn_b) + self.assertEqual(route.address, "current.example.com") + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') def test_handle_change_updates_when_no_host_ids(self, mock_query): """When no host_ids are provided, routes should be fully replaced.""" @@ -347,6 +433,44 @@ def test_resolve_falls_back_when_no_mapping(self): ) self.assertEqual(ep.resolve(), ("10.0.0.1", 9042)) + def test_original_port_is_part_of_identity(self): + """Endpoints that only differ by original port should not compare equal.""" + host_id = uuid.uuid4() + ep_without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + ) + ep_with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertNotEqual(ep_without_port, ep_with_port) + self.assertNotEqual(hash(ep_without_port), hash(ep_with_port)) + + def test_sorting_handles_missing_original_port(self): + """Ordering should remain deterministic when original_port is None.""" + host_id = uuid.uuid4() + ep_without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + ) + ep_with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertEqual( + sorted([ep_without_port, ep_with_port]), + [ep_with_port, ep_without_port], + ) + @patch('cassandra.client_routes.socket.getaddrinfo', return_value=[(socket.AF_INET, socket.SOCK_STREAM, 0, '', ("192.168.1.100", 9042))]) def test_resolve_returns_address_when_route_exists(self, _mock_getaddrinfo):