Skip to content

Commit 278d66e

Browse files
committed
client-routes: preserve mixed event route state
1 parent 0842348 commit 278d66e

3 files changed

Lines changed: 76 additions & 7 deletions

File tree

cassandra/client_routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def handle_client_routes_change(self, connection: 'Connection', timeout: float,
294294
return
295295

296296
routes = self._query_routes_for_change_event(connection, timeout, pairs)
297-
self._routes.merge(routes, affected_host_ids=set(host_uuids))
297+
self._routes.merge(routes, affected_host_ids={host_id for _, host_id in pairs})
298298

299299
def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float,
300300
connection_ids: Set[str]) -> List[_Route]:

cassandra/connection.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,25 @@ def resolve(self) -> Tuple[str, int]:
468468
def __eq__(self, other):
469469
return (isinstance(other, ClientRoutesEndPoint) and
470470
self._host_id == other._host_id and
471-
self._original_address == other._original_address)
471+
self._original_address == other._original_address and
472+
self._original_port == other._original_port)
472473

473474
def __hash__(self):
474-
return hash((self._host_id, self._original_address))
475+
return hash((self._host_id, self._original_address, self._original_port))
476+
477+
def _comparison_key(self):
478+
return (self._host_id, self._original_address,
479+
self._original_port is None, self._original_port)
475480

476481
def __lt__(self, other):
477-
return ((self._host_id, self._original_address) <
478-
(other._host_id, other._original_address))
482+
return self._comparison_key() < other._comparison_key()
479483

480484
def __str__(self):
481485
return str("%s (host_id=%s)" % (self._original_address, self._host_id))
482486

483487
def __repr__(self):
484-
return "<%s: host_id=%s, original_addr=%s>" % (
485-
self.__class__.__name__, self._host_id, self._original_address)
488+
return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % (
489+
self.__class__.__name__, self._host_id, self._original_address, self._original_port)
486490

487491

488492
class _Frame(object):

tests/unit/test_client_routes.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,35 @@ def test_handle_change_merges_when_host_ids_present(self, mock_query):
233233
self.assertIsNotNone(handler._routes.get_by_host_id(existing_host))
234234
self.assertIsNotNone(handler._routes.get_by_host_id(new_host))
235235

236+
@patch.object(_ClientRoutesHandler, '_query_routes_for_change_event')
237+
def test_handle_change_preserves_routes_for_unrelated_connection_ids(self, mock_query):
238+
"""Routes for unrelated connection_ids in mixed events should not be removed."""
239+
handler = _ClientRoutesHandler(self.config)
240+
mock_conn = Mock()
241+
242+
conn_id = str(self.conn_id)
243+
changed_host = uuid.uuid4()
244+
unrelated_host = uuid.uuid4()
245+
246+
handler._routes.update([
247+
_Route(connection_id=conn_id, host_id=changed_host, address="old.com", port=9042),
248+
_Route(connection_id=conn_id, host_id=unrelated_host, address="keep.com", port=9042),
249+
])
250+
251+
mock_query.return_value = [
252+
_Route(connection_id=conn_id, host_id=changed_host, address="new.com", port=9042),
253+
]
254+
255+
handler.handle_client_routes_change(
256+
mock_conn, 5.0,
257+
ClientRoutesChangeType.UPDATE_NODES,
258+
connection_ids=[conn_id, "unrelated-conn-id"],
259+
host_ids=[str(changed_host), str(unrelated_host)],
260+
)
261+
262+
self.assertEqual(handler._routes.get_by_host_id(changed_host).address, "new.com")
263+
self.assertEqual(handler._routes.get_by_host_id(unrelated_host).address, "keep.com")
264+
236265
@patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections')
237266
def test_handle_change_updates_when_no_host_ids(self, mock_query):
238267
"""When no host_ids are provided, routes should be fully replaced."""
@@ -388,6 +417,42 @@ def test_resolve_host_missing_port_raises(self):
388417
with self.assertRaises(ValueError):
389418
self.handler.resolve_host(host_id)
390419

420+
def test_endpoint_identity_includes_original_port(self):
421+
host_id = uuid.uuid4()
422+
first = ClientRoutesEndPoint(
423+
host_id=host_id,
424+
handler=self.handler,
425+
original_address="10.0.0.1",
426+
original_port=9042,
427+
)
428+
second = ClientRoutesEndPoint(
429+
host_id=host_id,
430+
handler=self.handler,
431+
original_address="10.0.0.1",
432+
original_port=9142,
433+
)
434+
435+
self.assertNotEqual(first, second)
436+
self.assertEqual(len({first, second}), 2)
437+
438+
def test_endpoint_ordering_handles_missing_original_port(self):
439+
host_id = uuid.uuid4()
440+
without_port = ClientRoutesEndPoint(
441+
host_id=host_id,
442+
handler=self.handler,
443+
original_address="10.0.0.1",
444+
original_port=None,
445+
)
446+
with_port = ClientRoutesEndPoint(
447+
host_id=host_id,
448+
handler=self.handler,
449+
original_address="10.0.0.1",
450+
original_port=9042,
451+
)
452+
453+
self.assertCountEqual(
454+
sorted([without_port, with_port]), [without_port, with_port])
455+
391456

392457
class TestClientRoutesEndPointFactory(unittest.TestCase):
393458

0 commit comments

Comments
 (0)