diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 98e8676bbc..a1833887d4 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -481,7 +481,15 @@ cpdef Deserializer find_deserializer(cqltype): def obj_array(list objs): - """Create a (Cython) array of objects given a list of objects""" + """Create a (Cython) array of objects given a list of objects. + + Returns the plain list for empty input since ``cython_array`` does + not support zero-length shapes. Callers that use + ``cdef Deserializer[::1]`` typed memoryviews must guard against + empty input before assignment. + """ + if not objs: + return objs cdef object[:] arr cdef Py_ssize_t i arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O") diff --git a/cassandra/serializers.pxd b/cassandra/serializers.pxd new file mode 100644 index 0000000000..60297077a8 --- /dev/null +++ b/cassandra/serializers.pxd @@ -0,0 +1,20 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef class Serializer: + # The cqltypes._CassandraType corresponding to this serializer + cdef object cqltype + + cpdef bytes serialize(self, object value, int protocol_version) diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx new file mode 100644 index 0000000000..8bbc0b27be --- /dev/null +++ b/cassandra/serializers.pyx @@ -0,0 +1,585 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cython-optimized serializers for CQL types. + +Mirrors the architecture of deserializers.pyx. Currently implements +optimized serialization for: +- FloatType (4-byte big-endian float) +- DoubleType (8-byte big-endian double) +- Int32Type (4-byte big-endian signed int) +- VectorType (type-specialized for float/double/int32, generic fallback) + +For all other types, GenericSerializer delegates to the Python-level +cqltype.serialize() classmethod. +""" + +from libc.stdint cimport int32_t +from libc.string cimport memcpy +from libc.math cimport isinf, isnan +from libc.float cimport FLT_MAX +from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING +from cython.view cimport array as cython_array + +from cassandra import cqltypes +from operator import index as _operator_index + +cdef bint is_little_endian +from cassandra.util import is_little_endian + + +# --------------------------------------------------------------------------- +# Range-check helpers (match struct.pack error semantics) +# --------------------------------------------------------------------------- + +cdef inline void _check_float_range(double value) except *: + """Raise OverflowError for finite values that overflow float32. + + Uses the same semantics as ``struct.pack('>f', value)``: cast the + double to float and reject only if the result is ±inf while the + input was finite. This correctly accepts values slightly above + ``FLT_MAX`` that round down (e.g. 3.4028235e38). + + Intentionally raises OverflowError (not struct.error) so callers + can catch a single exception type. inf, -inf, and nan pass + through unchanged. + """ + if not isinf(value) and not isnan(value): + if isinf(value): + raise OverflowError( + "Value %r too large for float32 (max %r)" % (value, FLT_MAX)) + + +cdef inline void _check_int32_range(object value) except *: + """Raise OverflowError for values outside the signed int32 range. + + Intentionally raises OverflowError (not struct.error) so callers + can catch a single exception type. The accepted range matches + struct.pack('>i', ...): [-2147483648, 2147483647]. The check must + be done on the Python int *before* the C-level cast, + which would silently truncate. + + The value should already have been coerced via ``_coerce_int()`` + (i.e. the ``__index__`` protocol) before being passed here. + """ + if value > 2147483647 or value < -2147483648: + raise OverflowError( + "Value %r out of range for int32 " + "(must be between -2147483648 and 2147483647)" % (value,)) + + +cdef inline object _coerce_int(object value): + """Coerce *value* to a Python ``int`` via the ``__index__`` protocol. + + This matches ``struct.pack('>i', value)`` semantics, which accepts any + object implementing ``__index__`` (e.g. numpy integer scalars). Raises + ``TypeError`` for objects that do not support the protocol. + """ + return _operator_index(value) + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + +cdef class Serializer: + """Cython-based serializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + cpdef bytes serialize(self, object value, int protocol_version): + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Scalar serializers +# --------------------------------------------------------------------------- + +cdef class SerFloatType(Serializer): + """Serialize a Python float to 4-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_float_range(value) + cdef float val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +cdef class SerDoubleType(Serializer): + """Serialize a Python float to 8-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + cdef double val = value + cdef char out[8] + cdef char *src = &val + + if is_little_endian: + out[0] = src[7] + out[1] = src[6] + out[2] = src[5] + out[3] = src[4] + out[4] = src[3] + out[5] = src[2] + out[6] = src[1] + out[7] = src[0] + else: + memcpy(out, src, 8) + + return PyBytes_FromStringAndSize(out, 8) + + +cdef class SerInt32Type(Serializer): + """Serialize a Python int to 4-byte big-endian signed int32.""" + + cpdef bytes serialize(self, object value, int protocol_version): + value = _coerce_int(value) + _check_int32_range(value) + cdef int32_t val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +# --------------------------------------------------------------------------- +# Type detection helpers +# --------------------------------------------------------------------------- + +cdef inline bint _is_float_type(object subtype): + return subtype is cqltypes.FloatType or issubclass(subtype, cqltypes.FloatType) + +cdef inline bint _is_double_type(object subtype): + return subtype is cqltypes.DoubleType or issubclass(subtype, cqltypes.DoubleType) + +cdef inline bint _is_int32_type(object subtype): + return subtype is cqltypes.Int32Type or issubclass(subtype, cqltypes.Int32Type) + + +# --------------------------------------------------------------------------- +# VectorType serializer +# --------------------------------------------------------------------------- + +cdef class SerVectorType(Serializer): + """ + Optimized Cython serializer for VectorType. + + For float, double, and int32 vectors, pre-allocates a contiguous buffer + and uses C-level byte swapping. For other subtypes, falls back to + per-element Python serialization. + """ + + cdef Py_ssize_t vector_size + cdef object subtype + # 0 = generic, 1 = float, 2 = double, 3 = int32 + cdef int type_code + + def __init__(self, cqltype): + super().__init__(cqltype) + self.vector_size = cqltype.vector_size + self.subtype = cqltype.subtype + + if _is_float_type(self.subtype): + self.type_code = 1 + elif _is_double_type(self.subtype): + self.type_code = 2 + elif _is_int32_type(self.subtype): + self.type_code = 3 + else: + self.type_code = 0 + + cpdef bytes serialize(self, object value, int protocol_version): + cdef object result + cdef Py_ssize_t v_length = len(value) + if v_length != self.vector_size: + raise ValueError( + "Expected sequence of size %d for vector of type %s and " + "dimension %d, observed sequence of length %d" % ( + self.vector_size, self.subtype.typename, + self.vector_size, v_length)) + + if self.type_code == 1: + result = self._serialize_float_buffer(value) + if result is not None: + return result + elif self.type_code == 2: + result = self._serialize_double_buffer(value) + if result is not None: + return result + elif self.type_code == 3: + result = self._serialize_int32_buffer(value) + if result is not None: + return result + + # Keep indexable sequences on the fast path. Fall back to tuple() + # only for iterable-only inputs so the Cython path still matches + # Python VectorType.serialize() input semantics. + if v_length != 0 and not isinstance(value, (list, tuple)): + try: + value[0] + except (TypeError, KeyError, IndexError): + value = tuple(value) + + if self.type_code == 1: + return self._serialize_float(value) + elif self.type_code == 2: + return self._serialize_double(value) + elif self.type_code == 3: + return self._serialize_int32(value) + else: + return self._serialize_generic(value, protocol_version) + + cdef inline bytes _serialize_float(self, object values): + """Serialize a list of floats into a contiguous big-endian buffer. + + ``values`` is already an indexable sequence (normalized in + ``serialize()``), so integer indexing is safe and fast. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + cdef double dval + cdef float val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + dval = values[i] + _check_float_range(dval) + val = dval + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline object _serialize_float_buffer(self, object values): + """Fast path for contiguous float32 buffers (e.g. numpy float32 arrays). + + No ``_check_float_range`` is needed: the typed memoryview + ``float[::1]`` constrains values to C float (IEEE 754 float32), + so overflow is impossible by definition. Returns ``None`` when + *values* does not support the buffer protocol with the required + format, letting the caller fall through to the element-wise path. + """ + cdef float[::1] view + cdef Py_ssize_t buf_size = self.vector_size * 4 + cdef Py_ssize_t i + cdef object result + cdef char *buf + cdef char *dst + cdef char *src + cdef float val + + try: + view = values + except (TypeError, ValueError): + return None + + if buf_size == 0: + return b"" + + result = PyBytes_FromStringAndSize(NULL, buf_size) + buf = PyBytes_AS_STRING(result) + + if is_little_endian: + for i in range(self.vector_size): + val = view[i] + src = &val + dst = buf + i * 4 + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(buf, &view[0], buf_size) + + return result + + cdef inline bytes _serialize_double(self, object values): + """Serialize a list of doubles into a contiguous big-endian buffer. + + ``values`` is already an indexable sequence (normalized in + ``serialize()``), so integer indexing is safe and fast. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 8 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + cdef double val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + val = values[i] + src = &val + dst = buf + i * 8 + + if is_little_endian: + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(dst, src, 8) + + return result + + cdef inline object _serialize_double_buffer(self, object values): + """Fast path for contiguous float64 buffers (e.g. numpy float64 arrays). + + Returns ``None`` when *values* does not expose a compatible + buffer, letting the caller fall through to the element-wise path. + """ + cdef double[::1] view + cdef Py_ssize_t buf_size = self.vector_size * 8 + cdef Py_ssize_t i + cdef object result + cdef char *buf + cdef char *dst + cdef char *src + cdef double val + + try: + view = values + except (TypeError, ValueError): + return None + + if buf_size == 0: + return b"" + + result = PyBytes_FromStringAndSize(NULL, buf_size) + buf = PyBytes_AS_STRING(result) + + if is_little_endian: + for i in range(self.vector_size): + val = view[i] + src = &val + dst = buf + i * 8 + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(buf, &view[0], buf_size) + + return result + + cdef inline bytes _serialize_int32(self, object values): + """Serialize a list of int32 values into a contiguous big-endian buffer. + + ``values`` is already an indexable sequence (normalized in + ``serialize()``), so integer indexing is safe and fast. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + cdef int32_t val + cdef object item + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + item = _coerce_int(values[i]) + _check_int32_range(item) + val = item + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline object _serialize_int32_buffer(self, object values): + """Fast path for contiguous int32 buffers (e.g. numpy int32 arrays). + + No ``_coerce_int`` / ``_check_int32_range`` is needed: the typed + memoryview ``int[::1]`` enforces 32-bit signed integer range at + the buffer-protocol level. Returns ``None`` when *values* does + not expose a compatible buffer, letting the caller fall through + to the element-wise path. + """ + cdef int[::1] view + cdef Py_ssize_t buf_size = self.vector_size * 4 + cdef Py_ssize_t i + cdef object result + cdef char *buf + cdef char *dst + cdef char *src + + try: + view = values + except (TypeError, ValueError): + return None + + if buf_size == 0: + return b"" + + result = PyBytes_FromStringAndSize(NULL, buf_size) + buf = PyBytes_AS_STRING(result) + + if is_little_endian: + for i in range(self.vector_size): + src = &view[i] + dst = buf + i * 4 + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(buf, &view[0], buf_size) + + return result + + cdef inline bytes _serialize_generic(self, object values, int protocol_version): + """Fallback: element-by-element Python serialization for non-optimized types.""" + import io + from cassandra.marshal import uvint_pack + + serialized_size = self.subtype.serial_size() + buf = io.BytesIO() + for item in values: + item_bytes = self.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Generic serializer (fallback for all other types) +# --------------------------------------------------------------------------- + +cdef class GenericSerializer(Serializer): + """ + Wraps a generic cqltype for serialization, delegating to the Python-level + cqltype.serialize() classmethod. + """ + + cpdef bytes serialize(self, object value, int protocol_version): + return self.cqltype.serialize(value, protocol_version) + + def __repr__(self): + return "GenericSerializer(%s)" % (self.cqltype,) + + +# --------------------------------------------------------------------------- +# Lookup and factory +# --------------------------------------------------------------------------- + +cdef dict _ser_classes = {} + +cpdef Serializer find_serializer(cqltype): + """Find a serializer for a cqltype.""" + + # For VectorType, use SerVectorType only if parameterized (has a valid subtype). + # Un-parameterized VectorType (base class) would crash _is_float_type() etc. + if issubclass(cqltype, cqltypes.VectorType): + if getattr(cqltype, 'subtype', None) is not None: + return SerVectorType(cqltype) + return GenericSerializer(cqltype) + + # For scalar types with dedicated serializers, look up by name + name = 'Ser' + cqltype.__name__ + cls = _ser_classes.get(name) + if cls is not None: + return cls(cqltype) + + # Fallback to generic + return GenericSerializer(cqltype) + + +def make_serializers(cqltypes_list): + """Create a Cython typed array of Serializer objects for each given cqltype. + + Returns an ``obj_array`` (Cython typed memoryview) matching the + ``make_deserializers()`` convention for O(1) C-level indexed access. + """ + return obj_array([find_serializer(ct) for ct in cqltypes_list]) + + +def obj_array(list objs): + """Create a (Cython) array of objects given a list of objects. + + Mirrors ``deserializers.obj_array()`` so both sides share the same + typed-memoryview convention. Returns the plain list for empty input + since ``cython_array`` does not support zero-length shapes. Callers + that use ``cdef Serializer[::1]`` typed memoryviews must guard + against empty input before assignment. + """ + if not objs: + return objs + cdef object[:] arr + cdef Py_ssize_t i + arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O") + for i, obj in enumerate(objs): + arr[i] = obj + return arr + + +# Build the lookup dict for scalar serializers at module load time +_ser_classes['SerFloatType'] = SerFloatType +_ser_classes['SerDoubleType'] = SerDoubleType +_ser_classes['SerInt32Type'] = SerInt32Type diff --git a/tests/unit/cython/test_serializers.py b/tests/unit/cython/test_serializers.py new file mode 100644 index 0000000000..87729f44a2 --- /dev/null +++ b/tests/unit/cython/test_serializers.py @@ -0,0 +1,669 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for Cython-optimized serializers (cassandra.serializers). + +Verifies byte-for-byte equivalence with the Python-level cqltype.serialize() +implementations, plus correct error behavior for edge cases. +""" + +import math +import struct +import unittest + +import numpy as np + +from tests.unit.cython.utils import cythontest + +from cassandra.cython_deps import HAVE_CYTHON + +try: + from tests import VERIFY_CYTHON +except ImportError: + VERIFY_CYTHON = False + +from cassandra.cqltypes import ( + FloatType, + DoubleType, + Int32Type, + VectorType, + UTF8Type, + LongType, + BooleanType, +) + +# Import serializers only if Cython is available (compiled .so present). +# When VERIFY_CYTHON is set (CI mode), let ImportError propagate so build +# failures are not silently swallowed. +if HAVE_CYTHON or VERIFY_CYTHON: + from cassandra.serializers import ( + Serializer, + SerFloatType, + SerDoubleType, + SerInt32Type, + SerVectorType, + GenericSerializer, + find_serializer, + make_serializers, + ) + +# Protocol version used in tests (value doesn't affect scalar serialization) +PROTO = 4 + + +def _make_vector_type(subtype, size): + """Create a VectorType parameterized with the given subtype class and size.""" + return VectorType.apply_parameters([subtype, size], None) + + +# --------------------------------------------------------------------------- +# Scalar serializer equivalence tests +# --------------------------------------------------------------------------- + + +@cythontest +class TestSerFloatTypeEquivalence(unittest.TestCase): + """Verify SerFloatType produces identical bytes to FloatType.serialize().""" + + def setUp(self): + self.ser = SerFloatType(FloatType) + + def _assert_equiv(self, value): + cython_bytes = self.ser.serialize(value, PROTO) + python_bytes = FloatType.serialize(value, PROTO) + self.assertEqual(cython_bytes, python_bytes, "Mismatch for value %r" % value) + + def test_zero(self): + self._assert_equiv(0.0) + + def test_negative_zero(self): + self._assert_equiv(-0.0) + + def test_positive_values(self): + for val in [1.0, 0.5, 3.14, 100.0, 1e10]: + self._assert_equiv(val) + + def test_negative_values(self): + for val in [-1.0, -0.5, -3.14, -100.0, -1e10]: + self._assert_equiv(val) + + def test_flt_max(self): + flt_max = 3.4028234663852886e38 + self._assert_equiv(flt_max) + self._assert_equiv(-flt_max) + + def test_flt_max_rounding_boundary(self): + """Values slightly above FLT_MAX that round down should be accepted. + + 3.4028235e38 is above FLT_MAX (3.4028234663852886e38) but + struct.pack('>f', 3.4028235e38) rounds it down to FLT_MAX. + The Cython serializer must accept the same inputs. + """ + self._assert_equiv(3.4028235e38) + self._assert_equiv(-3.4028235e38) + + def test_subnormal_values(self): + """Subnormal (denormalized) floats should serialize correctly.""" + self._assert_equiv(1e-45) + self._assert_equiv(-1e-45) + self._assert_equiv(1.4e-45) # smallest positive float32 + + def test_inf(self): + self._assert_equiv(float("inf")) + + def test_neg_inf(self): + self._assert_equiv(float("-inf")) + + def test_nan(self): + """NaN bytes must match (NaN != NaN, so compare bytes directly).""" + cython_bytes = self.ser.serialize(float("nan"), PROTO) + python_bytes = FloatType.serialize(float("nan"), PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_overflow_positive(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(1e40, PROTO) + + def test_overflow_negative(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(-1e40, PROTO) + + def test_overflow_dbl_max(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(1.7976931348623157e308, PROTO) + + def test_type_error_string(self): + with self.assertRaises(TypeError): + self.ser.serialize("not a float", PROTO) + + def test_type_error_none(self): + with self.assertRaises(TypeError): + self.ser.serialize(None, PROTO) + + +@cythontest +class TestSerDoubleTypeEquivalence(unittest.TestCase): + """Verify SerDoubleType produces identical bytes to DoubleType.serialize().""" + + def setUp(self): + self.ser = SerDoubleType(DoubleType) + + def _assert_equiv(self, value): + cython_bytes = self.ser.serialize(value, PROTO) + python_bytes = DoubleType.serialize(value, PROTO) + self.assertEqual(cython_bytes, python_bytes, "Mismatch for value %r" % value) + + def test_zero(self): + self._assert_equiv(0.0) + + def test_negative_zero(self): + self._assert_equiv(-0.0) + + def test_normal_values(self): + for val in [1.0, -1.0, 3.14, -3.14, 1e100, -1e100, 1e-100]: + self._assert_equiv(val) + + def test_dbl_max(self): + self._assert_equiv(1.7976931348623157e308) + self._assert_equiv(-1.7976931348623157e308) + + def test_inf(self): + self._assert_equiv(float("inf")) + self._assert_equiv(float("-inf")) + + def test_nan(self): + cython_bytes = self.ser.serialize(float("nan"), PROTO) + python_bytes = DoubleType.serialize(float("nan"), PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_type_error_string(self): + with self.assertRaises(TypeError): + self.ser.serialize("not a double", PROTO) + + def test_type_error_none(self): + with self.assertRaises(TypeError): + self.ser.serialize(None, PROTO) + + +@cythontest +class TestSerInt32TypeEquivalence(unittest.TestCase): + """Verify SerInt32Type produces identical bytes to Int32Type.serialize().""" + + def setUp(self): + self.ser = SerInt32Type(Int32Type) + + def _assert_equiv(self, value): + cython_bytes = self.ser.serialize(value, PROTO) + python_bytes = Int32Type.serialize(value, PROTO) + self.assertEqual(cython_bytes, python_bytes, "Mismatch for value %r" % value) + + def test_zero(self): + self._assert_equiv(0) + + def test_positive_values(self): + for val in [1, 42, 127, 255, 32767, 65535]: + self._assert_equiv(val) + + def test_negative_values(self): + for val in [-1, -42, -128, -32768]: + self._assert_equiv(val) + + def test_int32_max(self): + self._assert_equiv(2147483647) + + def test_int32_min(self): + self._assert_equiv(-2147483648) + + def test_overflow_positive(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(2147483648, PROTO) + + def test_overflow_negative(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(-2147483649, PROTO) + + def test_overflow_large_python_int(self): + """Python ints have arbitrary precision; must still reject out-of-range.""" + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(2**100, PROTO) + + def test_overflow_large_negative_python_int(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize(-(2**100), PROTO) + + def test_type_error_string(self): + with self.assertRaises(TypeError): + self.ser.serialize("not an int", PROTO) + + def test_type_error_none(self): + with self.assertRaises(TypeError): + self.ser.serialize(None, PROTO) + + def test_index_protocol(self): + """Objects implementing __index__ should be accepted, like struct.pack.""" + + class MyInt: + def __index__(self): + return 42 + + cython_bytes = self.ser.serialize(MyInt(), PROTO) + python_bytes = Int32Type.serialize(42, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + +# --------------------------------------------------------------------------- +# VectorType serializer equivalence tests +# --------------------------------------------------------------------------- + + +@cythontest +class TestSerVectorTypeFloat(unittest.TestCase): + """Verify SerVectorType float fast-path matches VectorType.serialize().""" + + def setUp(self): + self.vec_type = _make_vector_type(FloatType, 4) + self.ser = SerVectorType(self.vec_type) + + def _assert_equiv(self, values): + cython_bytes = self.ser.serialize(values, PROTO) + python_bytes = self.vec_type.serialize(values, PROTO) + self.assertEqual( + cython_bytes, python_bytes, "Mismatch for values %r" % (values,) + ) + + def test_basic(self): + self._assert_equiv([1.0, 2.0, 3.0, 4.0]) + + def test_zeros(self): + self._assert_equiv([0.0, 0.0, 0.0, 0.0]) + + def test_negative(self): + self._assert_equiv([-1.0, -2.5, -0.001, -100.0]) + + def test_mixed_special(self): + self._assert_equiv([float("inf"), float("-inf"), 0.0, -0.0]) + + def test_nan_element(self): + """NaN in vector should serialize identically.""" + cython_bytes = self.ser.serialize([1.0, float("nan"), 3.0, 4.0], PROTO) + python_bytes = self.vec_type.serialize([1.0, float("nan"), 3.0, 4.0], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_element_overflow(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize([1.0, 1e40, 3.0, 4.0], PROTO) + + def test_wrong_length_short(self): + with self.assertRaises(ValueError): + self.ser.serialize([1.0, 2.0], PROTO) + + def test_wrong_length_long(self): + with self.assertRaises(ValueError): + self.ser.serialize([1.0, 2.0, 3.0, 4.0, 5.0], PROTO) + + def test_empty_list_for_nonempty_vector(self): + with self.assertRaises(ValueError): + self.ser.serialize([], PROTO) + + +@cythontest +class TestSerVectorTypeDouble(unittest.TestCase): + """Verify SerVectorType double fast-path matches VectorType.serialize().""" + + def setUp(self): + self.vec_type = _make_vector_type(DoubleType, 3) + self.ser = SerVectorType(self.vec_type) + + def _assert_equiv(self, values): + cython_bytes = self.ser.serialize(values, PROTO) + python_bytes = self.vec_type.serialize(values, PROTO) + self.assertEqual( + cython_bytes, python_bytes, "Mismatch for values %r" % (values,) + ) + + def test_basic(self): + self._assert_equiv([1.0, 2.0, 3.0]) + + def test_large_values(self): + self._assert_equiv([1e100, -1e100, 1e-100]) + + def test_special(self): + self._assert_equiv([float("inf"), float("-inf"), 0.0]) + + +@cythontest +class TestSerVectorTypeInt32(unittest.TestCase): + """Verify SerVectorType int32 fast-path matches VectorType.serialize().""" + + def setUp(self): + self.vec_type = _make_vector_type(Int32Type, 3) + self.ser = SerVectorType(self.vec_type) + + def _assert_equiv(self, values): + cython_bytes = self.ser.serialize(values, PROTO) + python_bytes = self.vec_type.serialize(values, PROTO) + self.assertEqual( + cython_bytes, python_bytes, "Mismatch for values %r" % (values,) + ) + + def test_basic(self): + self._assert_equiv([1, 2, 3]) + + def test_boundaries(self): + self._assert_equiv([2147483647, -2147483648, 0]) + + def test_element_overflow(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize([1, 2147483648, 3], PROTO) + + def test_element_overflow_negative(self): + with self.assertRaises((OverflowError, struct.error)): + self.ser.serialize([1, -2147483649, 3], PROTO) + + def test_vector_index_protocol(self): + """Objects implementing __index__ in vector elements should be accepted.""" + + class MyInt: + def __index__(self): + return 42 + + cython_bytes = self.ser.serialize([MyInt(), MyInt(), MyInt()], PROTO) + python_bytes = self.vec_type.serialize([42, 42, 42], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + +@cythontest +class TestSerVectorTypeGenericFallback(unittest.TestCase): + """Verify SerVectorType generic fallback matches VectorType.serialize().""" + + def test_utf8_vector(self): + vec_type = _make_vector_type(UTF8Type, 3) + ser = SerVectorType(vec_type) + values = ["hello", "world", "test"] + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_boolean_vector(self): + vec_type = _make_vector_type(BooleanType, 4) + ser = SerVectorType(vec_type) + values = [True, False, True, False] + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + +@cythontest +class TestSerVectorTypeHighDimensional(unittest.TestCase): + """Test with realistic high-dimensional vectors (embedding use case).""" + + def test_float_1536_dim(self): + """1536-dim float vector (typical for embedding models).""" + vec_type = _make_vector_type(FloatType, 1536) + ser = SerVectorType(vec_type) + values = [float(i) / 1536.0 for i in range(1536)] + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + self.assertEqual(len(cython_bytes), 1536 * 4) + + def test_double_768_dim(self): + vec_type = _make_vector_type(DoubleType, 768) + ser = SerVectorType(vec_type) + values = [float(i) / 768.0 for i in range(768)] + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + self.assertEqual(len(cython_bytes), 768 * 8) + + +# --------------------------------------------------------------------------- +# Round-trip tests (serialize with Cython, deserialize with Python) +# --------------------------------------------------------------------------- + + +@cythontest +class TestSerializerRoundTrip(unittest.TestCase): + """Serialize with Cython, deserialize with Python cqltype.deserialize().""" + + def test_float_round_trip(self): + ser = SerFloatType(FloatType) + for val in [0.0, 1.0, -1.0, 3.14, float("inf"), float("-inf")]: + serialized = ser.serialize(val, PROTO) + deserialized = FloatType.deserialize(serialized, PROTO) + if math.isinf(val): + self.assertEqual(val, deserialized) + else: + self.assertAlmostEqual(val, deserialized, places=5) + + def test_double_round_trip(self): + ser = SerDoubleType(DoubleType) + for val in [0.0, 1.0, -1.0, 3.141592653589793, 1e100, float("inf")]: + serialized = ser.serialize(val, PROTO) + deserialized = DoubleType.deserialize(serialized, PROTO) + self.assertEqual(val, deserialized) + + def test_int32_round_trip(self): + ser = SerInt32Type(Int32Type) + for val in [0, 1, -1, 2147483647, -2147483648, 42]: + serialized = ser.serialize(val, PROTO) + deserialized = Int32Type.deserialize(serialized, PROTO) + self.assertEqual(val, deserialized) + + def test_float_vector_round_trip(self): + vec_type = _make_vector_type(FloatType, 4) + ser = SerVectorType(vec_type) + values = [1.5, -2.5, 3.14, 0.0] + serialized = ser.serialize(values, PROTO) + deserialized = vec_type.deserialize(serialized, PROTO) + for orig, deser in zip(values, deserialized): + self.assertAlmostEqual(orig, deser, places=5) + + def test_int32_vector_round_trip(self): + vec_type = _make_vector_type(Int32Type, 3) + ser = SerVectorType(vec_type) + values = [2147483647, -2147483648, 0] + serialized = ser.serialize(values, PROTO) + deserialized = vec_type.deserialize(serialized, PROTO) + self.assertEqual(list(deserialized), values) + + +@cythontest +class TestSerVectorTypeIterableInput(unittest.TestCase): + """Verify vector serializer accepts non-subscriptable iterables. + + The Python VectorType.serialize() only requires len() + iteration, + so the Cython path must accept the same inputs after normalization. + """ + + def test_iterable_with_len_no_getitem(self): + """A custom iterable with __len__ + __iter__ but no __getitem__.""" + + class IterableOnly: + def __init__(self, data): + self._data = list(data) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + vec_type = _make_vector_type(FloatType, 3) + ser = SerVectorType(vec_type) + values = IterableOnly([1.0, 2.0, 3.0]) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize([1.0, 2.0, 3.0], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_iterable_int32_vector(self): + """Non-subscriptable iterable with int32 fast-path.""" + + class IterableOnly: + def __init__(self, data): + self._data = list(data) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + vec_type = _make_vector_type(Int32Type, 3) + ser = SerVectorType(vec_type) + values = IterableOnly([1, 2, 3]) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize([1, 2, 3], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_tuple_input(self): + """Tuple input should work (already subscriptable, but verify).""" + vec_type = _make_vector_type(Int32Type, 3) + ser = SerVectorType(vec_type) + cython_bytes = ser.serialize((1, 2, 3), PROTO) + python_bytes = vec_type.serialize([1, 2, 3], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_indexable_sequence_input(self): + """Indexable non-list sequences should avoid tuple normalization.""" + + class IndexableSequence: + def __init__(self, data): + self._data = list(data) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + return self._data[index] + + def __iter__(self): + return iter(self._data) + + vec_type = _make_vector_type(FloatType, 3) + ser = SerVectorType(vec_type) + values = IndexableSequence([1.0, 2.0, 3.0]) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize([1.0, 2.0, 3.0], PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_numpy_float32_array_input(self): + """NumPy float32 arrays should hit the float buffer fast path.""" + vec_type = _make_vector_type(FloatType, 4) + ser = SerVectorType(vec_type) + values = np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_numpy_float64_array_input(self): + """NumPy float64 arrays should hit the double buffer fast path.""" + vec_type = _make_vector_type(DoubleType, 3) + ser = SerVectorType(vec_type) + values = np.asarray([1.0, -2.5, 3.14], dtype=np.float64) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_numpy_int32_array_input(self): + """NumPy int32 arrays should hit the int32 buffer fast path.""" + vec_type = _make_vector_type(Int32Type, 3) + ser = SerVectorType(vec_type) + values = np.asarray([2147483647, -2147483648, 0], dtype=np.int32) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + def test_numpy_dtype_mismatch_fallthrough(self): + """dtype mismatch should fall through to element-wise path correctly. + + A float64 array passed to a FloatType vector cannot bind to + float[::1], so the serializer must fall through to the + element-wise path and still produce correct bytes. + """ + vec_type = _make_vector_type(FloatType, 3) + ser = SerVectorType(vec_type) + values = np.asarray([1.0, 2.0, 3.0], dtype=np.float64) + cython_bytes = ser.serialize(values, PROTO) + python_bytes = vec_type.serialize(values, PROTO) + self.assertEqual(cython_bytes, python_bytes) + + +# --------------------------------------------------------------------------- +# Factory function tests +# --------------------------------------------------------------------------- + + +@cythontest +class TestFindSerializer(unittest.TestCase): + """Test find_serializer() returns correct serializer types.""" + + def test_float_type(self): + ser = find_serializer(FloatType) + self.assertIsInstance(ser, SerFloatType) + + def test_double_type(self): + ser = find_serializer(DoubleType) + self.assertIsInstance(ser, SerDoubleType) + + def test_int32_type(self): + ser = find_serializer(Int32Type) + self.assertIsInstance(ser, SerInt32Type) + + def test_vector_type(self): + vec_type = _make_vector_type(FloatType, 3) + ser = find_serializer(vec_type) + self.assertIsInstance(ser, SerVectorType) + + def test_unknown_type_gets_generic(self): + ser = find_serializer(UTF8Type) + self.assertIsInstance(ser, GenericSerializer) + + def test_generic_delegates_to_python(self): + ser = find_serializer(LongType) + self.assertIsInstance(ser, GenericSerializer) + result = ser.serialize(42, PROTO) + expected = LongType.serialize(42, PROTO) + self.assertEqual(result, expected) + + def test_unparameterized_vector_type_gets_generic(self): + """Un-parameterized VectorType (base class) should not crash.""" + ser = find_serializer(VectorType) + self.assertIsInstance(ser, GenericSerializer) + + +@cythontest +class TestMakeSerializers(unittest.TestCase): + """Test make_serializers() batch factory.""" + + def test_basic(self): + types = [FloatType, DoubleType, Int32Type, UTF8Type] + serializers = make_serializers(types) + self.assertEqual(len(serializers), 4) + self.assertIsInstance(serializers[0], SerFloatType) + self.assertIsInstance(serializers[1], SerDoubleType) + self.assertIsInstance(serializers[2], SerInt32Type) + self.assertIsInstance(serializers[3], GenericSerializer) + + def test_empty(self): + serializers = make_serializers([]) + self.assertEqual(len(serializers), 0) + + def test_with_vector_type(self): + vec_type = _make_vector_type(FloatType, 3) + serializers = make_serializers([vec_type, Int32Type]) + self.assertEqual(len(serializers), 2) + self.assertIsInstance(serializers[0], SerVectorType) + self.assertIsInstance(serializers[1], SerInt32Type)