Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
523 changes: 523 additions & 0 deletions benchmarks/decode_benchmark.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions cassandra/bytesio.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ cdef class BytesIOReader:
cdef char *buf_ptr
cdef Py_ssize_t pos
cdef Py_ssize_t size
cdef Py_ssize_t _initial_offset
cdef char *read(self, Py_ssize_t n = ?) except NULL
12 changes: 9 additions & 3 deletions cassandra/bytesio.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ cdef class BytesIOReader:
"""
This class provides efficient support for reading bytes from a 'bytes' buffer,
by returning char * values directly without allocating intermediate objects.

An optional offset allows reading from the middle of an existing buffer,
avoiding a copy when only a suffix of the bytes is needed.
"""

def __init__(self, bytes buf):
def __init__(self, bytes buf, Py_ssize_t offset=0):
if offset < 0 or offset > len(buf):
raise ValueError("offset %d out of range for buffer of length %d" % (offset, len(buf)))
self.buf = buf
self.size = len(buf)
self.buf_ptr = self.buf
self._initial_offset = offset
self.size = len(buf) - offset
self.buf_ptr = <char*>self.buf + offset

cdef char *read(self, Py_ssize_t n = -1) except NULL:
"""Read at most size bytes from the file
Expand Down
63 changes: 44 additions & 19 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,14 +630,25 @@ def readable_io_bytes(self):
def readable_cql_frame_bytes(self):
return self.cql_frame_buffer.tell()

@staticmethod
def _reset_buffer(buf):
"""
Reset a BytesIO buffer by discarding consumed data.
Avoid an intermediate bytes copy from .read(); slice the existing buffer.
BytesIO will still materialize its own backing store, but this reduces
one full-buffer allocation on the hot receive path.
"""
pos = buf.tell()
new_buf = io.BytesIO(buf.getbuffer()[pos:])
new_buf.seek(0, 2) # 2 == SEEK_END
return new_buf

def reset_io_buffer(self):
self._io_buffer = io.BytesIO(self._io_buffer.read())
self._io_buffer.seek(0, 2) # 2 == SEEK_END
self._io_buffer = self._reset_buffer(self._io_buffer)

def reset_cql_frame_buffer(self):
if self.is_checksumming_enabled:
self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read())
self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END
self._cql_frame_buffer = self._reset_buffer(self._cql_frame_buffer)
else:
self.reset_io_buffer()

Expand Down Expand Up @@ -671,7 +682,9 @@ class Connection(object):

CALLBACK_ERR_THREAD_THRESHOLD = 100

in_buffer_size = 4096
# 64 KiB recv buffer reduces the number of syscalls when reading
# large result sets, at a modest per-connection memory cost.
in_buffer_size = 65536
out_buffer_size = 4096

cql_version = None
Expand Down Expand Up @@ -1191,19 +1204,23 @@ def control_conn_disposed(self):

@defunct_on_error
def _read_frame_header(self):
buf = self._io_buffer.cql_frame_buffer.getvalue()
pos = len(buf)
cql_buf = self._io_buffer.cql_frame_buffer
pos = cql_buf.tell()
if pos:
version = buf[0] & PROTOCOL_VERSION_MASK
if version not in ProtocolVersion.SUPPORTED_VERSIONS:
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
# this frame header struct is everything after the version byte
header_size = frame_header_v3.size + 1
if pos >= header_size:
flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1)
if body_len < 0:
raise ProtocolError("Received negative body length: %r" % body_len)
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
buf = cql_buf.getbuffer()
try:
version = buf[0] & PROTOCOL_VERSION_MASK
if version not in ProtocolVersion.SUPPORTED_VERSIONS:
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
# this frame header struct is everything after the version byte
header_size = frame_header_v3.size + 1
if pos >= header_size:
flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1)
if body_len < 0:
raise ProtocolError("Received negative body length: %r" % body_len)
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
finally:
del buf # release memoryview before any buffer mutation
return pos

@defunct_on_error
Expand Down Expand Up @@ -1257,8 +1274,16 @@ def process_io_buffer(self):
return
else:
frame = self._current_frame
self._io_buffer.cql_frame_buffer.seek(frame.body_offset)
msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset)
# Use memoryview to avoid intermediate allocation, then
# convert to bytes. Explicitly release the memoryview
# before any buffer mutation (seek / reset).
cql_buf = self._io_buffer.cql_frame_buffer
buf = cql_buf.getbuffer()
try:
msg = bytes(buf[frame.body_offset:frame.end_pos])
finally:
del buf # release memoryview before buffer mutation
cql_buf.seek(frame.end_pos)
self.process_msg(frame, msg)
self._io_buffer.reset_cql_frame_buffer()
self._current_frame = None
Expand Down
40 changes: 39 additions & 1 deletion cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,43 @@ class NotSupportedError(Exception):
class InternalError(Exception):
pass


class BytesReader:
"""
Lightweight reader for bytes data without BytesIO overhead.
Provides the same read() interface but operates directly on a
bytes object, avoiding internal buffer copies.

