diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..c129bfb3a5 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -468,21 +468,25 @@ def resolve(self) -> Tuple[str, int]: def __eq__(self, other): return (isinstance(other, ClientRoutesEndPoint) and self._host_id == other._host_id and - self._original_address == other._original_address) + self._original_address == other._original_address and + self._original_port == other._original_port) def __hash__(self): - return hash((self._host_id, self._original_address)) + return hash((self._host_id, self._original_address, self._original_port)) + + def _comparison_key(self): + return (self._host_id, self._original_address, + self._original_port is None, self._original_port) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return self._comparison_key() < other._comparison_key() def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) def __repr__(self): - return "<%s: host_id=%s, original_addr=%s>" % ( - self.__class__.__name__, self._host_id, self._original_address) + return "<%s: host_id=%s, original_addr=%s, original_port=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address, self._original_port) class _Frame(object): diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py index 0aa82fc76a..7fa75a1f13 100644 --- a/tests/unit/test_client_routes.py +++ b/tests/unit/test_client_routes.py @@ -388,6 +388,42 @@ def test_resolve_host_missing_port_raises(self): with self.assertRaises(ValueError): self.handler.resolve_host(host_id) + def test_endpoint_identity_includes_original_port(self): + host_id = uuid.uuid4() + first = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + second = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9142, + ) + + self.assertNotEqual(first, second) + self.assertEqual(len({first, second}), 2) + + def test_endpoint_ordering_handles_missing_original_port(self): + host_id = uuid.uuid4() + without_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=None, + ) + with_port = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + + self.assertCountEqual( + sorted([without_port, with_port]), [without_port, with_port]) + class TestClientRoutesEndPointFactory(unittest.TestCase):