diff --git a/benchmarks/decode_benchmark.py b/benchmarks/decode_benchmark.py new file mode 100644 index 0000000000..292f8502c3 --- /dev/null +++ b/benchmarks/decode_benchmark.py @@ -0,0 +1,523 @@ +#!/usr/bin/env python3 +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Isolated benchmark for ProtocolHandler.decode_message(). + +Measures the throughput of decoding synthetic RESULT/ROWS messages of +varying sizes. Does NOT require a live Cassandra/Scylla cluster. + +Run on both ``master`` and the ``remove_copies`` branch to compare: + + python benchmarks/decode_benchmark.py + python benchmarks/decode_benchmark.py --scenarios small_100,large_5k_1KB + python benchmarks/decode_benchmark.py --cython-only --iterations 20 + python benchmarks/decode_benchmark.py --cprofile medium_1k_1KB +""" + +from __future__ import print_function + +import argparse +import gc +import os +import statistics +import struct +import sys +import time + +# --------------------------------------------------------------------------- +# Pin to a single CPU core for consistent results +# --------------------------------------------------------------------------- +try: + os.sched_setaffinity(0, {0}) +except (AttributeError, OSError): + # sched_setaffinity is Linux-only; silently skip on other platforms + pass + +# --------------------------------------------------------------------------- +# Make sure the driver package is importable from the repo root +# --------------------------------------------------------------------------- +_benchdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(_benchdir, '..')) + +from io import BytesIO +from cassandra.marshal import int32_pack, int64_pack, double_pack +from cassandra.protocol import ( + write_int, write_short, write_string, write_value, + ProtocolHandler, _ProtocolHandler, HAVE_CYTHON, +) + +# --------------------------------------------------------------------------- +# CQL type codes (native protocol v4) +# --------------------------------------------------------------------------- +TYPE_BIGINT = 0x0002 +TYPE_BLOB = 0x0003 +TYPE_DOUBLE = 0x0007 +TYPE_INT = 0x0009 +TYPE_VARCHAR = 0x000D + +# Metadata flag +_FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + +# ResultMessage kind +_RESULT_KIND_ROWS = 0x0002 + +# ResultMessage opcode +_OPCODE_RESULT = 0x08 + + +# ====================================================================== +# Synthetic message construction +# ====================================================================== + +def _build_rows_body(columns, row_values_fn, row_count): + """ + Build the raw bytes for a RESULT/ROWS message body. + + Parameters + ---------- + columns : list of (name: str, type_code: int) + Column definitions. + row_values_fn : callable() -> list[bytes|None] + Returns one row of pre-encoded cell values each time it is called. + row_count : int + Number of rows to encode. + + Returns + ------- + bytes + Complete RESULT body ready for ``decode_message()``. + """ + buf = BytesIO() + + # kind = ROWS + write_int(buf, _RESULT_KIND_ROWS) + + # --- metadata --- + write_int(buf, _FLAGS_GLOBAL_TABLES_SPEC) + write_int(buf, len(columns)) + write_string(buf, "ks") + write_string(buf, "tbl") + for col_name, type_code in columns: + write_string(buf, col_name) + write_short(buf, type_code) + + # --- rows --- + write_int(buf, row_count) + for _ in range(row_count): + for cell in row_values_fn(): + write_value(buf, cell) + + return buf.getvalue() + + +# ====================================================================== +# Scenario definitions +# ====================================================================== + +def _make_text(size): + """Return a UTF-8 encoded bytes value of exactly *size* bytes.""" + return (b'x' * size) + + +def _scenario_small_100(): + """100 rows, 3 cols (text 50B, int, bigint) ~3 KB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(50) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 100) + + +def _scenario_medium_1k_256B(): + """1000 rows, 3 cols (text 256B, int, bigint) ~273 KB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(256) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_medium_1k_1KB(): + """1000 rows, 3 cols (text 1024B, int, bigint) ~1 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(1024) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_large_5k_1KB(): + """5000 rows, 3 cols (text 1024B, int, bigint) ~5 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(1024) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 5000) + + +def _scenario_large_1k_4KB(): + """1000 rows, 3 cols (text 4096B, int, bigint) ~4 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(4096) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_wide_5k_doubles(): + """5000 rows, 10 cols (10x double) ~586 KB""" + columns = [("col_d%d" % i, TYPE_DOUBLE) for i in range(10)] + vals = [double_pack(1.0 + i * 0.1) for i in range(10)] + row = lambda: list(vals) + return _build_rows_body(columns, row, 5000) + + +def _scenario_wide_1k_20cols(): + """1000 rows, 20 cols (10x text 64B, 5x int, 3x bigint, 2x double) ~850 KB""" + columns = [] + for i in range(10): + columns.append(("col_text%d" % i, TYPE_VARCHAR)) + for i in range(5): + columns.append(("col_int%d" % i, TYPE_INT)) + for i in range(3): + columns.append(("col_bigint%d" % i, TYPE_BIGINT)) + for i in range(2): + columns.append(("col_double%d" % i, TYPE_DOUBLE)) + + text_val = _make_text(64) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + double_val = double_pack(3.14159) + + def row(): + cells = [] + for _ in range(10): + cells.append(text_val) + for _ in range(5): + cells.append(int_val) + for _ in range(3): + cells.append(bigint_val) + for _ in range(2): + cells.append(double_val) + return cells + + return _build_rows_body(columns, row, 1000) + + +def _scenario_blob_1k_16KB(): + """1000 rows, 2 cols (int, 16 KB blob) ~16 MB""" + columns = [ + ("col_int", TYPE_INT), + ("col_blob", TYPE_BLOB), + ] + int_val = int32_pack(42) + blob_val = os.urandom(16384) + row = lambda: [int_val, blob_val] + return _build_rows_body(columns, row, 1000) + + +SCENARIOS = { + "small_100": ("100 rows, 3 cols (text 50B, int, bigint)", _scenario_small_100), + "medium_1k_256B": ("1000 rows, 3 cols (text 256B, int, bigint)", _scenario_medium_1k_256B), + "medium_1k_1KB": ("1000 rows, 3 cols (text 1024B, int, bigint)", _scenario_medium_1k_1KB), + "large_5k_1KB": ("5000 rows, 3 cols (text 1024B, int, bigint)", _scenario_large_5k_1KB), + "large_1k_4KB": ("1000 rows, 3 cols (text 4096B, int, bigint)", _scenario_large_1k_4KB), + "wide_5k_doubles": ("5000 rows, 10 cols (10x double)", _scenario_wide_5k_doubles), + "wide_1k_20cols": ("1000 rows, 20 cols (10x text64, 5x int, ...)", _scenario_wide_1k_20cols), + "blob_1k_16KB": ("1000 rows, 2 cols (int, 16 KB blob)", _scenario_blob_1k_16KB), +} + +# Ordered list so output is deterministic +SCENARIO_ORDER = [ + "small_100", + "medium_1k_256B", + "medium_1k_1KB", + "large_5k_1KB", + "large_1k_4KB", + "wide_5k_doubles", + "wide_1k_20cols", + "blob_1k_16KB", +] + + +# ====================================================================== +# Benchmark runner +# ====================================================================== + +def _decode(handler, body): + """Call decode_message with the standard benchmark parameters.""" + return handler.decode_message( + protocol_version=4, + protocol_features=None, + user_type_map={}, + stream_id=0, + flags=0, + opcode=_OPCODE_RESULT, + body=body, + decompressor=None, + result_metadata=None, + ) + + +def _run_iterations(handler, body, iterations, warmup): + """ + Run *warmup* + *iterations* decode calls, return list of elapsed + times (seconds) for the measured iterations only. + """ + # Warm-up: let JIT / caches settle + for _ in range(warmup): + _decode(handler, body) + + gc.disable() + try: + times = [] + for _ in range(iterations): + t0 = time.perf_counter() + _decode(handler, body) + t1 = time.perf_counter() + times.append(t1 - t0) + finally: + gc.enable() + return times + + +def _format_time(seconds): + """Human-readable time string.""" + if seconds < 1e-3: + return "%.1f us" % (seconds * 1e6) + elif seconds < 1.0: + return "%.2f ms" % (seconds * 1e3) + else: + return "%.3f s" % seconds + + +def _format_throughput(body_size, seconds): + """MB/s throughput string.""" + mb = body_size / (1024 * 1024) + return "%.1f MB/s" % (mb / seconds) + + +def _report(label, times, body_size, row_count): + """Print a single result line.""" + t_min = min(times) + t_med = statistics.median(times) + t_mean = statistics.mean(times) + rows_per_sec = row_count / t_med + + if rows_per_sec >= 1e6: + rps_str = "%.2fM rows/s" % (rows_per_sec / 1e6) + elif rows_per_sec >= 1e3: + rps_str = "%.0fK rows/s" % (rows_per_sec / 1e3) + else: + rps_str = "%.0f rows/s" % rows_per_sec + + print(" %-14s min=%s median=%s mean=%s (%s, %s)" % ( + label, + _format_time(t_min), + _format_time(t_med), + _format_time(t_mean), + _format_throughput(body_size, t_med), + rps_str, + )) + + +def _extract_row_count(scenario_name): + """Infer the row count from the scenario name for rows/s reporting.""" + mapping = { + "small_100": 100, + "medium_1k_256B": 1000, + "medium_1k_1KB": 1000, + "large_5k_1KB": 5000, + "large_1k_4KB": 1000, + "wide_5k_doubles": 5000, + "wide_1k_20cols": 1000, + "blob_1k_16KB": 1000, + } + return mapping.get(scenario_name, 0) + + +def run_benchmark(scenarios, iterations, warmup, cython_only, python_only, cprofile_scenario): + """ + Run the benchmark for each requested scenario. + """ + print("=" * 78) + print("Decode Benchmark") + print("=" * 78) + print(" Cython available : %s" % HAVE_CYTHON) + print(" Iterations : %d (+ %d warmup)" % (iterations, warmup)) + print(" CPU pinned : %s" % _is_pinned()) + print() + + handlers = [] + if not python_only: + if HAVE_CYTHON: + handlers.append(("Cython", ProtocolHandler)) + elif not cython_only: + print(" [NOTE] Cython extensions not available, skipping Cython path\n") + if not cython_only: + handlers.append(("Python", _ProtocolHandler)) + + if not handlers: + print("ERROR: no handlers selected (Cython not available and --cython-only set)") + sys.exit(1) + + profiler = None + if cprofile_scenario: + import cProfile + profiler = cProfile.Profile() + + for name in scenarios: + desc, builder = SCENARIOS[name] + body = builder() + body_size = len(body) + row_count = _extract_row_count(name) + + print("Scenario: %s (%s, %s body)" % ( + name, desc, _format_size(body_size))) + + for label, handler in handlers: + if profiler and name == cprofile_scenario: + profiler.enable() + + times = _run_iterations(handler, body, iterations, warmup) + + if profiler and name == cprofile_scenario: + profiler.disable() + + _report(label + ":", times, body_size, row_count) + + print() + + if profiler: + print("-" * 78) + print("cProfile results for scenario '%s':" % cprofile_scenario) + print("-" * 78) + import pstats + stats = pstats.Stats(profiler) + stats.strip_dirs() + stats.sort_stats('cumulative') + stats.print_stats(30) + + +def _format_size(nbytes): + """Human-readable byte size.""" + if nbytes >= 1024 * 1024: + return "%.1f MB" % (nbytes / (1024 * 1024)) + elif nbytes >= 1024: + return "%.1f KB" % (nbytes / 1024) + else: + return "%d B" % nbytes + + +def _is_pinned(): + """Check if the process is pinned to a single CPU core.""" + try: + affinity = os.sched_getaffinity(0) + return len(affinity) == 1 + except (AttributeError, OSError): + return False + + +# ====================================================================== +# CLI +# ====================================================================== + +def main(): + parser = argparse.ArgumentParser( + description="Isolated decode_message benchmark (no cluster required)") + parser.add_argument( + "--iterations", "-n", type=int, default=10, + help="Number of timed iterations per scenario (default: 10)") + parser.add_argument( + "--warmup", "-w", type=int, default=3, + help="Number of warmup iterations (default: 3)") + parser.add_argument( + "--scenarios", "-s", type=str, default=None, + help="Comma-separated list of scenarios to run (default: all). " + "Available: %s" % ", ".join(SCENARIO_ORDER)) + parser.add_argument( + "--cython-only", action="store_true", default=False, + help="Only benchmark the Cython (fast) path") + parser.add_argument( + "--python-only", action="store_true", default=False, + help="Only benchmark the pure-Python path") + parser.add_argument( + "--cprofile", type=str, default=None, metavar="SCENARIO", + help="Enable cProfile for the named scenario and print top-30 stats") + parser.add_argument( + "--list", action="store_true", default=False, + help="List available scenarios and exit") + + args = parser.parse_args() + + if args.list: + print("Available scenarios:") + for name in SCENARIO_ORDER: + desc, builder = SCENARIOS[name] + print(" %-20s %s" % (name, desc)) + sys.exit(0) + + if args.cython_only and args.python_only: + parser.error("--cython-only and --python-only are mutually exclusive") + + if args.scenarios: + selected = [s.strip() for s in args.scenarios.split(",")] + for s in selected: + if s not in SCENARIOS: + parser.error("Unknown scenario: %s" % s) + else: + selected = list(SCENARIO_ORDER) + + if args.cprofile and args.cprofile not in selected: + parser.error("--cprofile scenario '%s' is not in the selected scenarios" % args.cprofile) + + run_benchmark( + scenarios=selected, + iterations=args.iterations, + warmup=args.warmup, + cython_only=args.cython_only, + python_only=args.python_only, + cprofile_scenario=args.cprofile, + ) + + +if __name__ == "__main__": + main() diff --git a/cassandra/bytesio.pxd b/cassandra/bytesio.pxd index d52d3fa8fe..40edcd996d 100644 --- a/cassandra/bytesio.pxd +++ b/cassandra/bytesio.pxd @@ -17,4 +17,5 @@ cdef class BytesIOReader: cdef char *buf_ptr cdef Py_ssize_t pos cdef Py_ssize_t size + cdef Py_ssize_t _initial_offset cdef char *read(self, Py_ssize_t n = ?) except NULL diff --git a/cassandra/bytesio.pyx b/cassandra/bytesio.pyx index 1a57911fcf..4244e74d7c 100644 --- a/cassandra/bytesio.pyx +++ b/cassandra/bytesio.pyx @@ -16,12 +16,18 @@ cdef class BytesIOReader: """ This class provides efficient support for reading bytes from a 'bytes' buffer, by returning char * values directly without allocating intermediate objects. + + An optional offset allows reading from the middle of an existing buffer, + avoiding a copy when only a suffix of the bytes is needed. """ - def __init__(self, bytes buf): + def __init__(self, bytes buf, Py_ssize_t offset=0): + if offset < 0 or offset > len(buf): + raise ValueError("offset %d out of range for buffer of length %d" % (offset, len(buf))) self.buf = buf - self.size = len(buf) - self.buf_ptr = self.buf + self._initial_offset = offset + self.size = len(buf) - offset + self.buf_ptr = self.buf + offset cdef char *read(self, Py_ssize_t n = -1) except NULL: """Read at most size bytes from the file diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..6453437011 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -630,14 +630,25 @@ def readable_io_bytes(self): def readable_cql_frame_bytes(self): return self.cql_frame_buffer.tell() + @staticmethod + def _reset_buffer(buf): + """ + Reset a BytesIO buffer by discarding consumed data. + Avoid an intermediate bytes copy from .read(); slice the existing buffer. + BytesIO will still materialize its own backing store, but this reduces + one full-buffer allocation on the hot receive path. + """ + pos = buf.tell() + new_buf = io.BytesIO(buf.getbuffer()[pos:]) + new_buf.seek(0, 2) # 2 == SEEK_END + return new_buf + def reset_io_buffer(self): - self._io_buffer = io.BytesIO(self._io_buffer.read()) - self._io_buffer.seek(0, 2) # 2 == SEEK_END + self._io_buffer = self._reset_buffer(self._io_buffer) def reset_cql_frame_buffer(self): if self.is_checksumming_enabled: - self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read()) - self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END + self._cql_frame_buffer = self._reset_buffer(self._cql_frame_buffer) else: self.reset_io_buffer() @@ -671,7 +682,9 @@ class Connection(object): CALLBACK_ERR_THREAD_THRESHOLD = 100 - in_buffer_size = 4096 + # 64 KiB recv buffer reduces the number of syscalls when reading + # large result sets, at a modest per-connection memory cost. + in_buffer_size = 65536 out_buffer_size = 4096 cql_version = None @@ -1191,19 +1204,23 @@ def control_conn_disposed(self): @defunct_on_error def _read_frame_header(self): - buf = self._io_buffer.cql_frame_buffer.getvalue() - pos = len(buf) + cql_buf = self._io_buffer.cql_frame_buffer + pos = cql_buf.tell() if pos: - version = buf[0] & PROTOCOL_VERSION_MASK - if version not in ProtocolVersion.SUPPORTED_VERSIONS: - raise ProtocolError("This version of the driver does not support protocol version %d" % version) - # this frame header struct is everything after the version byte - header_size = frame_header_v3.size + 1 - if pos >= header_size: - flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1) - if body_len < 0: - raise ProtocolError("Received negative body length: %r" % body_len) - self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) + buf = cql_buf.getbuffer() + try: + version = buf[0] & PROTOCOL_VERSION_MASK + if version not in ProtocolVersion.SUPPORTED_VERSIONS: + raise ProtocolError("This version of the driver does not support protocol version %d" % version) + # this frame header struct is everything after the version byte + header_size = frame_header_v3.size + 1 + if pos >= header_size: + flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1) + if body_len < 0: + raise ProtocolError("Received negative body length: %r" % body_len) + self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) + finally: + del buf # release memoryview before any buffer mutation return pos @defunct_on_error @@ -1257,8 +1274,16 @@ def process_io_buffer(self): return else: frame = self._current_frame - self._io_buffer.cql_frame_buffer.seek(frame.body_offset) - msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) + # Use memoryview to avoid intermediate allocation, then + # convert to bytes. Explicitly release the memoryview + # before any buffer mutation (seek / reset). + cql_buf = self._io_buffer.cql_frame_buffer + buf = cql_buf.getbuffer() + try: + msg = bytes(buf[frame.body_offset:frame.end_pos]) + finally: + del buf # release memoryview before buffer mutation + cql_buf.seek(frame.end_pos) self.process_msg(frame, msg) self._io_buffer.reset_cql_frame_buffer() self._current_frame = None diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..445543ac11 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -53,6 +53,43 @@ class NotSupportedError(Exception): class InternalError(Exception): pass + +class BytesReader: + """ + Lightweight reader for bytes data without BytesIO overhead. + Provides the same read() interface but operates directly on a + bytes object, avoiding internal buffer copies. + + Unlike io.BytesIO.read(n), read(n) raises EOFError when fewer than + n bytes remain. This is intentional: protocol parsing should fail + fast on truncated or malformed frames rather than silently returning + short data. + """ + __slots__ = ('_data', '_pos', '_size') + + def __init__(self, data): + # Materialize memoryview up front so read() never needs to check + self._data = bytes(data) if isinstance(data, memoryview) else data + self._pos = 0 + self._size = len(self._data) + + def read(self, n=-1): + if n < 0: + result = self._data[self._pos:] + self._pos = self._size + else: + end = self._pos + n + if end > self._size: + raise EOFError("Cannot read past the end of the buffer") + result = self._data[self._pos:end] + self._pos = end + return result + + def remaining_buffer(self): + """Return (underlying_bytes, current_position) for zero-copy handoff.""" + return self._data, self._pos + + ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) HEADER_DIRECTION_TO_CLIENT = 0x80 @@ -1154,7 +1191,8 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre body = decompressor(body) flags ^= COMPRESSED_FLAG - body = io.BytesIO(body) + # Use lightweight BytesReader instead of io.BytesIO to avoid buffer copy + body = BytesReader(body) if flags & TRACING_FLAG: trace_id = UUID(bytes=body.read(16)) flags ^= TRACING_FLAG diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..5a99dfac36 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -35,13 +35,19 @@ def make_recv_results_rows(ColumnParser colparser): desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, [ColDesc(md[0], md[1], md[2]) for md in column_metadata], make_deserializers(self.column_types), protocol_version) - reader = BytesIOReader(f.read()) + # Zero-copy handoff: reuse the underlying bytes buffer at its current + # position instead of copying via f.read(). + if hasattr(f, 'remaining_buffer'): + buf_data, buf_offset = f.remaining_buffer() + reader = BytesIOReader(buf_data, buf_offset) + else: + reader = BytesIOReader(f.read()) try: self.parsed_rows = colparser.parse_rows(reader, desc) except Exception as e: # Use explicitly the TupleRowParser to display better error messages for column decoding failures rowparser = TupleRowParser() - reader.buf_ptr = reader.buf + reader.buf_ptr = reader.buf + reader._initial_offset reader.pos = 0 rowcount = read_int(reader) for i in range(rowcount): diff --git a/tests/unit/test_bytesio_reader.py b/tests/unit/test_bytesio_reader.py new file mode 100644 index 0000000000..bb7efe1193 --- /dev/null +++ b/tests/unit/test_bytesio_reader.py @@ -0,0 +1,67 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import pytest + +try: + from cassandra.bytesio import BytesIOReader + + has_cython = True +except ImportError: + has_cython = False + + +@pytest.mark.skipif(not has_cython, reason="Cython extensions not compiled") +class BytesIOReaderTest(unittest.TestCase): + """Tests for the Cython BytesIOReader, including the offset parameter. + + Note: BytesIOReader.read() is a cdef method, so it cannot be called + directly from Python. Reading with an offset is exercised through the + end-to-end decode_message test in test_protocol.py which goes through + the Cython row parser path (remaining_buffer -> BytesIOReader(buf, offset)). + """ + + def test_construct_no_offset(self): + # Should not raise + reader = BytesIOReader(b"\x00\x01\x02\x03\x04\x05") + + def test_construct_with_zero_offset(self): + reader = BytesIOReader(b"hello world", 0) + + def test_construct_with_offset(self): + reader = BytesIOReader(b"header_row_data", 7) + + def test_construct_offset_at_end(self): + data = b"abcdef" + reader = BytesIOReader(data, len(data)) + + def test_construct_negative_offset_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", -1) + + def test_construct_offset_past_end_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", 6) + + def test_construct_offset_way_past_end_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", 100) + + def test_construct_empty_buffer_zero_offset(self): + reader = BytesIOReader(b"", 0) + + def test_construct_empty_buffer_nonzero_offset_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"", 1) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..b88b5bee2d 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,8 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator, + _ConnectionIOBuffer) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -571,3 +572,42 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class ResetBufferTest(unittest.TestCase): + """Tests for _ConnectionIOBuffer._reset_buffer static method.""" + + def test_preserves_remaining_data(self): + buf = BytesIO() + buf.write(b"already_consumed_new_data") + buf.seek(17) # position after "already_consumed_" + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"new_data") + # Cursor is at SEEK_END, ready for further writes + self.assertEqual(result.tell(), len(b"new_data")) + + def test_empty_remaining(self): + buf = BytesIO() + buf.write(b"all_consumed") + buf.seek(12) + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"") + self.assertEqual(result.tell(), 0) + + def test_nothing_consumed(self): + buf = BytesIO() + buf.write(b"all_remaining") + buf.seek(0) + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"all_remaining") + # Cursor is at SEEK_END, ready for further writes + self.assertEqual(result.tell(), len(b"all_remaining")) + + def test_new_buffer_is_writable(self): + buf = BytesIO() + buf.write(b"head_tail") + buf.seek(5) + result = _ConnectionIOBuffer._reset_buffer(buf) + result.seek(0, 2) # seek to end + result.write(b"_more") + self.assertEqual(result.getvalue(), b"tail_more") diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..7ad4d0bf2c 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -13,15 +13,26 @@ # limitations under the License. import unittest +from io import BytesIO from unittest.mock import Mock from cassandra import ProtocolVersion, UnsupportedOperation from cassandra.protocol import ( - PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, - _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, - _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + PrepareMessage, + QueryMessage, + ExecuteMessage, + UnsupportedOperation, + _PAGING_OPTIONS_FLAG, + _WITH_SERIAL_CONSISTENCY_FLAG, + _PAGE_SIZE_FLAG, + _WITH_PAGING_STATE_FLAG, + BatchMessage, + BytesReader, + ProtocolHandler, + SupportedMessage, + ReadyMessage, + write_stringmultimap, ) from cassandra.query import BatchType from cassandra.marshal import uint32_unpack @@ -30,7 +41,6 @@ class MessageTest(unittest.TestCase): - def test_prepare_message(self): """ Test to check the appropriate calls are made @@ -45,28 +55,38 @@ def test_prepare_message(self): io = Mock() message.send_body(io, 4) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) + self._check_calls(io, [(b"\x00\x00\x00\x01",), (b"a",)]) io.reset_mock() message.send_body(io, 5) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) + self._check_calls(io, [(b"\x00\x00\x00\x01",), (b"a",), (b"\x00\x00\x00\x00",)]) def test_execute_message(self): - message = ExecuteMessage('1', [], 4) + message = ExecuteMessage("1", [], 4) io = Mock() message.send_body(io, 4) - self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) + self._check_calls( + io, [(b"\x00\x01",), (b"1",), (b"\x00\x04",), (b"\x01",), (b"\x00\x00",)] + ) io.reset_mock() - message.result_metadata_id = 'foo' + message.result_metadata_id = "foo" message.send_body(io, 5) - self._check_calls(io, [(b'\x00\x01',), (b'1',), - (b'\x00\x03',), (b'foo',), - (b'\x00\x04',), - (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) + self._check_calls( + io, + [ + (b"\x00\x01",), + (b"1",), + (b"\x00\x03",), + (b"foo",), + (b"\x00\x04",), + (b"\x00\x00\x00\x01",), + (b"\x00\x00",), + ], + ) def test_query_message(self): """ @@ -82,11 +102,16 @@ def test_query_message(self): io = Mock() message.send_body(io, 4) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) + self._check_calls( + io, [(b"\x00\x00\x00\x01",), (b"a",), (b"\x00\x03",), (b"\x00",)] + ) io.reset_mock() message.send_body(io, 5) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) + self._check_calls( + io, + [(b"\x00\x00\x00\x01",), (b"a",), (b"\x00\x03",), (b"\x00\x00\x00\x00",)], + ) def _check_calls(self, io, expected): assert tuple(c[1] for c in io.write.mock_calls) == tuple(expected) @@ -112,80 +137,242 @@ def test_prepare_flag(self): io.reset_mock() def test_prepare_flag_with_keyspace(self): - message = PrepareMessage("a", keyspace='ks') + message = PrepareMessage("a", keyspace="ks") io = Mock() for version in ProtocolVersion.SUPPORTED_VERSIONS: if ProtocolVersion.uses_keyspace_flag(version): message.send_body(io, version) - self._check_calls(io, [ - (b'\x00\x00\x00\x01',), - (b'a',), - (b'\x00\x00\x00\x01',), - (b'\x00\x02',), - (b'ks',), - ]) + self._check_calls( + io, + [ + (b"\x00\x00\x00\x01",), + (b"a",), + (b"\x00\x00\x00\x01",), + (b"\x00\x02",), + (b"ks",), + ], + ) else: with pytest.raises(UnsupportedOperation): message.send_body(io, version) io.reset_mock() def test_keyspace_flag_raises_before_v5(self): - keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks') - io = Mock(name='io') + keyspace_message = QueryMessage("a", consistency_level=3, keyspace="ks") + io = Mock(name="io") - with pytest.raises(UnsupportedOperation, match='Keyspaces.*set'): + with pytest.raises(UnsupportedOperation, match="Keyspaces.*set"): keyspace_message.send_body(io, protocol_version=4) io.assert_not_called() def test_keyspace_written_with_length(self): - io = Mock(name='io') + io = Mock(name="io") base_expected = [ - (b'\x00\x00\x00\x01',), - (b'a',), - (b'\x00\x03',), - (b'\x00\x00\x00\x80',), # options w/ keyspace flag + (b"\x00\x00\x00\x01",), + (b"a",), + (b"\x00\x03",), + (b"\x00\x00\x00\x80",), # options w/ keyspace flag ] - QueryMessage('a', consistency_level=3, keyspace='ks').send_body( + QueryMessage("a", consistency_level=3, keyspace="ks").send_body( io, protocol_version=5 ) - self._check_calls(io, base_expected + [ - (b'\x00\x02',), # length of keyspace string - (b'ks',), - ]) + self._check_calls( + io, + base_expected + + [ + (b"\x00\x02",), # length of keyspace string + (b"ks",), + ], + ) io.reset_mock() - QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( + QueryMessage("a", consistency_level=3, keyspace="keyspace").send_body( io, protocol_version=5 ) - self._check_calls(io, base_expected + [ - (b'\x00\x08',), # length of keyspace string - (b'keyspace',), - ]) + self._check_calls( + io, + base_expected + + [ + (b"\x00\x08",), # length of keyspace string + (b"keyspace",), + ], + ) def test_batch_message_with_keyspace(self): self.maxDiff = None - io = Mock(name='io') + io = Mock(name="io") batch = BatchMessage( batch_type=BatchType.LOGGED, - queries=((False, 'stmt a', ('param a',)), - (False, 'stmt b', ('param b',)), - (False, 'stmt c', ('param c',)) - ), + queries=( + (False, "stmt a", ("param a",)), + (False, "stmt b", ("param b",)), + (False, "stmt c", ("param c",)), + ), consistency_level=3, - keyspace='ks' + keyspace="ks", ) batch.send_body(io, protocol_version=5) - self._check_calls(io, - ((b'\x00',), (b'\x00\x03',), (b'\x00',), - (b'\x00\x00\x00\x06',), (b'stmt a',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), - (b'\x00\x03',), - (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) + self._check_calls( + io, + ( + (b"\x00",), + (b"\x00\x03",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt a",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param a",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt b",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param b",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt c",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param c",), + (b"\x00\x03",), + (b"\x00\x00\x00\x80",), + (b"\x00\x02",), + (b"ks",), + ), + ) + + +class BytesReaderTest(unittest.TestCase): + """Tests for the BytesReader class used in decode_message.""" + + def test_read_exact(self): + r = BytesReader(b"abcdef") + self.assertEqual(r.read(3), b"abc") + self.assertEqual(r.read(3), b"def") + + def test_read_sequential(self): + r = BytesReader(b"\x00\x01\x02\x03") + self.assertEqual(r.read(1), b"\x00") + self.assertEqual(r.read(2), b"\x01\x02") + self.assertEqual(r.read(1), b"\x03") + + def test_read_zero_bytes(self): + r = BytesReader(b"abc") + self.assertEqual(r.read(0), b"") + self.assertEqual(r.read(3), b"abc") + + def test_read_all_no_args(self): + r = BytesReader(b"hello") + self.assertEqual(r.read(), b"hello") + + def test_read_all_negative(self): + r = BytesReader(b"hello") + self.assertEqual(r.read(-1), b"hello") + + def test_read_all_after_partial(self): + r = BytesReader(b"hello world") + r.read(6) + self.assertEqual(r.read(), b"world") + + def test_read_past_end_raises(self): + r = BytesReader(b"abc") + with self.assertRaises(EOFError): + r.read(4) + + def test_read_past_end_after_partial(self): + r = BytesReader(b"abc") + r.read(2) + with self.assertRaises(EOFError): + r.read(2) + + def test_empty_data(self): + r = BytesReader(b"") + self.assertEqual(r.read(), b"") + self.assertEqual(r.read(0), b"") + with self.assertRaises(EOFError): + r.read(1) + + def test_memoryview_input(self): + data = b"hello world" + r = BytesReader(memoryview(data)) + result = r.read(5) + self.assertIsInstance(result, bytes) + self.assertEqual(result, b"hello") + + def test_return_type_is_bytes(self): + r = BytesReader(b"\x00\x01\x02") + result = r.read(3) + self.assertIsInstance(result, bytes) + + def test_remaining_buffer(self): + r = BytesReader(b"header_row_data") + r.read(7) # consume "header_" + buf, pos = r.remaining_buffer() + self.assertEqual(buf, b"header_row_data") + self.assertEqual(pos, 7) + self.assertEqual(buf[pos:], b"row_data") + + def test_remaining_buffer_at_start(self): + r = BytesReader(b"all_data") + buf, pos = r.remaining_buffer() + self.assertEqual(pos, 0) + self.assertEqual(buf, b"all_data") + + +class DecodeMessageTest(unittest.TestCase): + """ + End-to-end tests for ProtocolHandler.decode_message using BytesReader. + + These verify that real message types round-trip through the decode path + that now uses BytesReader instead of io.BytesIO. + """ + + def _decode(self, opcode, body): + return ProtocolHandler.decode_message( + protocol_version=ProtocolVersion.MAX_SUPPORTED, + protocol_features=None, + user_type_map={}, + stream_id=0, + flags=0, + opcode=opcode, + body=body, + decompressor=None, + result_metadata=None, + ) + + def test_ready_message_empty_body(self): + """ReadyMessage has an empty body (opcode 0x02).""" + msg = self._decode(0x02, b"") + self.assertIsInstance(msg, ReadyMessage) + self.assertEqual(msg.stream_id, 0) + self.assertIsNone(msg.trace_id) + self.assertIsNone(msg.custom_payload) + + def test_supported_message_with_body(self): + """SupportedMessage reads a stringmultimap from body (opcode 0x06).""" + buf = BytesIO() + write_stringmultimap( + buf, + { + "CQL_VERSION": ["3.4.5"], + "COMPRESSION": ["lz4", "snappy"], + }, ) + body = buf.getvalue() + msg = self._decode(0x06, body) + self.assertIsInstance(msg, SupportedMessage) + self.assertEqual(msg.cql_versions, ["3.4.5"]) + self.assertEqual(msg.options["COMPRESSION"], ["lz4", "snappy"]) + + def test_decode_with_memoryview_body(self): + """decode_message should accept a memoryview body (BytesReader materializes it).""" + buf = BytesIO() + write_stringmultimap(buf, {"CQL_VERSION": ["3.0.0"]}) + body = memoryview(buf.getvalue()) + msg = self._decode(0x06, body) + self.assertIsInstance(msg, SupportedMessage) + self.assertEqual(msg.cql_versions, ["3.0.0"])