Unlike io.BytesIO.read(n), read(n) raises EOFError when fewer than
n bytes remain. This is intentional: protocol parsing should fail
fast on truncated or malformed frames rather than silently returning
short data.
"""
__slots__ = ('_data', '_pos', '_size')

def __init__(self, data):
# Materialize memoryview up front so read() never needs to check
self._data = bytes(data) if isinstance(data, memoryview) else data
self._pos = 0
self._size = len(self._data)

def read(self, n=-1):
if n < 0:
result = self._data[self._pos:]
self._pos = self._size
else:
end = self._pos + n
if end > self._size:
raise EOFError("Cannot read past the end of the buffer")
result = self._data[self._pos:end]
self._pos = end
return result

def remaining_buffer(self):
"""Return (underlying_bytes, current_position) for zero-copy handoff."""
return self._data, self._pos


ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])

HEADER_DIRECTION_TO_CLIENT = 0x80
Expand Down Expand Up @@ -1154,7 +1191,8 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre
body = decompressor(body)
flags ^= COMPRESSED_FLAG

body = io.BytesIO(body)
# Use lightweight BytesReader instead of io.BytesIO to avoid buffer copy
body = BytesReader(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= TRACING_FLAG
Expand Down
10 changes: 8 additions & 2 deletions cassandra/row_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@ def make_recv_results_rows(ColumnParser colparser):
desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy,
[ColDesc(md[0], md[1], md[2]) for md in column_metadata],
make_deserializers(self.column_types), protocol_version)
reader = BytesIOReader(f.read())
# Zero-copy handoff: reuse the underlying bytes buffer at its current
# position instead of copying via f.read().
if hasattr(f, 'remaining_buffer'):
buf_data, buf_offset = f.remaining_buffer()
reader = BytesIOReader(buf_data, buf_offset)
else:
reader = BytesIOReader(f.read())
try:
self.parsed_rows = colparser.parse_rows(reader, desc)
except Exception as e:
# Use explicitly the TupleRowParser to display better error messages for column decoding failures
rowparser = TupleRowParser()
reader.buf_ptr = reader.buf
reader.buf_ptr = <char*>reader.buf + reader._initial_offset
reader.pos = 0
rowcount = read_int(reader)
for i in range(rowcount):
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/test_bytesio_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest

try:
from cassandra.bytesio import BytesIOReader

has_cython = True
except ImportError:
has_cython = False


@pytest.mark.skipif(not has_cython, reason="Cython extensions not compiled")
class BytesIOReaderTest(unittest.TestCase):
"""Tests for the Cython BytesIOReader, including the offset parameter.

Note: BytesIOReader.read() is a cdef method, so it cannot be called
directly from Python. Reading with an offset is exercised through the
end-to-end decode_message test in test_protocol.py which goes through
the Cython row parser path (remaining_buffer -> BytesIOReader(buf, offset)).
"""

def test_construct_no_offset(self):
# Should not raise
reader = BytesIOReader(b"\x00\x01\x02\x03\x04\x05")

def test_construct_with_zero_offset(self):
reader = BytesIOReader(b"hello world", 0)

def test_construct_with_offset(self):
reader = BytesIOReader(b"header_row_data", 7)

def test_construct_offset_at_end(self):
data = b"abcdef"
reader = BytesIOReader(data, len(data))

def test_construct_negative_offset_raises(self):
with self.assertRaises(ValueError):
BytesIOReader(b"hello", -1)

def test_construct_offset_past_end_raises(self):
with self.assertRaises(ValueError):
BytesIOReader(b"hello", 6)

def test_construct_offset_way_past_end_raises(self):
with self.assertRaises(ValueError):
BytesIOReader(b"hello", 100)

def test_construct_empty_buffer_zero_offset(self):
reader = BytesIOReader(b"", 0)

def test_construct_empty_buffer_nonzero_offset_raises(self):
with self.assertRaises(ValueError):
BytesIOReader(b"", 1)
42 changes: 41 additions & 1 deletion tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator,
_ConnectionIOBuffer)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ProtocolHandler)
Expand Down Expand Up @@ -571,3 +572,42 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange):
second_run = list(itertools.islice(gen.generate(0, 2), 5))

assert first_run == second_run


class ResetBufferTest(unittest.TestCase):
"""Tests for _ConnectionIOBuffer._reset_buffer static method."""

def test_preserves_remaining_data(self):
buf = BytesIO()
buf.write(b"already_consumed_new_data")
buf.seek(17) # position after "already_consumed_"
result = _ConnectionIOBuffer._reset_buffer(buf)
self.assertEqual(result.getvalue(), b"new_data")
# Cursor is at SEEK_END, ready for further writes
self.assertEqual(result.tell(), len(b"new_data"))

def test_empty_remaining(self):
buf = BytesIO()
buf.write(b"all_consumed")
buf.seek(12)
result = _ConnectionIOBuffer._reset_buffer(buf)
self.assertEqual(result.getvalue(), b"")
self.assertEqual(result.tell(), 0)

def test_nothing_consumed(self):
buf = BytesIO()
buf.write(b"all_remaining")
buf.seek(0)
result = _ConnectionIOBuffer._reset_buffer(buf)
self.assertEqual(result.getvalue(), b"all_remaining")
# Cursor is at SEEK_END, ready for further writes
self.assertEqual(result.tell(), len(b"all_remaining"))

def test_new_buffer_is_writable(self):
buf = BytesIO()
buf.write(b"head_tail")
buf.seek(5)
result = _ConnectionIOBuffer._reset_buffer(buf)
result.seek(0, 2) # seek to end
result.write(b"_more")
self.assertEqual(result.getvalue(), b"tail_more")
Loading
Loading