diff --git a/benchmarks/vector_deserialize.py b/benchmarks/vector_deserialize.py new file mode 100644 index 0000000000..89fb5d7e66 --- /dev/null +++ b/benchmarks/vector_deserialize.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark for VectorType deserialization performance. + +Tests different optimization strategies: +1. Current implementation (Python with struct.unpack/numpy) +2. Python struct.unpack only +3. Numpy frombuffer + tolist() +4. Cython DesVectorType deserializer + +Run with: python benchmarks/vector_deserialize.py +""" + +import os +import sys +import time +import struct + +# Add project root to path so the benchmark can be run from any directory +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from cassandra.cqltypes import FloatType, DoubleType, Int32Type, LongType, ShortType +from cassandra.marshal import ( + float_pack, + double_pack, + int32_pack, + int64_pack, + int16_pack, +) + + +def create_test_data(vector_size, element_type): + """Create serialized test data for a vector.""" + if element_type == FloatType: + values = [float(i * 0.1) for i in range(vector_size)] + pack_fn = float_pack + elif element_type == DoubleType: + values = [float(i * 0.1) for i in range(vector_size)] + pack_fn = double_pack + elif element_type == Int32Type: + values = list(range(vector_size)) + pack_fn = int32_pack + elif element_type == LongType: + values = list(range(vector_size)) + pack_fn = int64_pack + elif element_type == ShortType: + values = [i % 32767 for i in range(vector_size)] + pack_fn = int16_pack + else: + raise ValueError(f"Unsupported element type: {element_type}") + + # Serialize the vector + serialized = b"".join(pack_fn(v) for v in values) + + return serialized, values + + +def benchmark_current_implementation(vector_type, serialized_data, iterations=10000): + """Benchmark the current VectorType.deserialize implementation.""" + protocol_version = 4 + + start = time.perf_counter() + for _ in range(iterations): + result = vector_type.deserialize(serialized_data, protocol_version) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_struct_optimization(vector_type, serialized_data, iterations=10000): + """Benchmark struct.unpack optimization.""" + vector_size = vector_type.vector_size + subtype = vector_type.subtype + + # Determine format string - subtype is a class, use identity or issubclass + if subtype is FloatType or ( + isinstance(subtype, type) and issubclass(subtype, FloatType) + ): + format_str = f">{vector_size}f" + elif subtype is DoubleType or ( + isinstance(subtype, type) and issubclass(subtype, DoubleType) + ): + format_str = f">{vector_size}d" + elif subtype is Int32Type or ( + isinstance(subtype, type) and issubclass(subtype, Int32Type) + ): + format_str = f">{vector_size}i" + elif subtype is LongType or ( + isinstance(subtype, type) and issubclass(subtype, LongType) + ): + format_str = f">{vector_size}q" + elif subtype is ShortType or ( + isinstance(subtype, type) and issubclass(subtype, ShortType) + ): + format_str = f">{vector_size}h" + else: + return None, None, None + + start = time.perf_counter() + for _ in range(iterations): + result = list(struct.unpack(format_str, serialized_data)) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_numpy_optimization(vector_type, serialized_data, iterations=10000): + """Benchmark numpy.frombuffer optimization.""" + try: + import numpy as np + except ImportError: + return None, None, None + + vector_size = vector_type.vector_size + subtype = vector_type.subtype + + # Determine dtype + if subtype is FloatType or ( + isinstance(subtype, type) and issubclass(subtype, FloatType) + ): + dtype = ">f4" + elif subtype is DoubleType or ( + isinstance(subtype, type) and issubclass(subtype, DoubleType) + ): + dtype = ">f8" + elif subtype is Int32Type or ( + isinstance(subtype, type) and issubclass(subtype, Int32Type) + ): + dtype = ">i4" + elif subtype is LongType or ( + isinstance(subtype, type) and issubclass(subtype, LongType) + ): + dtype = ">i8" + elif subtype is ShortType or ( + isinstance(subtype, type) and issubclass(subtype, ShortType) + ): + dtype = ">i2" + else: + return None, None, None + + start = time.perf_counter() + for _ in range(iterations): + arr = np.frombuffer(serialized_data, dtype=dtype, count=vector_size) + result = arr.tolist() + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_cython_deserializer(vector_type, serialized_data, iterations=10000): + """Benchmark Cython DesVectorType deserializer. + + This benchmark requires the Cython deserializers extension to be compiled. + When the extension is not available, or the type does not have a dedicated + DesVectorType deserializer, the benchmark is silently skipped (returns None). + """ + try: + from cassandra.deserializers import find_deserializer + except ImportError: + return None, None, None + + protocol_version = 4 + + # Get the Cython deserializer + deserializer = find_deserializer(vector_type) + + # Check if we got the Cython deserializer + if deserializer.__class__.__name__ != "DesVectorType": + return None, None, None + + start = time.perf_counter() + for _ in range(iterations): + result = deserializer.deserialize_bytes(serialized_data, protocol_version) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def verify_results(expected, *results): + """Verify that all results match expected values.""" + for i, result in enumerate(results): + if result is None: + continue + if len(result) != len(expected): + print(f" ❌ Result {i} length mismatch: {len(result)} vs {len(expected)}") + return False + for j, (a, b) in enumerate(zip(result, expected)): + # Use relative tolerance for floating point comparison + if isinstance(a, float) and isinstance(b, float): + # Allow 0.01% relative error for floats + if abs(a - b) > max(abs(a), abs(b)) * 1e-4 + 1e-7: + print(f" ❌ Result {i} value mismatch at index {j}: {a} vs {b}") + return False + elif abs(a - b) > 1e-9: + print(f" ❌ Result {i} value mismatch at index {j}: {a} vs {b}") + return False + return True + + +def run_benchmark_suite(vector_size, element_type, type_name, iterations=10000): + """Run complete benchmark suite for a given vector configuration.""" + print(f"\n{'=' * 80}") + print(f"Benchmark: Vector<{type_name}, {vector_size}>") + print(f"{'=' * 80}") + print(f"Iterations: {iterations:,}") + + # Create test data + from cassandra.cqltypes import lookup_casstype + + cass_typename = f"org.apache.cassandra.db.marshal.{element_type.__name__}" + vector_typename = ( + f"org.apache.cassandra.db.marshal.VectorType({cass_typename}, {vector_size})" + ) + vector_type = lookup_casstype(vector_typename) + + serialized_data, expected_values = create_test_data(vector_size, element_type) + data_size = len(serialized_data) + + print(f"Serialized size: {data_size:,} bytes") + print() + + # Run benchmarks + results = [] + + # 1. Current implementation (baseline) + print("1. Current implementation (baseline)...") + elapsed, per_op, result_current = benchmark_current_implementation( + vector_type, serialized_data, iterations + ) + results.append(result_current) + print(f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} μs") + baseline_time = per_op + + # 2. Struct optimization + print("2. Python struct.unpack optimization...") + elapsed, per_op, result_struct = benchmark_struct_optimization( + vector_type, serialized_data, iterations + ) + results.append(result_struct) + if per_op is not None: + speedup = baseline_time / per_op + print( + f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} μs, Speedup: {speedup:.2f}x" + ) + else: + print(" Not applicable for this type") + + # 3. Numpy with tolist() + print("3. Numpy frombuffer + tolist()...") + elapsed, per_op, result_numpy = benchmark_numpy_optimization( + vector_type, serialized_data, iterations + ) + results.append(result_numpy) + if per_op is not None: + speedup = baseline_time / per_op + print( + f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} μs, Speedup: {speedup:.2f}x" + ) + else: + print(" Numpy not available") + + # 4. Cython deserializer + print("4. Cython DesVectorType deserializer...") + elapsed, per_op, result_cython = benchmark_cython_deserializer( + vector_type, serialized_data, iterations + ) + if per_op is not None: + results.append(result_cython) + speedup = baseline_time / per_op + print( + f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} μs, Speedup: {speedup:.2f}x" + ) + else: + print(" Cython deserializers not available") + + # Verify results + print("\nVerifying results...") + if verify_results(expected_values, *results): + print(" ✓ All results match!") + else: + print(" ✗ Result mismatch detected!") + + return baseline_time + + +def main(): + """Run all benchmarks.""" + # Pin to single CPU core for consistent measurements + try: + import os + + os.sched_setaffinity(0, {0}) # Pin to CPU core 0 + print("Pinned to CPU core 0 for consistent measurements") + except (AttributeError, OSError) as e: + print(f"Could not pin to single core: {e}") + print("Running without CPU affinity...") + + print("=" * 80) + print("VectorType Deserialization Performance Benchmark") + print("=" * 80) + + # Test configurations: (vector_size, element_type, type_name, iterations) + test_configs = [ + # Small vectors + (3, FloatType, "float", 50000), + (4, FloatType, "float", 50000), + # Medium vectors (common in ML) + (128, FloatType, "float", 10000), + (384, FloatType, "float", 10000), + # Large vectors (embeddings) + (768, FloatType, "float", 5000), + (1536, FloatType, "float", 2000), + # Other types (smaller iteration counts) + (128, DoubleType, "double", 10000), + (768, DoubleType, "double", 5000), + (1536, DoubleType, "double", 2000), + (64, Int32Type, "int", 15000), + (128, Int32Type, "int", 10000), + ] + + summary = [] + + for vector_size, element_type, type_name, iterations in test_configs: + baseline = run_benchmark_suite(vector_size, element_type, type_name, iterations) + summary.append((f"Vector<{type_name}, {vector_size}>", baseline)) + + # Print summary + print("\n" + "=" * 80) + print("SUMMARY - Current Implementation Performance") + print("=" * 80) + for config, baseline_time in summary: + print(f"{config:30s}: {baseline_time:8.2f} μs") + + print("\n" + "=" * 80) + print("Benchmark complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/vector_serialize.py b/benchmarks/vector_serialize.py new file mode 100644 index 0000000000..422e0761d1 --- /dev/null +++ b/benchmarks/vector_serialize.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark for VectorType serialization performance. + +Tests different optimization strategies: +1. Current implementation (Python io.BytesIO loop) +2. Python struct.pack batch format string +3. Cython SerVectorType serializer (when available) +4. BoundStatement.bind() end-to-end with 1 vector column (when available) + +Run with: python benchmarks/vector_serialize.py +""" + +import sys +import time +import struct + +# Add parent directory to path +sys.path.insert(0, '.') + +from cassandra.cqltypes import FloatType, DoubleType, Int32Type, lookup_casstype +from cassandra.marshal import float_pack, double_pack, int32_pack + + +def create_test_values(vector_size, element_type): + """Create test values for serialization benchmarks.""" + if element_type == FloatType: + return [float(i * 0.1) for i in range(vector_size)] + elif element_type == DoubleType: + return [float(i * 0.1) for i in range(vector_size)] + elif element_type == Int32Type: + return list(range(vector_size)) + else: + raise ValueError(f"Unsupported element type: {element_type}") + + +def benchmark_current_implementation(vector_type, values, iterations=10000): + """Benchmark the current VectorType.serialize implementation (io.BytesIO loop).""" + protocol_version = 4 + + start = time.perf_counter() + for _ in range(iterations): + result = vector_type.serialize(values, protocol_version) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_struct_pack(vector_type, values, iterations=10000): + """Benchmark struct.pack batch format string optimization.""" + vector_size = vector_type.vector_size + subtype = vector_type.subtype + + # Determine format string + if subtype is FloatType or (isinstance(subtype, type) and issubclass(subtype, FloatType)): + format_str = f'>{vector_size}f' + elif subtype is DoubleType or (isinstance(subtype, type) and issubclass(subtype, DoubleType)): + format_str = f'>{vector_size}d' + elif subtype is Int32Type or (isinstance(subtype, type) and issubclass(subtype, Int32Type)): + format_str = f'>{vector_size}i' + else: + return None, None, None + + # Pre-compile the struct for fair comparison + packer = struct.Struct(format_str) + + start = time.perf_counter() + for _ in range(iterations): + result = packer.pack(*values) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_cython_serializer(vector_type, values, iterations=10000): + """Benchmark Cython SerVectorType serializer (when available).""" + try: + from cassandra.serializers import find_serializer + except ImportError: + return None, None, None + + protocol_version = 4 + + # Get the Cython serializer + serializer = find_serializer(vector_type) + + # Check if we got the Cython serializer (not generic fallback) + if serializer.__class__.__name__ != 'SerVectorType': + return None, None, None + + start = time.perf_counter() + for _ in range(iterations): + result = serializer.serialize(values, protocol_version) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, result + + +def benchmark_bind_statement(vector_type, values, iterations=10000): + """Benchmark BoundStatement.bind() end-to-end with 1 vector column. + + This simulates the full bind path for a prepared statement with a single + vector column, including column metadata lookup and serialization. + """ + from unittest.mock import MagicMock + + try: + from cassandra.query import BoundStatement, PreparedStatement, UNSET_VALUE + except ImportError: + return None, None, None + + # Create a mock PreparedStatement with one vector column + col_meta_mock = MagicMock() + col_meta_mock.keyspace_name = 'test_ks' + col_meta_mock.table_name = 'test_table' + col_meta_mock.name = 'vec_col' + col_meta_mock.type = vector_type + + prepared = MagicMock(spec=PreparedStatement) + prepared.protocol_version = 4 + prepared.column_metadata = [col_meta_mock] + prepared.column_encryption_policy = None + prepared.routing_key_indexes = None + prepared.is_idempotent = False + prepared.result_metadata = None + prepared.keyspace = 'test_ks' + + start = time.perf_counter() + for _ in range(iterations): + bs = BoundStatement.__new__(BoundStatement) + bs.prepared_statement = prepared + bs.values = [] + bs.raw_values = [values] + # Inline the core serialization path (no CE policy) + bs.values.append(vector_type.serialize(values, 4)) + end = time.perf_counter() + + elapsed = end - start + per_op = (elapsed / iterations) * 1_000_000 # microseconds + + return elapsed, per_op, bs.values[0] + + +def verify_results(reference, *results): + """Verify that all serialization results produce identical bytes.""" + for i, result in enumerate(results): + if result is None: + continue + if result != reference: + print(f" Result {i} mismatch: {len(result)} bytes vs {len(reference)} bytes (reference)") + # Show first divergence + for j in range(min(len(result), len(reference))): + if result[j] != reference[j]: + print(f" First difference at byte {j}: {result[j]:#04x} vs {reference[j]:#04x}") + break + return False + return True + + +def run_benchmark_suite(vector_size, element_type, type_name, iterations=10000): + """Run complete benchmark suite for a given vector configuration.""" + sep = '=' * 80 + print(f"\n{sep}") + print(f"Benchmark: Vector<{type_name}, {vector_size}>") + print(f"{sep}") + print(f"Iterations: {iterations:,}") + + # Create vector type + cass_typename = f'org.apache.cassandra.db.marshal.{element_type.__name__}' + vector_typename = f'org.apache.cassandra.db.marshal.VectorType({cass_typename}, {vector_size})' + vector_type = lookup_casstype(vector_typename) + + values = create_test_values(vector_size, element_type) + + # Get reference serialization for verification + reference_bytes = vector_type.serialize(values, 4) + data_size = len(reference_bytes) + + print(f"Serialized size: {data_size:,} bytes") + print() + + # Collect results for verification + all_results = [] + + # 1. Current implementation (baseline) + print("1. Current implementation (io.BytesIO loop, baseline)...") + elapsed, per_op, result = benchmark_current_implementation( + vector_type, values, iterations) + all_results.append(result) + print(f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} us") + baseline_time = per_op + + # 2. struct.pack batch format string + print("2. Python struct.pack batch format string...") + elapsed, per_op, result = benchmark_struct_pack( + vector_type, values, iterations) + all_results.append(result) + if per_op is not None: + speedup = baseline_time / per_op + print(f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} us, Speedup: {speedup:.2f}x") + else: + print(" Not applicable for this type") + + # 3. Cython serializer + print("3. Cython SerVectorType serializer...") + elapsed, per_op, result = benchmark_cython_serializer( + vector_type, values, iterations) + all_results.append(result) + if per_op is not None: + speedup = baseline_time / per_op + print(f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} us, Speedup: {speedup:.2f}x") + else: + print(" Cython serializers not available") + + # 4. BoundStatement.bind() end-to-end + print("4. BoundStatement.bind() end-to-end (1 vector column)...") + elapsed, per_op, result = benchmark_bind_statement( + vector_type, values, iterations) + all_results.append(result) + if per_op is not None: + speedup = baseline_time / per_op + print(f" Total: {elapsed:.4f}s, Per-op: {per_op:.2f} us, Overhead vs baseline: {speedup:.2f}x") + else: + print(" BoundStatement benchmark not available") + + # Verify results + print("\nVerifying results...") + if verify_results(reference_bytes, *all_results): + print(" All results match!") + else: + print(" Result mismatch detected!") + + return baseline_time + + +def main(): + """Run all benchmarks.""" + # Pin to single CPU core for consistent measurements + try: + import os + os.sched_setaffinity(0, {0}) # Pin to CPU core 0 + print("Pinned to CPU core 0 for consistent measurements") + except (AttributeError, OSError) as e: + print(f"Could not pin to single core: {e}") + print("Running without CPU affinity...") + + sep = '=' * 80 + print(sep) + print("VectorType Serialization Performance Benchmark") + print(sep) + + # Test configurations: (vector_size, element_type, type_name, iterations) + test_configs = [ + # Small vectors + (3, FloatType, "float", 50000), + + # Medium vectors (common in ML) + (128, FloatType, "float", 10000), + + # Large vectors (embeddings) + (768, FloatType, "float", 5000), + (1536, FloatType, "float", 2000), + + # Other types + (128, DoubleType, "double", 10000), + (768, DoubleType, "double", 5000), + (1536, DoubleType, "double", 2000), + (128, Int32Type, "int", 10000), + ] + + summary = [] + + for vector_size, element_type, type_name, iterations in test_configs: + baseline = run_benchmark_suite(vector_size, element_type, type_name, iterations) + summary.append((f"Vector<{type_name}, {vector_size}>", baseline)) + + # Print summary + print(f"\n{sep}") + print("SUMMARY - Serialization Baseline Performance (io.BytesIO loop)") + print(sep) + for config, baseline_time in summary: + print(f"{config:30s}: {baseline_time:8.2f} us") + + print(f"\n{sep}") + print("Benchmark complete!") + print(sep) + + +if __name__ == '__main__': + main() diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 1d66ce1ed9..7289d1548b 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -38,12 +38,27 @@ from tests.unit.cython.utils import cythontest from tests.util import assertEqual -from tests.integration import use_singledc, execute_until_pass, notprotocolv1, \ - BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, \ - greaterthanorequalcass3_10, TestCluster, requires_composite_type, \ - requires_vector_type -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \ - get_sample, get_all_samples, get_collection_sample +from tests.integration import ( + use_singledc, + execute_until_pass, + notprotocolv1, + BasicSharedKeyspaceUnitTestCase, + greaterthancass21, + lessthancass30, + greaterthanorequalcass3_10, + TestCluster, + requires_composite_type, + requires_vector_type, +) +from tests.integration.datatype_utils import ( + update_datatypes, + PRIMITIVE_DATATYPES, + COLLECTION_TYPES, + PRIMITIVE_DATATYPES_KEYS, + get_sample, + get_all_samples, + get_collection_sample, +) import pytest @@ -53,7 +68,6 @@ def setup_module(): class TypeTests(BasicSharedKeyspaceUnitTestCase): - @classmethod def setUpClass(cls): # cls._cass_version, cls. = get_server_versions() @@ -68,7 +82,7 @@ def test_can_insert_blob_type_as_string(self): s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)") - params = ['key1', b'blobbyblob'] + params = ["key1", b"blobbyblob"] query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)" s.execute(query, params) @@ -85,14 +99,17 @@ def test_can_insert_blob_type_as_bytearray(self): s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)") - params = ['key1', bytearray(b'blob1')] + params = ["key1", bytearray(b"blob1")] s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params) results = s.execute("SELECT * FROM blobbytes").one() for expected, actual in zip(params, results): assert expected == actual - @unittest.skipIf(not hasattr(cassandra, 'deserializers'), "Cython required for to test DesBytesTypeArray deserializer") + @unittest.skipIf( + not hasattr(cassandra, "deserializers"), + "Cython required for to test DesBytesTypeArray deserializer", + ) def test_des_bytes_type_array(self): """ Simple test to ensure the DesBytesTypeByteArray deserializer functionally works @@ -105,14 +122,15 @@ def test_des_bytes_type_array(self): """ original = None try: - original = cassandra.deserializers.DesBytesType - cassandra.deserializers.DesBytesType = cassandra.deserializers.DesBytesTypeByteArray + cassandra.deserializers.DesBytesType = ( + cassandra.deserializers.DesBytesTypeByteArray + ) s = self.session s.execute("CREATE TABLE blobbytes2 (a ascii PRIMARY KEY, b blob)") - params = ['key1', bytearray(b'blob1')] + params = ["key1", bytearray(b"blob1")] s.execute("INSERT INTO blobbytes2 (a, b) VALUES (%s, %s)", params) results = s.execute("SELECT * FROM blobbytes2").one() @@ -120,7 +138,7 @@ def test_des_bytes_type_array(self): assert expected == actual finally: if original is not None: - cassandra.deserializers.DesBytesType=original + cassandra.deserializers.DesBytesType = original def test_can_insert_primitive_datatypes(self): """ @@ -132,12 +150,12 @@ def test_can_insert_primitive_datatypes(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = ["zz"] - start_index = ord('a') + start_index = ord("a") for i, datatype in enumerate(PRIMITIVE_DATATYPES): alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) col_names.append(chr(start_index + i)) - s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list))) + s.execute("CREATE TABLE alltypes ({0})".format(", ".join(alpha_type_list))) # create the input params = [0] @@ -145,12 +163,19 @@ def test_can_insert_primitive_datatypes(self): params.append((get_sample(datatype))) # insert into table as a simple statement - columns_string = ', '.join(col_names) - placeholders = ', '.join(["%s"] * len(col_names)) - s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + columns_string = ", ".join(col_names) + placeholders = ", ".join(["%s"] * len(col_names)) + s.execute( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ), + params, + ) # verify data - results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM alltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -159,29 +184,46 @@ def test_can_insert_primitive_datatypes(self): for i, datatype in enumerate(PRIMITIVE_DATATYPES): single_col_name = chr(start_index + i) single_col_names = ["zz", single_col_name] - placeholders = ','.join(["%s"] * len(single_col_names)) - single_columns_string = ', '.join(single_col_names) + placeholders = ",".join(["%s"] * len(single_col_names)) + single_columns_string = ", ".join(single_col_names) for j, data_sample in enumerate(get_all_samples(datatype)): key = i + 1000 * j single_params = (key, data_sample) - s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(single_columns_string, placeholders), - single_params) + s.execute( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + single_columns_string, placeholders + ), + single_params, + ) # verify data - result = s.execute("SELECT {0} FROM alltypes WHERE zz=%s".format(single_columns_string), (key,)).one()[1] + result = s.execute( + "SELECT {0} FROM alltypes WHERE zz=%s".format( + single_columns_string + ), + (key,), + ).one()[1] compare_value = data_sample - if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address): + if isinstance(data_sample, ipaddress.IPv4Address) or isinstance( + data_sample, ipaddress.IPv6Address + ): compare_value = str(data_sample) assert result == compare_value # try the same thing with a prepared statement - placeholders = ','.join(["?"] * len(col_names)) + placeholders = ",".join(["?"] * len(col_names)) s.execute("TRUNCATE alltypes;") - insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + insert = s.prepare( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ) + ) s.execute(insert.bind(params)) # verify data - results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM alltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -193,8 +235,12 @@ def test_can_insert_primitive_datatypes(self): # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM alltypes") - results = s.execute(select, - execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory)).one() + results = s.execute( + select, + execution_profile=s.execution_profile_clone_update( + EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory + ), + ).one() for expected, actual in zip(params, results.values()): assert actual == expected @@ -214,23 +260,37 @@ def test_can_insert_collection_datatypes(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = ["zz"] - start_index = ord('a') + start_index = ord("a") for i, collection_type in enumerate(COLLECTION_TYPES): for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): if collection_type == "map": - type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}, {3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) elif collection_type == "tuple": - type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} frozen<{2}<{3}>>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) else: - type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) alpha_type_list.append(type_string) - col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j))) + col_names.append( + "{0}_{1}".format(chr(start_index + i), chr(start_index + j)) + ) - s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list))) - columns_string = ', '.join(col_names) + s.execute("CREATE TABLE allcoltypes ({0})".format(", ".join(alpha_type_list))) + columns_string = ", ".join(col_names) # create the input for simple statement params = [0] @@ -239,11 +299,18 @@ def test_can_insert_collection_datatypes(self): params.append((get_collection_sample(collection_type, datatype))) # insert into table as a simple statement - placeholders = ', '.join(["%s"] * len(col_names)) - s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + placeholders = ", ".join(["%s"] * len(col_names)) + s.execute( + "INSERT INTO allcoltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ), + params, + ) # verify data - results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -254,26 +321,37 @@ def test_can_insert_collection_datatypes(self): params.append((get_collection_sample(collection_type, datatype))) # try the same thing with a prepared statement - placeholders = ','.join(["?"] * len(col_names)) - insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + placeholders = ",".join(["?"] * len(col_names)) + insert = s.prepare( + "INSERT INTO allcoltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ) + ) s.execute(insert.bind(params)) # verify data - results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected # verify data with prepared statement query - select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([0])).one() for expected, actual in zip(params, results): assert actual == expected # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM allcoltypes") - results = s.execute(select, - execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, - row_factory=ordered_dict_factory)).one() + results = s.execute( + select, + execution_profile=s.execution_profile_clone_update( + EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory + ), + ).one() for expected, actual in zip(params, results.values()): assert actual == expected @@ -289,12 +367,16 @@ def test_can_insert_empty_strings_and_nulls(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = [] - string_types = set(('ascii', 'text', 'varchar')) - string_columns = set(('')) + string_types = set(("ascii", "text", "varchar")) + string_columns = set(("")) # this is just a list of types to try with empty strings - non_string_types = PRIMITIVE_DATATYPES - string_types - set(('blob', 'date', 'inet', 'time', 'timestamp')) + non_string_types = ( + PRIMITIVE_DATATYPES + - string_types + - set(("blob", "date", "inet", "time", "timestamp")) + ) non_string_columns = set() - start_index = ord('a') + start_index = ord("a") for i, datatype in enumerate(PRIMITIVE_DATATYPES): col_name = chr(start_index + i) alpha_type_list.append("{0} {1}".format(col_name, datatype)) @@ -304,32 +386,48 @@ def test_can_insert_empty_strings_and_nulls(self): if datatype in string_types: string_columns.add(col_name) - execute_until_pass(s, "CREATE TABLE all_empty ({0})".format(', '.join(alpha_type_list))) + execute_until_pass( + s, "CREATE TABLE all_empty ({0})".format(", ".join(alpha_type_list)) + ) # verify all types initially null with simple statement - columns_string = ','.join(col_names) + columns_string = ",".join(col_names) s.execute("INSERT INTO all_empty (zz) VALUES (2)") - results = s.execute("SELECT {0} FROM all_empty WHERE zz=2".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM all_empty WHERE zz=2".format(columns_string) + ).one() assert all(x is None for x in results) # verify all types initially null with prepared statement - select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM all_empty WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([2])).one() assert all(x is None for x in results) # insert empty strings for string-like fields - expected_values = dict((col, '') for col in string_columns) - columns_string = ','.join(string_columns) - placeholders = ','.join(["%s"] * len(string_columns)) - s.execute("INSERT INTO all_empty (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values()) + expected_values = dict((col, "") for col in string_columns) + columns_string = ",".join(string_columns) + placeholders = ",".join(["%s"] * len(string_columns)) + s.execute( + "INSERT INTO all_empty (zz, {0}) VALUES (3, {1})".format( + columns_string, placeholders + ), + expected_values.values(), + ) # verify string types empty with simple statement - results = s.execute("SELECT {0} FROM all_empty WHERE zz=3".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM all_empty WHERE zz=3".format(columns_string) + ).one() for expected, actual in zip(expected_values.values(), results): assert actual == expected # verify string types empty with prepared statement - results = s.execute(s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), [3]).one() + results = s.execute( + s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), + [3], + ).one() for expected, actual in zip(expected_values.values(), results): assert actual == expected @@ -337,11 +435,13 @@ def test_can_insert_empty_strings_and_nulls(self): for col in non_string_columns: query = "INSERT INTO all_empty (zz, {0}) VALUES (4, %s)".format(col) with pytest.raises(InvalidRequest): - s.execute(query, ['']) + s.execute(query, [""]) - insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col)) + insert = s.prepare( + "INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col) + ) with pytest.raises(TypeError): - s.execute(insert, ['']) + s.execute(insert, [""]) # verify that Nones can be inserted and overwrites existing data # create the input @@ -350,9 +450,11 @@ def test_can_insert_empty_strings_and_nulls(self): params.append((get_sample(datatype))) # insert the data - columns_string = ','.join(col_names) - placeholders = ','.join(["%s"] * len(col_names)) - simple_insert = "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders) + columns_string = ",".join(col_names) + placeholders = ",".join(["%s"] * len(col_names)) + simple_insert = "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format( + columns_string, placeholders + ) s.execute(simple_insert, params) # then insert None, which should null them out @@ -366,7 +468,9 @@ def test_can_insert_empty_strings_and_nulls(self): assert None == col # check via prepared statement - select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM all_empty WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([5])).one() for col in results: assert None == col @@ -374,8 +478,12 @@ def test_can_insert_empty_strings_and_nulls(self): # do the same thing again, but use a prepared statement to insert the nulls s.execute(simple_insert, params) - placeholders = ','.join(["?"] * len(col_names)) - insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)) + placeholders = ",".join(["?"] * len(col_names)) + insert = s.prepare( + "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format( + columns_string, placeholders + ) + ) s.execute(insert, null_values) results = s.execute(query).one() @@ -393,10 +501,14 @@ def test_can_insert_empty_values_for_int32(self): s = self.session execute_until_pass(s, "CREATE TABLE empty_values (a text PRIMARY KEY, b int)") - execute_until_pass(s, "INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))") + execute_until_pass( + s, "INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))" + ) try: Int32Type.support_empty_values = True - results = execute_until_pass(s, "SELECT b FROM empty_values WHERE a='a'").one() + results = execute_until_pass( + s, "SELECT b FROM empty_values WHERE a='a'" + ).one() assert EMPTY is results.b finally: Int32Type.support_empty_values = False @@ -408,7 +520,7 @@ def test_timezone_aware_datetimes_are_timestamps(self): from zoneinfo import ZoneInfo - eastern_tz = ZoneInfo('US/Eastern') + eastern_tz = ZoneInfo("US/Eastern") dt = datetime(1997, 8, 29, 11, 14, tzinfo=eastern_tz) s = self.session @@ -440,24 +552,30 @@ def test_can_insert_tuples(self): # use this encoder in order to insert tuples s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - s.execute("CREATE TABLE tuple_type (a int PRIMARY KEY, b frozen>)") + s.execute( + "CREATE TABLE tuple_type (a int PRIMARY KEY, b frozen>)" + ) # test non-prepared statement - complete = ('foo', 123, True) - s.execute("INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,)) + complete = ("foo", 123, True) + s.execute( + "INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,) + ) result = s.execute("SELECT b FROM tuple_type WHERE a=0").one() assert complete == result.b - partial = ('bar', 456) + partial = ("bar", 456) partial_result = partial + (None,) s.execute("INSERT INTO tuple_type (a, b) VALUES (1, %s)", parameters=(partial,)) result = s.execute("SELECT b FROM tuple_type WHERE a=1").one() assert partial_result == result.b # test single value tuples - subpartial = ('zoo',) + subpartial = ("zoo",) subpartial_result = subpartial + (None, None) - s.execute("INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,)) + s.execute( + "INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,) + ) result = s.execute("SELECT b FROM tuple_type WHERE a=2").one() assert subpartial_result == result.b @@ -488,7 +606,9 @@ def test_can_insert_tuples_with_varying_lengths(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -499,8 +619,11 @@ def test_can_insert_tuples_with_varying_lengths(self): lengths = (1, 2, 3, 384) value_schema = [] for i in lengths: - value_schema += [' v_%s frozen>' % (i, ', '.join(['int'] * i))] - s.execute("CREATE TABLE tuple_lengths (k int PRIMARY KEY, %s)" % (', '.join(value_schema),)) + value_schema += [" v_%s frozen>" % (i, ", ".join(["int"] * i))] + s.execute( + "CREATE TABLE tuple_lengths (k int PRIMARY KEY, %s)" + % (", ".join(value_schema),) + ) # insert tuples into same key using different columns # and verify the results @@ -508,15 +631,20 @@ def test_can_insert_tuples_with_varying_lengths(self): # ensure tuples of larger sizes throw an error created_tuple = tuple(range(0, i + 1)) with pytest.raises(InvalidRequest): - s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) # ensure tuples of proper sizes are written and read correctly created_tuple = tuple(range(0, i)) - s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple) + ) result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,)).one() - assert tuple(created_tuple) == result['v_%s' % i] + assert tuple(created_tuple) == result["v_%s" % i] c.shutdown() def test_can_insert_tuples_all_primitive_datatypes(self): @@ -531,9 +659,11 @@ def test_can_insert_tuples_all_primitive_datatypes(self): s = c.connect(self.keyspace_name) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - s.execute("CREATE TABLE tuple_primitive (" - "k int PRIMARY KEY, " - "v frozen>)" % ','.join(PRIMITIVE_DATATYPES)) + s.execute( + "CREATE TABLE tuple_primitive (" + "k int PRIMARY KEY, " + "v frozen>)" % ",".join(PRIMITIVE_DATATYPES) + ) values = [] type_count = len(PRIMITIVE_DATATYPES) @@ -542,7 +672,9 @@ def test_can_insert_tuples_all_primitive_datatypes(self): # responses have trailing None values for every element that has not been written values.append(get_sample(data_type)) expected = tuple(values + [None] * (type_count - len(values))) - s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values))) + s.execute( + "INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values)) + ) result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,)).one() assert result.v == expected c.shutdown() @@ -556,7 +688,9 @@ def test_can_insert_tuples_all_collection_datatypes(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -567,62 +701,87 @@ def test_can_insert_tuples_all_collection_datatypes(self): # create list values for datatype in PRIMITIVE_DATATYPES_KEYS: - values.append('v_{0} frozen>>'.format(len(values), datatype)) + values.append( + "v_{0} frozen>>".format(len(values), datatype) + ) # create set values for datatype in PRIMITIVE_DATATYPES_KEYS: - values.append('v_{0} frozen>>'.format(len(values), datatype)) + values.append("v_{0} frozen>>".format(len(values), datatype)) # create map values for datatype in PRIMITIVE_DATATYPES_KEYS: datatype_1 = datatype_2 = datatype - if datatype == 'blob': + if datatype == "blob": # unhashable type: 'bytearray' - datatype_1 = 'ascii' - values.append('v_{0} frozen>>'.format(len(values), datatype_1, datatype_2)) + datatype_1 = "ascii" + values.append( + "v_{0} frozen>>".format( + len(values), datatype_1, datatype_2 + ) + ) # make sure we're testing all non primitive data types in the future - if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']): - raise NotImplemented('Missing datatype not implemented: {}'.format( - set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set']) - )) + if set(COLLECTION_TYPES) != set(["tuple", "list", "map", "set"]): + raise NotImplementedError( + "Missing datatype not implemented: {}".format( + set(COLLECTION_TYPES) - set(["tuple", "list", "map", "set"]) + ) + ) # create table - s.execute("CREATE TABLE tuple_non_primative (" - "k int PRIMARY KEY, " - "%s)" % ', '.join(values)) + s.execute( + "CREATE TABLE tuple_non_primative (" + "k int PRIMARY KEY, " + "%s)" % ", ".join(values) + ) i = 0 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([[get_sample(datatype)]]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) - - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) + + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([sortedset([get_sample(datatype)])]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) - - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) + + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: - if datatype == 'blob': + if datatype == "blob": # unhashable type: 'bytearray' - created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) + created_tuple = tuple([{get_sample("ascii"): get_sample(datatype)}]) else: created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 c.shutdown() @@ -632,9 +791,9 @@ def nested_tuples_schema_helper(self, depth): """ if depth == 0: - return 'int' + return "int" else: - return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1) + return "tuple<%s>" % self.nested_tuples_schema_helper(depth - 1) def nested_tuples_creator_helper(self, depth): """ @@ -644,7 +803,7 @@ def nested_tuples_creator_helper(self, depth): if depth == 0: return 303 else: - return (self.nested_tuples_creator_helper(depth - 1), ) + return (self.nested_tuples_creator_helper(depth - 1),) def test_can_insert_nested_tuples(self): """ @@ -655,7 +814,9 @@ def test_can_insert_nested_tuples(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -663,27 +824,37 @@ def test_can_insert_nested_tuples(self): s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple # create a table with multiple sizes of nested tuples - s.execute("CREATE TABLE nested_tuples (" - "k int PRIMARY KEY, " - "v_1 frozen<%s>," - "v_2 frozen<%s>," - "v_3 frozen<%s>," - "v_32 frozen<%s>" - ")" % (self.nested_tuples_schema_helper(1), - self.nested_tuples_schema_helper(2), - self.nested_tuples_schema_helper(3), - self.nested_tuples_schema_helper(32))) + s.execute( + "CREATE TABLE nested_tuples (" + "k int PRIMARY KEY, " + "v_1 frozen<%s>," + "v_2 frozen<%s>," + "v_3 frozen<%s>," + "v_32 frozen<%s>" + ")" + % ( + self.nested_tuples_schema_helper(1), + self.nested_tuples_schema_helper(2), + self.nested_tuples_schema_helper(3), + self.nested_tuples_schema_helper(32), + ) + ) for i in (1, 2, 3, 32): # create tuple created_tuple = self.nested_tuples_creator_helper(i) # write tuple - s.execute("INSERT INTO nested_tuples (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple)) + s.execute( + "INSERT INTO nested_tuples (k, v_%s) VALUES (%s, %s)", + (i, i, created_tuple), + ) # verify tuple was written and read correctly - result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i)).one() - assert created_tuple == result['v_%s' % i] + result = s.execute( + "SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i) + ).one() + assert created_tuple == result["v_%s" % i] c.shutdown() def test_can_insert_tuples_with_nulls(self): @@ -696,7 +867,9 @@ def test_can_insert_tuples_with_nulls(self): s = self.session - s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)") + s.execute( + "CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)" + ) insert = s.prepare("INSERT INTO tuples_nulls (k, t) VALUES (0, ?)") s.execute(insert, [(None, None, None, None)]) @@ -708,10 +881,10 @@ def test_can_insert_tuples_with_nulls(self): assert (None, None, None, None) == s.execute(read).one().t # also test empty strings where compatible - s.execute(insert, [('', None, None, b'')]) + s.execute(insert, [("", None, None, b"")]) result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") - assert ('', None, None, b'') == result.one().t - assert ('', None, None, b'') == s.execute(read).one().t + assert ("", None, None, b"") == result.one().t + assert ("", None, None, b"") == s.execute(read).one().t def test_insert_collection_with_null_fails(self): """ @@ -721,52 +894,62 @@ def test_insert_collection_with_null_fails(self): """ s = self.session columns = [] - for collection_type in ['list', 'set']: + for collection_type in ["list", "set"]: for simple_type in PRIMITIVE_DATATYPES_KEYS: - columns.append(f'{collection_type}_{simple_type} {collection_type}<{simple_type}>') + columns.append( + f"{collection_type}_{simple_type} {collection_type}<{simple_type}>" + ) for simple_type in PRIMITIVE_DATATYPES_KEYS: - columns.append(f'map_k_{simple_type} map<{simple_type}, ascii>') - columns.append(f'map_v_{simple_type} map') - s.execute(f'CREATE TABLE collection_nulls (k int PRIMARY KEY, {", ".join(columns)})') + columns.append(f"map_k_{simple_type} map<{simple_type}, ascii>") + columns.append(f"map_v_{simple_type} map") + s.execute( + f"CREATE TABLE collection_nulls (k int PRIMARY KEY, {', '.join(columns)})" + ) def raises_simple_and_prepared(exc_type, query_str, args): with pytest.raises(exc_type): s.execute(query_str, args) - p = s.prepare(query_str.replace('%s', '?')) + p = s.prepare(query_str.replace("%s", "?")) with pytest.raises(exc_type): s.execute(p, args) i = 0 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, set_{simple_type}) VALUES (%s, %s)' + query_str = ( + f"INSERT INTO collection_nulls (k, set_{simple_type}) VALUES (%s, %s)" + ) args = [i, sortedset([None, get_sample(simple_type)])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, list_{simple_type}) VALUES (%s, %s)' + query_str = ( + f"INSERT INTO collection_nulls (k, list_{simple_type}) VALUES (%s, %s)" + ) args = [i, [None, get_sample(simple_type)]] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, map_k_{simple_type}) VALUES (%s, %s)' - args = [i, OrderedMap([(get_sample(simple_type), 'abc'), (None, 'def')])] + query_str = ( + f"INSERT INTO collection_nulls (k, map_k_{simple_type}) VALUES (%s, %s)" + ) + args = [i, OrderedMap([(get_sample(simple_type), "abc"), (None, "def")])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, map_v_{simple_type}) VALUES (%s, %s)' - args = [i, OrderedMap([('abc', None), ('def', get_sample(simple_type))])] + query_str = ( + f"INSERT INTO collection_nulls (k, map_v_{simple_type}) VALUES (%s, %s)" + ) + args = [i, OrderedMap([("abc", None), ("def", get_sample(simple_type))])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 - - def test_can_insert_unicode_query_string(self): """ Test to ensure unicode strings can be used in a query """ s = self.session - s.execute(u"SELECT * FROM system.local WHERE key = 'ef\u2052ef'") - s.execute(u"SELECT * FROM system.local WHERE key = %s", (u"fe\u2051fe",)) + s.execute("SELECT * FROM system.local WHERE key = 'ef\u2052ef'") + s.execute("SELECT * FROM system.local WHERE key = %s", ("fe\u2051fe",)) @requires_composite_type def test_can_read_composite_type(self): @@ -785,13 +968,13 @@ def test_can_read_composite_type(self): s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc:123')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() assert 0 == result.a - assert ('abc', 123) == result.b + assert ("abc", 123) == result.b # CompositeType values can omit elements at the end s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() assert 0 == result.a - assert ('abc',) == result.b + assert ("abc",) == result.b @notprotocolv1 def test_special_float_cql_encoding(self): @@ -811,7 +994,7 @@ def test_special_float_cql_encoding(self): f float PRIMARY KEY, d double )""") - items = (float('nan'), float('inf'), float('-inf')) + items = (float("nan"), float("inf"), float("-inf")) def verify_insert_select(ins_statement, sel_statement): execute_concurrent_with_args(s, ins_statement, ((f, f) for f in items)) @@ -825,14 +1008,18 @@ def verify_insert_select(ins_statement, sel_statement): assert row.d == f # cql encoding - verify_insert_select('INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)', - 'SELECT * FROM float_cql_encoding WHERE f=%s') + verify_insert_select( + "INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)", + "SELECT * FROM float_cql_encoding WHERE f=%s", + ) s.execute("TRUNCATE float_cql_encoding") # prepared binding - verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'), - s.prepare('SELECT * FROM float_cql_encoding WHERE f=?')) + verify_insert_select( + s.prepare("INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)"), + s.prepare("SELECT * FROM float_cql_encoding WHERE f=?"), + ) @cythontest def test_cython_decimal(self): @@ -846,11 +1033,19 @@ def test_cython_decimal(self): @test_category data_types serialization """ - self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name)) + self.session.execute( + "CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name) + ) try: - self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name)) - results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) - assert str(results.one().dc) == '-1.08430792318105707' + self.session.execute( + "INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format( + self.function_table_name + ) + ) + results = self.session.execute( + "SELECT * FROM {0}".format(self.function_table_name) + ) + assert str(results.one().dc) == "-1.08430792318105707" finally: self.session.execute("DROP TABLE {0}".format(self.function_table_name)) @@ -877,37 +1072,66 @@ def test_smoke_duration_values(self): VALUES (?, ?) """) - nanosecond_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, - 10000000000000,-9223372036854775807, 9223372036854775807, - int("7FFFFFFFFFFFFFFF", 16), int("-7FFFFFFFFFFFFFFF", 16)] - month_day_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, - int("7FFFFFFF", 16), int("-7FFFFFFF", 16)] + nanosecond_smoke_values = [ + 0, + -1, + 1, + 100, + 1000, + 1000000, + 1000000000, + 10000000000000, + -9223372036854775807, + 9223372036854775807, + int("7FFFFFFFFFFFFFFF", 16), + int("-7FFFFFFFFFFFFFFF", 16), + ] + month_day_smoke_values = [ + 0, + -1, + 1, + 100, + 1000, + 1000000, + 1000000000, + int("7FFFFFFF", 16), + int("-7FFFFFFF", 16), + ] for nanosecond_value in nanosecond_smoke_values: for month_day_value in month_day_smoke_values: - # Must have the same sign if (month_day_value <= 0) != (nanosecond_value <= 0): continue - self.session.execute(prepared, (1, Duration(month_day_value, month_day_value, nanosecond_value))) + self.session.execute( + prepared, + (1, Duration(month_day_value, month_day_value, nanosecond_value)), + ) results = self.session.execute("SELECT * FROM duration_smoke") v = results.one()[1] - assert Duration(month_day_value, month_day_value, nanosecond_value) == v, "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value) + assert ( + Duration(month_day_value, month_day_value, nanosecond_value) == v + ), "Error encoding value {0},{0},{1}".format( + month_day_value, nanosecond_value + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16)))) + self.session.execute( + prepared, (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16))) + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0))) + self.session.execute( + prepared, (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0)) + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0))) + self.session.execute( + prepared, (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0)) + ) -class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): +class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): @greaterthancass21 @lessthancass30 def test_nested_types_with_protocol_version(self): @@ -921,26 +1145,30 @@ def test_nested_types_with_protocol_version(self): @test_category data_types serialization """ - ddl = '''CREATE TABLE {0}.t ( + ddl = """CREATE TABLE {0}.t ( k int PRIMARY KEY, - v list>>)'''.format(self.keyspace_name) + v list>>)""".format(self.keyspace_name) self.session.execute(ddl) - ddl = '''CREATE TABLE {0}.u ( + ddl = """CREATE TABLE {0}.u ( k int PRIMARY KEY, - v set>>)'''.format(self.keyspace_name) + v set>>)""".format(self.keyspace_name) self.session.execute(ddl) - ddl = '''CREATE TABLE {0}.v ( + ddl = """CREATE TABLE {0}.v ( k int PRIMARY KEY, v map>, frozen>>, - v1 frozen>)'''.format(self.keyspace_name) + v1 frozen>)""".format(self.keyspace_name) self.session.execute(ddl) - self.session.execute("CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format( + self.keyspace_name + ) + ) - ddl = '''CREATE TABLE {0}.w ( + ddl = """CREATE TABLE {0}.w ( k int PRIMARY KEY, - v frozen)'''.format(self.keyspace_name) + v frozen)""".format(self.keyspace_name) self.session.execute(ddl) @@ -952,17 +1180,23 @@ def test_nested_types_with_protocol_version(self): def read_inserts_at_level(self, proto_ver): session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: - results = session.execute('select * from t').one() + results = session.execute("select * from t").one() assert "[SortedSet([1, 2]), SortedSet([3, 5])]" == str(results.v) - results = session.execute('select * from u').one() + results = session.execute("select * from u").one() assert "SortedSet([[1, 2], [3, 5]])" == str(results.v) - results = session.execute('select * from v').one() - assert "{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}" == str(results.v) + results = session.execute("select * from v").one() + assert ( + "{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}" + == str(results.v) + ) - results = session.execute('select * from w').one() - assert "typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])" == str(results.v) + results = session.execute("select * from w").one() + assert ( + "typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])" + == str(results.v) + ) finally: session.cluster.shutdown() @@ -970,31 +1204,37 @@ def read_inserts_at_level(self, proto_ver): def run_inserts_at_version(self, proto_ver): session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: - p = session.prepare('insert into t (k, v) values (?, ?)') + p = session.prepare("insert into t (k, v) values (?, ?)") session.execute(p, (0, [{1, 2}, {3, 5}])) - p = session.prepare('insert into u (k, v) values (?, ?)') + p = session.prepare("insert into u (k, v) values (?, ?)") session.execute(p, (0, {(1, 2), (3, 5)})) - p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)') - session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four'))) + p = session.prepare("insert into v (k, v, v1) values (?, ?, ?)") + session.execute( + p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, "four")) + ) - p = session.prepare('insert into w (k, v) values (?, ?)') + p = session.prepare("insert into w (k, v) values (?, ?)") session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9]))) finally: session.cluster.shutdown() + @requires_vector_type class TypeTestsVector(BasicSharedKeyspaceUnitTestCase): - def _get_first_j(self, rs): rows = rs.all() assert len(rows) == 1 return rows[0].j def _get_row_simple(self, idx, table_name): - rs = self.session.execute("select j from {0}.{1} where i = {2}".format(self.keyspace_name, table_name, idx)) + rs = self.session.execute( + "select j from {0}.{1} where i = {2}".format( + self.keyspace_name, table_name, idx + ) + ) return self._get_first_j(rs) def _get_row_prepared(self, idx, table_name): @@ -1003,9 +1243,13 @@ def _get_row_prepared(self, idx, table_name): rs = self.session.execute(ps, [idx]) return self._get_first_j(rs) - def _round_trip_test(self, subtype, subtype_fn, test_fn, use_positional_parameters=True): + def _round_trip_test( + self, subtype, subtype_fn, test_fn, use_positional_parameters=True + ): - table_name = subtype.replace("<","A").replace(">", "B").replace(",", "C") + "isH" + table_name = ( + subtype.replace("<", "A").replace(">", "B").replace(",", "C") + "isH" + ) def random_subtype_vector(): return [subtype_fn() for _ in range(3)] @@ -1016,20 +1260,28 @@ def random_subtype_vector(): self.session.execute(ddl) if use_positional_parameters: - cql = "insert into {0}.{1} (i,j) values (%s,%s)".format(self.keyspace_name, table_name) + cql = "insert into {0}.{1} (i,j) values (%s,%s)".format( + self.keyspace_name, table_name + ) expected1 = random_subtype_vector() - data1 = {1:random_subtype_vector(), 2:expected1, 3:random_subtype_vector()} - for k,v in data1.items(): + data1 = { + 1: random_subtype_vector(), + 2: expected1, + 3: random_subtype_vector(), + } + for k, v in data1.items(): # Attempt a set of inserts using the driver's support for positional params - self.session.execute(cql, (k,v)) + self.session.execute(cql, (k, v)) - cql = "insert into {0}.{1} (i,j) values (?,?)".format(self.keyspace_name, table_name) + cql = "insert into {0}.{1} (i,j) values (?,?)".format( + self.keyspace_name, table_name + ) expected2 = random_subtype_vector() ps = self.session.prepare(cql) - data2 = {4:random_subtype_vector(), 5:expected2, 6:random_subtype_vector()} - for k,v in data2.items(): + data2 = {4: random_subtype_vector(), 5: expected2, 6: random_subtype_vector()} + for k, v in data2.items(): # Add some additional rows via prepared statements - self.session.execute(ps, [k,v]) + self.session.execute(ps, [k, v]) # Use prepared queries to gather data from the rows we added via simple queries and vice versa if use_positional_parameters: @@ -1042,36 +1294,52 @@ def random_subtype_vector(): test_fn(observed2[idx], expected2[idx]) def test_round_trip_integers(self): - self._round_trip_test("int", partial(random.randint, 0, 2 ** 31), assertEqual) - self._round_trip_test("bigint", partial(random.randint, 0, 2 ** 63), assertEqual) - self._round_trip_test("smallint", partial(random.randint, 0, 2 ** 15), assertEqual) - self._round_trip_test("tinyint", partial(random.randint, 0, (2 ** 7) - 1), assertEqual) - self._round_trip_test("varint", partial(random.randint, 0, 2 ** 63), assertEqual) + self._round_trip_test("int", partial(random.randint, 0, 2**31), assertEqual) + self._round_trip_test("bigint", partial(random.randint, 0, 2**63), assertEqual) + self._round_trip_test( + "smallint", partial(random.randint, 0, 2**15), assertEqual + ) + self._round_trip_test( + "tinyint", partial(random.randint, 0, (2**7) - 1), assertEqual + ) + self._round_trip_test("varint", partial(random.randint, 0, 2**63), assertEqual) def test_round_trip_floating_point(self): _almost_equal_test_fn = partial(pytest.approx, abs=1e-5) + def _random_decimal(): return Decimal(random.uniform(0.0, 100.0)) # Max value here isn't really connected to max value for floating point nums in IEEE 754... it's used here # mainly as a convenient benchmark - self._round_trip_test("float", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) - self._round_trip_test("double", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) + self._round_trip_test( + "float", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn + ) + self._round_trip_test( + "double", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn + ) self._round_trip_test("decimal", _random_decimal, _almost_equal_test_fn) def test_round_trip_text(self): def _random_string(): - return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(24)) + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(24) + ) self._round_trip_test("ascii", _random_string, assertEqual) self._round_trip_test("text", _random_string, assertEqual) def test_round_trip_date_and_time(self): _almost_equal_test_fn = partial(pytest.approx, abs=timedelta(seconds=1)) + def _random_datetime(): - return datetime.today() - timedelta(hours=random.randint(0,18), days=random.randint(1,1000)) + return datetime.today() - timedelta( + hours=random.randint(0, 18), days=random.randint(1, 1000) + ) + def _random_date(): return _random_datetime().date() + def _random_time(): return _random_datetime().time() @@ -1085,11 +1353,16 @@ def test_round_trip_uuid(self): def test_round_trip_miscellany(self): def _random_bytes(): - return random.getrandbits(32).to_bytes(4,'big') + return random.getrandbits(32).to_bytes(4, "big") + def _random_boolean(): return random.choice([True, False]) + def _random_duration(): - return Duration(random.randint(0,11), random.randint(0,11), random.randint(0,10000)) + return Duration( + random.randint(0, 11), random.randint(0, 11), random.randint(0, 10000) + ) + def _random_inet(): return socket.inet_ntoa(_random_bytes()) @@ -1100,11 +1373,13 @@ def _random_inet(): def test_round_trip_collections(self): def _random_seq(): - return [random.randint(0,100000) for _ in range(8)] + return [random.randint(0, 100000) for _ in range(8)] + def _random_set(): return set(_random_seq()) + def _random_map(): - return {k:v for (k,v) in zip(_random_seq(), _random_seq())} + return {k: v for (k, v) in zip(_random_seq(), _random_seq())} # Goal here is to test collections of both fixed and variable size subtypes self._round_trip_test("list", _random_seq, assertEqual) @@ -1118,44 +1393,76 @@ def _random_map(): def test_round_trip_vector_of_vectors(self): def _random_vector(): - return [random.randint(0,100000) for _ in range(2)] + return [random.randint(0, 100000) for _ in range(2)] self._round_trip_test("vector", _random_vector, assertEqual) self._round_trip_test("vector", _random_vector, assertEqual) def test_round_trip_tuples(self): def _random_tuple(): - return (random.randint(0,100000),random.randint(0,100000)) + return (random.randint(0, 100000), random.randint(0, 100000)) # Unfortunately we can't use positional parameters when inserting tuples because the driver will try to encode # them as lists before sending them to the server... and that confuses the parsing logic. - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) def test_round_trip_udts(self): def _udt_equal_test_fn(udt1, udt2): assert udt1.a == udt2.a assert udt1.b == udt2.b - self.session.execute("create type {}.fixed_type (a int, b int)".format(self.keyspace_name)) - self.session.execute("create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name)) - self.session.execute("create type {}.mixed_type_two (a varint, b int)".format(self.keyspace_name)) - self.session.execute("create type {}.var_type (a varint, b varint)".format(self.keyspace_name)) + self.session.execute( + "create type {}.fixed_type (a int, b int)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.mixed_type_two (a varint, b int)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.var_type (a varint, b varint)".format(self.keyspace_name) + ) class GeneralUDT: def __init__(self, a, b): self.a = a self.b = b - self.cluster.register_user_type(self.keyspace_name,'fixed_type', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'mixed_type_one', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'mixed_type_two', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'var_type', GeneralUDT) + self.cluster.register_user_type(self.keyspace_name, "fixed_type", GeneralUDT) + self.cluster.register_user_type( + self.keyspace_name, "mixed_type_one", GeneralUDT + ) + self.cluster.register_user_type( + self.keyspace_name, "mixed_type_two", GeneralUDT + ) + self.cluster.register_user_type(self.keyspace_name, "var_type", GeneralUDT) def _random_udt(): - return GeneralUDT(random.randint(0,100000),random.randint(0,100000)) + return GeneralUDT(random.randint(0, 100000), random.randint(0, 100000)) self._round_trip_test("fixed_type", _random_udt, _udt_equal_test_fn) self._round_trip_test("mixed_type_one", _random_udt, _udt_equal_test_fn) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 7a8c584f75..11b623d70a 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -22,116 +22,176 @@ import cassandra from cassandra import util from cassandra.cqltypes import ( - CassandraType, DateRangeType, DateType, DecimalType, - EmptyValue, LongType, SetType, UTF8Type, - cql_typename, int8_pack, int64_pack, int64_unpack, lookup_casstype, - lookup_casstype_simple, parse_casstype_args, - int32_pack, Int32Type, ListType, MapType, VectorType, - FloatType + CassandraType, + DateRangeType, + DateType, + DecimalType, + EmptyValue, + LongType, + SetType, + UTF8Type, + cql_typename, + int8_pack, + int64_pack, + int64_unpack, + lookup_casstype, + lookup_casstype_simple, + parse_casstype_args, + int32_pack, + Int32Type, + ListType, + MapType, + VectorType, + FloatType, ) from cassandra.encoder import cql_quote from cassandra.pool import Host from cassandra.metadata import Token from cassandra.policies import ConvictionPolicy, SimpleConvictionPolicy from cassandra.protocol import ( - read_inet, read_longstring, read_string, - read_stringmap, write_inet, write_longstring, - write_string, write_stringmap + read_inet, + read_longstring, + read_string, + read_stringmap, + write_inet, + write_longstring, + write_string, + write_stringmap, ) from cassandra.query import named_tuple_factory from cassandra.util import ( - OPEN_BOUND, Date, DateRange, DateRangeBound, - DateRangePrecision, Time, ms_timestamp_from_datetime, - datetime_from_timestamp + OPEN_BOUND, + Date, + DateRange, + DateRangeBound, + DateRangePrecision, + Time, + ms_timestamp_from_datetime, + datetime_from_timestamp, ) from tests.unit.util import check_sequence_consistency import pytest class TypeTests(unittest.TestCase): - def test_lookup_casstype_simple(self): """ Ensure lookup_casstype_simple returns the correct classes """ - assert lookup_casstype_simple('AsciiType') == cassandra.cqltypes.AsciiType - assert lookup_casstype_simple('LongType') == cassandra.cqltypes.LongType - assert lookup_casstype_simple('BytesType') == cassandra.cqltypes.BytesType - assert lookup_casstype_simple('BooleanType') == cassandra.cqltypes.BooleanType - assert lookup_casstype_simple('CounterColumnType') == cassandra.cqltypes.CounterColumnType - assert lookup_casstype_simple('DecimalType') == cassandra.cqltypes.DecimalType - assert lookup_casstype_simple('DoubleType') == cassandra.cqltypes.DoubleType - assert lookup_casstype_simple('FloatType') == cassandra.cqltypes.FloatType - assert lookup_casstype_simple('InetAddressType') == cassandra.cqltypes.InetAddressType - assert lookup_casstype_simple('Int32Type') == cassandra.cqltypes.Int32Type - assert lookup_casstype_simple('UTF8Type') == cassandra.cqltypes.UTF8Type - assert lookup_casstype_simple('DateType') == cassandra.cqltypes.DateType - assert lookup_casstype_simple('SimpleDateType') == cassandra.cqltypes.SimpleDateType - assert lookup_casstype_simple('ByteType') == cassandra.cqltypes.ByteType - assert lookup_casstype_simple('ShortType') == cassandra.cqltypes.ShortType - assert lookup_casstype_simple('TimeUUIDType') == cassandra.cqltypes.TimeUUIDType - assert lookup_casstype_simple('TimeType') == cassandra.cqltypes.TimeType - assert lookup_casstype_simple('UUIDType') == cassandra.cqltypes.UUIDType - assert lookup_casstype_simple('IntegerType') == cassandra.cqltypes.IntegerType - assert lookup_casstype_simple('MapType') == cassandra.cqltypes.MapType - assert lookup_casstype_simple('ListType') == cassandra.cqltypes.ListType - assert lookup_casstype_simple('SetType') == cassandra.cqltypes.SetType - assert lookup_casstype_simple('CompositeType') == cassandra.cqltypes.CompositeType - assert lookup_casstype_simple('ColumnToCollectionType') == cassandra.cqltypes.ColumnToCollectionType - assert lookup_casstype_simple('ReversedType') == cassandra.cqltypes.ReversedType - assert lookup_casstype_simple('DurationType') == cassandra.cqltypes.DurationType - assert lookup_casstype_simple('DateRangeType') == cassandra.cqltypes.DateRangeType - - assert str(lookup_casstype_simple('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) + assert lookup_casstype_simple("AsciiType") == cassandra.cqltypes.AsciiType + assert lookup_casstype_simple("LongType") == cassandra.cqltypes.LongType + assert lookup_casstype_simple("BytesType") == cassandra.cqltypes.BytesType + assert lookup_casstype_simple("BooleanType") == cassandra.cqltypes.BooleanType + assert ( + lookup_casstype_simple("CounterColumnType") + == cassandra.cqltypes.CounterColumnType + ) + assert lookup_casstype_simple("DecimalType") == cassandra.cqltypes.DecimalType + assert lookup_casstype_simple("DoubleType") == cassandra.cqltypes.DoubleType + assert lookup_casstype_simple("FloatType") == cassandra.cqltypes.FloatType + assert ( + lookup_casstype_simple("InetAddressType") + == cassandra.cqltypes.InetAddressType + ) + assert lookup_casstype_simple("Int32Type") == cassandra.cqltypes.Int32Type + assert lookup_casstype_simple("UTF8Type") == cassandra.cqltypes.UTF8Type + assert lookup_casstype_simple("DateType") == cassandra.cqltypes.DateType + assert ( + lookup_casstype_simple("SimpleDateType") + == cassandra.cqltypes.SimpleDateType + ) + assert lookup_casstype_simple("ByteType") == cassandra.cqltypes.ByteType + assert lookup_casstype_simple("ShortType") == cassandra.cqltypes.ShortType + assert lookup_casstype_simple("TimeUUIDType") == cassandra.cqltypes.TimeUUIDType + assert lookup_casstype_simple("TimeType") == cassandra.cqltypes.TimeType + assert lookup_casstype_simple("UUIDType") == cassandra.cqltypes.UUIDType + assert lookup_casstype_simple("IntegerType") == cassandra.cqltypes.IntegerType + assert lookup_casstype_simple("MapType") == cassandra.cqltypes.MapType + assert lookup_casstype_simple("ListType") == cassandra.cqltypes.ListType + assert lookup_casstype_simple("SetType") == cassandra.cqltypes.SetType + assert ( + lookup_casstype_simple("CompositeType") == cassandra.cqltypes.CompositeType + ) + assert ( + lookup_casstype_simple("ColumnToCollectionType") + == cassandra.cqltypes.ColumnToCollectionType + ) + assert lookup_casstype_simple("ReversedType") == cassandra.cqltypes.ReversedType + assert lookup_casstype_simple("DurationType") == cassandra.cqltypes.DurationType + assert ( + lookup_casstype_simple("DateRangeType") == cassandra.cqltypes.DateRangeType + ) + + assert str(lookup_casstype_simple("unknown")) == str( + cassandra.cqltypes.mkUnrecognizedType("unknown") + ) def test_lookup_casstype(self): """ Ensure lookup_casstype returns the correct classes """ - assert lookup_casstype('AsciiType') == cassandra.cqltypes.AsciiType - assert lookup_casstype('LongType') == cassandra.cqltypes.LongType - assert lookup_casstype('BytesType') == cassandra.cqltypes.BytesType - assert lookup_casstype('BooleanType') == cassandra.cqltypes.BooleanType - assert lookup_casstype('CounterColumnType') == cassandra.cqltypes.CounterColumnType - assert lookup_casstype('DateType') == cassandra.cqltypes.DateType - assert lookup_casstype('DecimalType') == cassandra.cqltypes.DecimalType - assert lookup_casstype('DoubleType') == cassandra.cqltypes.DoubleType - assert lookup_casstype('FloatType') == cassandra.cqltypes.FloatType - assert lookup_casstype('InetAddressType') == cassandra.cqltypes.InetAddressType - assert lookup_casstype('Int32Type') == cassandra.cqltypes.Int32Type - assert lookup_casstype('UTF8Type') == cassandra.cqltypes.UTF8Type - assert lookup_casstype('DateType') == cassandra.cqltypes.DateType - assert lookup_casstype('TimeType') == cassandra.cqltypes.TimeType - assert lookup_casstype('ByteType') == cassandra.cqltypes.ByteType - assert lookup_casstype('ShortType') == cassandra.cqltypes.ShortType - assert lookup_casstype('TimeUUIDType') == cassandra.cqltypes.TimeUUIDType - assert lookup_casstype('UUIDType') == cassandra.cqltypes.UUIDType - assert lookup_casstype('IntegerType') == cassandra.cqltypes.IntegerType - assert lookup_casstype('MapType') == cassandra.cqltypes.MapType - assert lookup_casstype('ListType') == cassandra.cqltypes.ListType - assert lookup_casstype('SetType') == cassandra.cqltypes.SetType - assert lookup_casstype('CompositeType') == cassandra.cqltypes.CompositeType - assert lookup_casstype('ColumnToCollectionType') == cassandra.cqltypes.ColumnToCollectionType - assert lookup_casstype('ReversedType') == cassandra.cqltypes.ReversedType - assert lookup_casstype('DurationType') == cassandra.cqltypes.DurationType - assert lookup_casstype('DateRangeType') == cassandra.cqltypes.DateRangeType - - assert str(lookup_casstype('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) + assert lookup_casstype("AsciiType") == cassandra.cqltypes.AsciiType + assert lookup_casstype("LongType") == cassandra.cqltypes.LongType + assert lookup_casstype("BytesType") == cassandra.cqltypes.BytesType + assert lookup_casstype("BooleanType") == cassandra.cqltypes.BooleanType + assert ( + lookup_casstype("CounterColumnType") == cassandra.cqltypes.CounterColumnType + ) + assert lookup_casstype("DateType") == cassandra.cqltypes.DateType + assert lookup_casstype("DecimalType") == cassandra.cqltypes.DecimalType + assert lookup_casstype("DoubleType") == cassandra.cqltypes.DoubleType + assert lookup_casstype("FloatType") == cassandra.cqltypes.FloatType + assert lookup_casstype("InetAddressType") == cassandra.cqltypes.InetAddressType + assert lookup_casstype("Int32Type") == cassandra.cqltypes.Int32Type + assert lookup_casstype("UTF8Type") == cassandra.cqltypes.UTF8Type + assert lookup_casstype("DateType") == cassandra.cqltypes.DateType + assert lookup_casstype("TimeType") == cassandra.cqltypes.TimeType + assert lookup_casstype("ByteType") == cassandra.cqltypes.ByteType + assert lookup_casstype("ShortType") == cassandra.cqltypes.ShortType + assert lookup_casstype("TimeUUIDType") == cassandra.cqltypes.TimeUUIDType + assert lookup_casstype("UUIDType") == cassandra.cqltypes.UUIDType + assert lookup_casstype("IntegerType") == cassandra.cqltypes.IntegerType + assert lookup_casstype("MapType") == cassandra.cqltypes.MapType + assert lookup_casstype("ListType") == cassandra.cqltypes.ListType + assert lookup_casstype("SetType") == cassandra.cqltypes.SetType + assert lookup_casstype("CompositeType") == cassandra.cqltypes.CompositeType + assert ( + lookup_casstype("ColumnToCollectionType") + == cassandra.cqltypes.ColumnToCollectionType + ) + assert lookup_casstype("ReversedType") == cassandra.cqltypes.ReversedType + assert lookup_casstype("DurationType") == cassandra.cqltypes.DurationType + assert lookup_casstype("DateRangeType") == cassandra.cqltypes.DateRangeType + + assert str(lookup_casstype("unknown")) == str( + cassandra.cqltypes.mkUnrecognizedType("unknown") + ) with pytest.raises(ValueError): - lookup_casstype('AsciiType~') + lookup_casstype("AsciiType~") def test_casstype_parameterized(self): - assert LongType.cass_parameterized_type_with(()) == 'LongType' - assert LongType.cass_parameterized_type_with((), full=True) == 'org.apache.cassandra.db.marshal.LongType' - assert SetType.cass_parameterized_type_with([DecimalType], full=True) == 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' + assert LongType.cass_parameterized_type_with(()) == "LongType" + assert ( + LongType.cass_parameterized_type_with((), full=True) + == "org.apache.cassandra.db.marshal.LongType" + ) + assert ( + SetType.cass_parameterized_type_with([DecimalType], full=True) + == "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)" + ) - assert LongType.cql_parameterized_type() == 'bigint' + assert LongType.cql_parameterized_type() == "bigint" subtypes = (cassandra.cqltypes.UTF8Type, cassandra.cqltypes.UTF8Type) - assert 'map' == cassandra.cqltypes.MapType.apply_parameters(subtypes).cql_parameterized_type() + assert ( + "map" + == cassandra.cqltypes.MapType.apply_parameters( + subtypes + ).cql_parameterized_type() + ) def test_datetype_from_string(self): # Ensure all formats can be parsed, without exception @@ -144,8 +204,11 @@ def test_cql_typename(self): Smoke test cql_typename """ - assert cql_typename('DateType') == 'timestamp' - assert cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)') == 'list' + assert cql_typename("DateType") == "timestamp" + assert ( + cql_typename("org.apache.cassandra.db.marshal.ListType(IntegerType)") + == "list" + ) def test_named_tuple_colname_substitution(self): colnames = ("func(abc)", "[applied]", "func(func(abc))", "foo_bar", "foo_bar_") @@ -159,7 +222,7 @@ def test_named_tuple_colname_substitution(self): def test_parse_casstype_args(self): class FooType(CassandraType): - typename = 'org.apache.cassandra.db.marshal.FooType' + typename = "org.apache.cassandra.db.marshal.FooType" def __init__(self, subtypes, names): self.subtypes = subtypes @@ -167,17 +230,28 @@ def __init__(self, subtypes, names): @classmethod def apply_parameters(cls, subtypes, names): - return cls(subtypes, [unhexlify(name.encode()) if name is not None else name for name in names]) + return cls( + subtypes, + [ + unhexlify(name.encode()) if name is not None else name + for name in names + ], + ) class BarType(FooType): - typename = 'org.apache.cassandra.db.marshal.BarType' - - ctype = parse_casstype_args(''.join(( - 'org.apache.cassandra.db.marshal.FooType(', - '63697479:org.apache.cassandra.db.marshal.UTF8Type,', - 'BarType(61646472657373:org.apache.cassandra.db.marshal.UTF8Type),', - '7a6970:org.apache.cassandra.db.marshal.UTF8Type', - ')'))) + typename = "org.apache.cassandra.db.marshal.BarType" + + ctype = parse_casstype_args( + "".join( + ( + "org.apache.cassandra.db.marshal.FooType(", + "63697479:org.apache.cassandra.db.marshal.UTF8Type,", + "BarType(61646472657373:org.apache.cassandra.db.marshal.UTF8Type),", + "7a6970:org.apache.cassandra.db.marshal.UTF8Type", + ")", + ) + ) + ) assert FooType == ctype.__class__ @@ -189,17 +263,21 @@ class BarType(FooType): assert [b"address"] == ctype.subtypes[1].names assert UTF8Type == ctype.subtypes[2] - assert [b'city', None, b'zip'] == ctype.names + assert [b"city", None, b"zip"] == ctype.names def test_parse_casstype_vector(self): - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)") + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)" + ) assert issubclass(ctype, VectorType) assert 3 == ctype.vector_size assert FloatType == ctype.subtype def test_parse_casstype_vector_of_vectors(self): inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type) + ) assert issubclass(ctype, VectorType) assert 3 == ctype.vector_size sub_ctype = ctype.subtype @@ -208,38 +286,62 @@ def test_parse_casstype_vector_of_vectors(self): assert FloatType == sub_ctype.subtype def test_empty_value(self): - assert str(EmptyValue()) == 'EMPTY' + assert str(EmptyValue()) == "EMPTY" def test_datetype(self): now_time_seconds = time.time() - now_datetime = datetime.datetime.fromtimestamp(now_time_seconds, tz=datetime.timezone.utc) + now_datetime = datetime.datetime.fromtimestamp( + now_time_seconds, tz=datetime.timezone.utc + ) # Cassandra timestamps in millis now_timestamp = now_time_seconds * 1e3 # same results serialized - assert DateType.serialize(now_datetime, 0) == DateType.serialize(now_timestamp, 0) + assert DateType.serialize(now_datetime, 0) == DateType.serialize( + now_timestamp, 0 + ) # deserialize # epoc expected = 0 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None) + assert DateType.deserialize( + int64_pack(1000 * expected), 0 + ) == datetime.datetime.fromtimestamp( + expected, tz=datetime.timezone.utc + ).replace(tzinfo=None) # beyond 32b - expected = 2 ** 33 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + expected = 2**33 + assert DateType.deserialize( + int64_pack(1000 * expected), 0 + ) == datetime.datetime( + 2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc + ).replace(tzinfo=None) # less than epoc (PYTHON-119) expected = -770172256 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType.deserialize( + int64_pack(1000 * expected), 0 + ) == datetime.datetime( + 1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc + ).replace(tzinfo=None) # work around rounding difference among Python versions (PYTHON-230) expected = 1424817268.274 - assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType.deserialize( + int64_pack(int(1000 * expected)), 0 + ) == datetime.datetime( + 2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc + ).replace(tzinfo=None) # Large date overflow (PYTHON-452) expected = 2177403010.123 - assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType.deserialize( + int64_pack(int(1000 * expected)), 0 + ) == datetime.datetime( + 2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc + ).replace(tzinfo=None) # Large timestamp precision (GH-532) - timestamps far from epoch must # not lose precision due to floating-point conversions. @@ -264,10 +366,10 @@ def test_collection_null_support(self): """ int_list = ListType.apply_parameters([Int32Type]) value = ( - int32_pack(2) + # num items - int32_pack(-1) + # size of item1 - int32_pack(4) + # size of item2 - int32_pack(42) # item2 + int32_pack(2) # num items + + int32_pack(-1) # size of item1 + + int32_pack(4) # size of item2 + + int32_pack(42) # item2 ) assert [None, 42] == int_list.deserialize(value, 3) @@ -275,57 +377,59 @@ def test_collection_null_support(self): assert {None, 42} == set(set_list.deserialize(value, 3)) value = ( - int32_pack(2) + # num items - int32_pack(4) + # key size of item1 - int32_pack(42) + # key item1 - int32_pack(-1) + # value size of item1 - int32_pack(-1) + # key size of item2 - int32_pack(4) + # value size of item2 - int32_pack(42) # value of item2 + int32_pack(2) # num items + + int32_pack(4) # key size of item1 + + int32_pack(42) # key item1 + + int32_pack(-1) # value size of item1 + + int32_pack(-1) # key size of item2 + + int32_pack(4) # value size of item2 + + int32_pack(42) # value of item2 ) map_list = MapType.apply_parameters([Int32Type, Int32Type]) - assert [(42, None), (None, 42)] == map_list.deserialize(value, 3)._items # OrderedMapSerializedKey + assert [(42, None), (None, 42)] == map_list.deserialize( + value, 3 + )._items # OrderedMapSerializedKey def test_write_read_string(self): with tempfile.TemporaryFile() as f: - value = u'test' + value = "test" write_string(f, value) f.seek(0) assert read_string(f) == value def test_write_read_longstring(self): with tempfile.TemporaryFile() as f: - value = u'test' + value = "test" write_longstring(f, value) f.seek(0) assert read_longstring(f) == value def test_write_read_stringmap(self): with tempfile.TemporaryFile() as f: - value = {'key': 'value'} + value = {"key": "value"} write_stringmap(f, value) f.seek(0) assert read_stringmap(f) == value def test_write_read_inet(self): with tempfile.TemporaryFile() as f: - value = ('192.168.1.1', 9042) + value = ("192.168.1.1", 9042) write_inet(f, value) f.seek(0) assert read_inet(f) == value with tempfile.TemporaryFile() as f: - value = ('2001:db8:0:f101::1', 9042) + value = ("2001:db8:0:f101::1", 9042) write_inet(f, value) f.seek(0) assert read_inet(f) == value def test_cql_quote(self): - assert cql_quote(u'test') == "'test'" - assert cql_quote('test') == "'test'" - assert cql_quote(0) == '0' + assert cql_quote("test") == "'test'" + assert cql_quote("test") == "'test'" + assert cql_quote(0) == "0" class VectorTests(unittest.TestCase): @@ -339,7 +443,7 @@ def _round_trip_compare_fn(self, first, second): assert first == pytest.approx(second, rel=1e-5) elif isinstance(first, list): assert len(first) == len(second) - for (felem, selem) in zip(first, second): + for felem, selem in zip(first, second): self._round_trip_compare_fn(felem, selem) elif isinstance(first, set) or isinstance(first, frozenset): assert len(first) == len(second) @@ -347,7 +451,7 @@ def _round_trip_compare_fn(self, first, second): second_norm = self._normalize_set(second) assert first_norm == second_norm elif isinstance(first, dict): - for ((fk,fv), (sk,sv)) in zip(first.items(), second.items()): + for (fk, fv), (sk, sv) in zip(first.items(), second.items()): self._round_trip_compare_fn(fk, sk) self._round_trip_compare_fn(fv, sv) else: @@ -361,91 +465,141 @@ def _round_trip_test(self, data, ctype_str): assert serialized_size * len(data) == len(data_bytes) result = ctype.deserialize(data_bytes, 0) assert len(data) == len(result) - for idx in range(0,len(data)): + for idx in range(0, len(data)): self._round_trip_compare_fn(data[idx], result[idx]) def test_round_trip_basic_types_with_fixed_serialized_size(self): - self._round_trip_test([True, False, False, True], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)") - self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)") - self._round_trip_test([3, 2, 41, 12], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)") - self._round_trip_test([3, 2, 41, 12], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)") - self._round_trip_test([uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)") - self._round_trip_test([3, 2, 41, 12], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)") + self._round_trip_test( + [True, False, False, True], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)", + ) + self._round_trip_test( + [3.4, 2.9, 41.6, 12.0], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)", + ) + self._round_trip_test( + [3.4, 2.9, 41.6, 12.0], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)", + ) + self._round_trip_test( + [3, 2, 41, 12], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)", + ) + self._round_trip_test( + [3, 2, 41, 12], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)", + ) + self._round_trip_test( + [uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)", + ) + self._round_trip_test( + [3, 2, 41, 12], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)", + ) def test_round_trip_basic_types_without_fixed_serialized_size(self): # Varints - self._round_trip_test([3, 2, 41, 12], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + self._round_trip_test( + [3, 2, 41, 12], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)", + ) # ASCII text - self._round_trip_test(["abc", "def", "ghi", "jkl"], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)") + self._round_trip_test( + ["abc", "def", "ghi", "jkl"], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)", + ) # UTF8 text - self._round_trip_test(["abc", "def", "ghi", "jkl"], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)") + self._round_trip_test( + ["abc", "def", "ghi", "jkl"], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)", + ) # Time is something of a weird one. By rights it should be a fixed size type but C* code marks it as variable # size. We're forced to follow the C* code base (since that's who'll be providing the data we're parsing) so # we match what they're doing. - self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)") + self._round_trip_test( + [datetime.time(1, 1, 1), datetime.time(2, 2, 2), datetime.time(3, 3, 3)], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)", + ) # Duration (containts varints) - self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)") + self._round_trip_test( + [util.Duration(1, 1, 1), util.Duration(2, 2, 2), util.Duration(3, 3, 3)], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)", + ) def test_round_trip_collection_types(self): # List (subtype of fixed size) - self._round_trip_test([[1, 2, 3, 4], [5, 6], [7, 8, 9, 10], [11, 12]], \ + self._round_trip_test( + [[1, 2, 3, 4], [5, 6], [7, 8, 9, 10], [11, 12]], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType \ - (org.apache.cassandra.db.marshal.Int32Type), 4)") + (org.apache.cassandra.db.marshal.Int32Type), 4)", + ) # Set (subtype of fixed size) - self._round_trip_test([set([1, 2, 3, 4]), set([5, 6]), set([7, 8, 9, 10]), set([11, 12])], \ + self._round_trip_test( + [set([1, 2, 3, 4]), set([5, 6]), set([7, 8, 9, 10]), set([11, 12])], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType \ - (org.apache.cassandra.db.marshal.Int32Type), 4)") + (org.apache.cassandra.db.marshal.Int32Type), 4)", + ) # Map (subtype of fixed size) - self._round_trip_test([{1:1.2}, {2:3.4}, {3:5.6}, {4:7.8}], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ - (org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)") + self._round_trip_test( + [{1: 1.2}, {2: 3.4}, {3: 5.6}, {4: 7.8}], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ + (org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)", + ) # List (subtype without fixed size) - self._round_trip_test([["one","two"], ["three","four"], ["five","six"], ["seven","eight"]], \ + self._round_trip_test( + [["one", "two"], ["three", "four"], ["five", "six"], ["seven", "eight"]], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType \ - (org.apache.cassandra.db.marshal.AsciiType), 4)") + (org.apache.cassandra.db.marshal.AsciiType), 4)", + ) # Set (subtype without fixed size) - self._round_trip_test([set(["one","two"]), set(["three","four"]), set(["five","six"]), set(["seven","eight"])], \ + self._round_trip_test( + [ + set(["one", "two"]), + set(["three", "four"]), + set(["five", "six"]), + set(["seven", "eight"]), + ], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType \ - (org.apache.cassandra.db.marshal.AsciiType), 4)") + (org.apache.cassandra.db.marshal.AsciiType), 4)", + ) # Map (subtype without fixed size) - self._round_trip_test([{1:"one"}, {2:"two"}, {3:"three"}, {4:"four"}], \ - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ - (org.apache.cassandra.db.marshal.IntegerType,org.apache.cassandra.db.marshal.AsciiType), 4)") + self._round_trip_test( + [{1: "one"}, {2: "two"}, {3: "three"}, {4: "four"}], + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ + (org.apache.cassandra.db.marshal.IntegerType,org.apache.cassandra.db.marshal.AsciiType), 4)", + ) # List of lists (subtype without fixed size) - data = [[["one","two"],["three"]], [["four"],["five"]], [["six","seven","eight"]], [["nine"]]] + data = [ + [["one", "two"], ["three"]], + [["four"], ["five"]], + [["six", "seven", "eight"]], + [["nine"]], + ] ctype = "org.apache.cassandra.db.marshal.VectorType\ (org.apache.cassandra.db.marshal.ListType\ (org.apache.cassandra.db.marshal.ListType\ (org.apache.cassandra.db.marshal.AsciiType)), 4)" self._round_trip_test(data, ctype) # Set of sets (subtype without fixed size) - data = [set([frozenset(["one","two"]),frozenset(["three"])]),\ - set([frozenset(["four"]),frozenset(["five"])]),\ - set([frozenset(["six","seven","eight"])]), - set([frozenset(["nine"])])] + data = [ + set([frozenset(["one", "two"]), frozenset(["three"])]), + set([frozenset(["four"]), frozenset(["five"])]), + set([frozenset(["six", "seven", "eight"])]), + set([frozenset(["nine"])]), + ] ctype = "org.apache.cassandra.db.marshal.VectorType\ (org.apache.cassandra.db.marshal.SetType\ (org.apache.cassandra.db.marshal.SetType\ (org.apache.cassandra.db.marshal.AsciiType)), 4)" self._round_trip_test(data, ctype) # Map of maps (subtype without fixed size) - data = [{100:{1:"one",2:"two",3:"three"}},\ - {200:{4:"four",5:"five"}},\ - {300:{}},\ - {400:{6:"six"}}] + data = [ + {100: {1: "one", 2: "two", 3: "three"}}, + {200: {4: "four", 5: "five"}}, + {300: {}}, + {400: {6: "six"}}, + ] ctype = "org.apache.cassandra.db.marshal.VectorType\ (org.apache.cassandra.db.marshal.MapType\ (org.apache.cassandra.db.marshal.Int32Type,\ @@ -455,75 +609,228 @@ def test_round_trip_collection_types(self): def test_round_trip_vector_of_vectors(self): # Subytpes of subtypes with a fixed size - self._round_trip_test([[1.2, 3.4], [5.6, 7.8], [9.10, 11.12], [13.14, 15.16]], \ + self._round_trip_test( + [[1.2, 3.4], [5.6, 7.8], [9.10, 11.12], [13.14, 15.16]], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \ - (org.apache.cassandra.db.marshal.FloatType,2), 4)") + (org.apache.cassandra.db.marshal.FloatType,2), 4)", + ) # Subytpes of subtypes without a fixed size - self._round_trip_test([["one", "two"], ["three", "four"], ["five", "six"], ["seven", "eight"]], \ + self._round_trip_test( + [["one", "two"], ["three", "four"], ["five", "six"], ["seven", "eight"]], "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \ - (org.apache.cassandra.db.marshal.AsciiType,2), 4)") + (org.apache.cassandra.db.marshal.AsciiType,2), 4)", + ) # parse_casstype_args() is tested above... we're explicitly concerned about cql_parapmeterized_type() output here def test_cql_parameterized_type(self): # Base vector functionality - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - assert ctype.cql_parameterized_type() == "org.apache.cassandra.db.marshal.VectorType" + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ) + assert ( + ctype.cql_parameterized_type() + == "org.apache.cassandra.db.marshal.VectorType" + ) # Test vector-of-vectors inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type) + ) inner_parsed_type = "org.apache.cassandra.db.marshal.VectorType" - assert ctype.cql_parameterized_type() == "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type) + assert ( + ctype.cql_parameterized_type() + == "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type) + ) def test_serialization_fixed_size_too_small(self): - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") - with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4"): + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)" + ) + with pytest.raises( + ValueError, + match="Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4", + ): ctype.serialize([1.2, 3.4, 5.6, 7.8], 0) def test_serialization_fixed_size_too_big(self): - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5"): + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ) + with pytest.raises( + ValueError, + match="Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5", + ): ctype.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) def test_serialization_variable_size_too_small(self): - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") - with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4"): + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)" + ) + with pytest.raises( + ValueError, + match="Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4", + ): ctype.serialize([1, 2, 3, 4], 0) def test_serialization_variable_size_too_big(self): - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") - with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5"): + ctype = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)" + ) + with pytest.raises( + ValueError, + match="Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5", + ): ctype.serialize([1, 2, 3, 4, 5], 0) def test_deserialization_fixed_size_too_small(self): - ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + ctype_four = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ) ctype_four_bytes = ctype_four.serialize([1.2, 3.4, 5.6, 7.8], 0) - ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") - with pytest.raises(ValueError, match="Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead"): + ctype_five = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)" + ) + with pytest.raises( + ValueError, + match="Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead", + ): ctype_five.deserialize(ctype_four_bytes, 0) def test_deserialization_fixed_size_too_big(self): - ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") + ctype_five = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)" + ) ctype_five_bytes = ctype_five.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) - ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - with pytest.raises(ValueError, match="Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead"): + ctype_four = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ) + with pytest.raises( + ValueError, + match="Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead", + ): ctype_four.deserialize(ctype_five_bytes, 0) def test_deserialization_variable_size_too_small(self): - ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + ctype_four = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)" + ) ctype_four_bytes = ctype_four.serialize([1, 2, 3, 4], 0) - ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") - with pytest.raises(ValueError, match="Error reading additional data during vector deserialization after successfully adding 4 elements"): + ctype_five = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)" + ) + with pytest.raises( + ValueError, + match="Error reading additional data during vector deserialization after successfully adding 4 elements", + ): ctype_five.deserialize(ctype_four_bytes, 0) def test_deserialization_variable_size_too_big(self): - ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") + ctype_five = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)" + ) ctype_five_bytes = ctype_five.serialize([1, 2, 3, 4, 5], 0) - ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") - with pytest.raises(ValueError, match="Additional bytes remaining after vector deserialization completed"): + ctype_four = parse_casstype_args( + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)" + ) + with pytest.raises( + ValueError, + match="Additional bytes remaining after vector deserialization completed", + ): ctype_four.deserialize(ctype_five_bytes, 0) + def test_vector_cython_deserializer_variable_size_subtype(self): + """ + Test that DesVectorType falls back gracefully for variable-size subtypes. + Variable-size types (e.g. UTF8Type) are not supported by the Cython fast path + and should raise ValueError from DesVectorType._deserialize_generic. + The pure Python VectorType.deserialize handles these correctly. + + Note: This test is forward-looking — it validates the Cython deserializer + that is introduced in a companion PR. The skipTest guards below ensure + the test is silently skipped when the extension is not yet compiled. + + @since 3.x + @expected_result Cython deserializer raises ValueError for variable-size subtypes; + pure Python path correctly deserializes them + + @test_category data_types:vector + """ + try: + from cassandra.deserializers import find_deserializer, DesVectorType + except ImportError: + self.skipTest("Cython deserializers not available (no compiled extension)") + + vt_text = VectorType.apply_parameters(["UTF8Type", 3], {}) + des_text = find_deserializer(vt_text) + if not isinstance(des_text, DesVectorType): + self.skipTest( + "DesVectorType not available (Cython VectorType deserializer not compiled)" + ) + self.assertEqual(des_text.__class__.__name__, "DesVectorType") + + # Cython path should raise for variable-size subtypes + data = vt_text.serialize(["abc", "def", "ghi"], 5) + with self.assertRaises(ValueError) as cm: + des_text.deserialize_bytes(data, 5) + self.assertIn("variable-size subtype", str(cm.exception)) + + # Pure Python path should work correctly + result = vt_text.deserialize(data, 5) + self.assertEqual(result, ["abc", "def", "ghi"]) + + def test_vector_numpy_large_deserialization(self): + """ + Test that large vectors (>= 32 elements) use the numpy deserialization path + and return correct results for all supported numeric types. + + @since 3.x + @expected_result Large vectors are correctly deserialized (via numpy when available) + + @test_category data_types:vector + """ + import struct + + vector_size = 64 # >= 32 threshold for numpy path + + # Float vector + float_data = list(range(vector_size)) + float_values = [float(x) for x in float_data] + vt_float = VectorType.apply_parameters(["FloatType", vector_size], {}) + packed = struct.pack(">%df" % vector_size, *float_values) + result = vt_float.deserialize(packed, 5) + self.assertEqual(len(result), vector_size) + for i in range(vector_size): + self.assertAlmostEqual(result[i], float_values[i], places=5) + + # Double vector + double_values = [float(x) * 1.1 for x in range(vector_size)] + vt_double = VectorType.apply_parameters(["DoubleType", vector_size], {}) + packed = struct.pack(">%dd" % vector_size, *double_values) + result = vt_double.deserialize(packed, 5) + self.assertEqual(len(result), vector_size) + for i in range(vector_size): + self.assertAlmostEqual(result[i], double_values[i], places=10) + + # Int32 vector + int32_values = list(range(vector_size)) + vt_int32 = VectorType.apply_parameters(["Int32Type", vector_size], {}) + packed = struct.pack(">%di" % vector_size, *int32_values) + result = vt_int32.deserialize(packed, 5) + self.assertEqual(result, int32_values) + + # Int64/Long vector + int64_values = list(range(vector_size)) + vt_int64 = VectorType.apply_parameters(["LongType", vector_size], {}) + packed = struct.pack(">%dq" % vector_size, *int64_values) + result = vt_int64.deserialize(packed, 5) + self.assertEqual(result, int64_values) + + # ShortType skipped: serial_size() returns None (pre-existing bug), + # so VectorType.deserialize takes the variable-size path which fails. + # ShortType struct.unpack works for small vectors via _vector_struct. + ZERO = datetime.timedelta(0) @@ -558,8 +865,7 @@ def test_month_rounding_creation_failure(self): feb_stamp = ms_timestamp_from_datetime( datetime.datetime(2018, 2, 25, 18, 59, 59, 0) ) - dr = DateRange(OPEN_BOUND, - DateRangeBound(feb_stamp, DateRangePrecision.MONTH)) + dr = DateRange(OPEN_BOUND, DateRangeBound(feb_stamp, DateRangePrecision.MONTH)) dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) assert dt.day == 28 @@ -567,155 +873,146 @@ def test_month_rounding_creation_failure(self): feb_stamp_leap_year = ms_timestamp_from_datetime( datetime.datetime(2016, 2, 25, 18, 59, 59, 0) ) - dr = DateRange(OPEN_BOUND, - DateRangeBound(feb_stamp_leap_year, DateRangePrecision.MONTH)) + dr = DateRange( + OPEN_BOUND, DateRangeBound(feb_stamp_leap_year, DateRangePrecision.MONTH) + ) dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) assert dt.day == 29 def test_decode_precision(self): - assert DateRangeType._decode_precision(6) == 'MILLISECOND' + assert DateRangeType._decode_precision(6) == "MILLISECOND" def test_decode_precision_error(self): with pytest.raises(ValueError): DateRangeType._decode_precision(-1) def test_encode_precision(self): - assert DateRangeType._encode_precision('SECOND') == 5 + assert DateRangeType._encode_precision("SECOND") == 5 def test_encode_precision_error(self): with pytest.raises(ValueError): - DateRangeType._encode_precision('INVALID') + DateRangeType._encode_precision("INVALID") def test_deserialize_single_value(self): - serialized = (int8_pack(0) + - int64_pack(self.timestamp) + - int8_pack(3)) - assert DateRangeType.deserialize(serialized, 5) == util.DateRange(value=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), - precision='HOUR') + serialized = int8_pack(0) + int64_pack(self.timestamp) + int8_pack(3) + assert DateRangeType.deserialize(serialized, 5) == util.DateRange( + value=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), + precision="HOUR", + ) ) def test_deserialize_closed_range(self): - serialized = (int8_pack(1) + - int64_pack(self.timestamp) + - int8_pack(2) + - int64_pack(self.timestamp) + - int8_pack(6)) + serialized = ( + int8_pack(1) + + int64_pack(self.timestamp) + + int8_pack(2) + + int64_pack(self.timestamp) + + int8_pack(6) + ) assert DateRangeType.deserialize(serialized, 5) == util.DateRange( lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 0, 0), - precision='DAY' + value=datetime.datetime(2017, 2, 1, 0, 0), precision="DAY" ), upper_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), - precision='MILLISECOND' - ) + precision="MILLISECOND", + ), ) def test_deserialize_open_high(self): - serialized = (int8_pack(2) + - int64_pack(self.timestamp) + - int8_pack(3)) + serialized = int8_pack(2) + int64_pack(self.timestamp) + int8_pack(3) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 0), - precision='HOUR' + value=datetime.datetime(2017, 2, 1, 15, 0), precision="HOUR" ), - upper_bound=util.OPEN_BOUND + upper_bound=util.OPEN_BOUND, ) def test_deserialize_open_low(self): - serialized = (int8_pack(3) + - int64_pack(self.timestamp) + - int8_pack(4)) + serialized = int8_pack(3) + int64_pack(self.timestamp) + int8_pack(4) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( lower_bound=util.OPEN_BOUND, upper_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 42, 20, 1000), - precision='MINUTE' - ) + precision="MINUTE", + ), ) def test_deserialize_single_open(self): - assert util.DateRange(value=util.OPEN_BOUND) == DateRangeType.deserialize(int8_pack(5), 5) + assert util.DateRange(value=util.OPEN_BOUND) == DateRangeType.deserialize( + int8_pack(5), 5 + ) def test_serialize_single_value(self): - serialized = (int8_pack(0) + - int64_pack(self.timestamp) + - int8_pack(5)) + serialized = int8_pack(0) + int64_pack(self.timestamp) + int8_pack(5) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( value=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12), - precision='SECOND' + value=datetime.datetime(2017, 2, 1, 15, 42, 12), precision="SECOND" ) ) def test_serialize_closed_range(self): - serialized = (int8_pack(1) + - int64_pack(self.timestamp) + - int8_pack(5) + - int64_pack(self.timestamp) + - int8_pack(0)) + serialized = ( + int8_pack(1) + + int64_pack(self.timestamp) + + int8_pack(5) + + int64_pack(self.timestamp) + + int8_pack(0) + ) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12), - precision='SECOND' + value=datetime.datetime(2017, 2, 1, 15, 42, 12), precision="SECOND" ), upper_bound=util.DateRangeBound( - value=datetime.datetime(2017, 12, 31), - precision='YEAR' - ) + value=datetime.datetime(2017, 12, 31), precision="YEAR" + ), ) def test_serialize_open_high(self): - serialized = (int8_pack(2) + - int64_pack(self.timestamp) + - int8_pack(2)) + serialized = int8_pack(2) + int64_pack(self.timestamp) + int8_pack(2) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1), - precision='DAY' + value=datetime.datetime(2017, 2, 1), precision="DAY" ), - upper_bound=util.OPEN_BOUND + upper_bound=util.OPEN_BOUND, ) def test_serialize_open_low(self): - serialized = (int8_pack(2) + - int64_pack(self.timestamp) + - int8_pack(3)) + serialized = int8_pack(2) + int64_pack(self.timestamp) + int8_pack(3) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15), - precision='HOUR' + value=datetime.datetime(2017, 2, 1, 15), precision="HOUR" ), - upper_bound=util.OPEN_BOUND + upper_bound=util.OPEN_BOUND, ) def test_deserialize_both_open(self): - serialized = (int8_pack(4)) + serialized = int8_pack(4) deserialized = DateRangeType.deserialize(serialized, 5) assert deserialized == util.DateRange( - lower_bound=util.OPEN_BOUND, - upper_bound=util.OPEN_BOUND + lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND ) def test_serialize_single_open(self): - serialized = DateRangeType.serialize(util.DateRange( - value=util.OPEN_BOUND, - ), 5) + serialized = DateRangeType.serialize( + util.DateRange( + value=util.OPEN_BOUND, + ), + 5, + ) assert int8_pack(5) == serialized def test_serialize_both_open(self): - serialized = DateRangeType.serialize(util.DateRange( - lower_bound=util.OPEN_BOUND, - upper_bound=util.OPEN_BOUND - ), 5) + serialized = DateRangeType.serialize( + util.DateRange(lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND), 5 + ) assert int8_pack(4) == serialized def test_failure_to_serialize_no_value_object(self): @@ -725,14 +1022,19 @@ def test_failure_to_serialize_no_value_object(self): def test_failure_to_serialize_no_bounds_object(self): class no_bounds_object(object): value = lower_bound = None + with pytest.raises(ValueError): DateRangeType.serialize(no_bounds_object, 5) def test_serialized_value_round_trip(self): - vals = [b'\x01\x00\x00\x01%\xe9a\xf9\xd1\x06\x00\x00\x01v\xbb>o\xff\x00', - b'\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00'] + vals = [ + b"\x01\x00\x00\x01%\xe9a\xf9\xd1\x06\x00\x00\x01v\xbb>o\xff\x00", + b"\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00", + ] for serialized in vals: - assert serialized == DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) + assert serialized == DateRangeType.serialize( + DateRangeType.deserialize(serialized, 0), 0 + ) def test_serialize_zero_datetime(self): """ @@ -746,10 +1048,13 @@ def test_serialize_zero_datetime(self): @test_category data_types """ - DateRangeType.serialize(util.DateRange( - lower_bound=(datetime.datetime(1970, 1, 1), 'YEAR'), - upper_bound=(datetime.datetime(1970, 1, 1), 'YEAR') - ), 5) + DateRangeType.serialize( + util.DateRange( + lower_bound=(datetime.datetime(1970, 1, 1), "YEAR"), + upper_bound=(datetime.datetime(1970, 1, 1), "YEAR"), + ), + 5, + ) def test_deserialize_zero_datetime(self): """ @@ -764,10 +1069,14 @@ def test_deserialize_zero_datetime(self): @test_category data_types """ DateRangeType.deserialize( - (int8_pack(1) + - int64_pack(0) + int8_pack(0) + - int64_pack(0) + int8_pack(0)), - 5 + ( + int8_pack(1) + + int64_pack(0) + + int8_pack(0) + + int64_pack(0) + + int8_pack(0) + ), + 5, ) @@ -801,8 +1110,10 @@ def test_deserialize_date_range_milliseconds(self): for i in range(1000): lower_value = self.starting_lower_value + i upper_value = self.starting_upper_value + i - dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.MILLISECOND), - DateRangeBound(upper_value, DateRangePrecision.MILLISECOND)) + dr = DateRange( + DateRangeBound(lower_value, DateRangePrecision.MILLISECOND), + DateRangeBound(upper_value, DateRangePrecision.MILLISECOND), + ) assert lower_value == dr.lower_bound.milliseconds assert upper_value == dr.upper_bound.milliseconds @@ -821,13 +1132,15 @@ def truncate_last_figures(number, n=3): """ Truncates last n digits of a number """ - return int(str(number)[:-n] + '0' * n) + return int(str(number)[:-n] + "0" * n) for i in range(1000): lower_value = self.starting_lower_value + i * 900 upper_value = self.starting_upper_value + i * 900 - dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.SECOND), - DateRangeBound(upper_value, DateRangePrecision.SECOND)) + dr = DateRange( + DateRangeBound(lower_value, DateRangePrecision.SECOND), + DateRangeBound(upper_value, DateRangePrecision.SECOND), + ) assert truncate_last_figures(lower_value) == dr.lower_bound.milliseconds upper_value = truncate_last_figures(upper_value) + 999 @@ -843,12 +1156,14 @@ def test_deserialize_date_range_minutes(self): @test_category data_types """ - self._deserialize_date_range({"second": 0, "microsecond": 0}, - DateRangePrecision.MINUTE, - # This lambda function given a truncated date adds - # one day minus one microsecond in microseconds - lambda x: x + 59 * 1000 + 999, - lambda original_value, i: original_value + i * 900 * 50) + self._deserialize_date_range( + {"second": 0, "microsecond": 0}, + DateRangePrecision.MINUTE, + # This lambda function given a truncated date adds + # one day minus one microsecond in microseconds + lambda x: x + 59 * 1000 + 999, + lambda original_value, i: original_value + i * 900 * 50, + ) def test_deserialize_date_range_hours(self): """ @@ -860,15 +1175,14 @@ def test_deserialize_date_range_hours(self): @test_category data_types """ - self._deserialize_date_range({"minute": 0, "second": 0, "microsecond": 0}, - DateRangePrecision.HOUR, - # This lambda function given a truncated date adds - # one hour minus one microsecond in microseconds - lambda x: x + - 59 * 60 * 1000 + - 59 * 1000 + - 999, - lambda original_value, i: original_value + i * 900 * 50 * 60) + self._deserialize_date_range( + {"minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.HOUR, + # This lambda function given a truncated date adds + # one hour minus one microsecond in microseconds + lambda x: x + 59 * 60 * 1000 + 59 * 1000 + 999, + lambda original_value, i: original_value + i * 900 * 50 * 60, + ) def test_deserialize_date_range_day(self): """ @@ -880,16 +1194,14 @@ def test_deserialize_date_range_day(self): @test_category data_types """ - self._deserialize_date_range({"hour": 0, "minute": 0, "second": 0, "microsecond": 0}, - DateRangePrecision.DAY, - # This lambda function given a truncated date adds - # one day minus one microsecond in microseconds - lambda x: x + - 23 * 60 * 60 * 1000 + - 59 * 60 * 1000 + - 59 * 1000 + - 999, - lambda original_value, i: original_value + i * 900 * 50 * 60 * 24) + self._deserialize_date_range( + {"hour": 0, "minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.DAY, + # This lambda function given a truncated date adds + # one day minus one microsecond in microseconds + lambda x: x + 23 * 60 * 60 * 1000 + 59 * 60 * 1000 + 59 * 1000 + 999, + lambda original_value, i: original_value + i * 900 * 50 * 60 * 24, + ) @unittest.skip("This is currently failig, see PYTHON-912") def test_deserialize_date_range_month(self): @@ -902,6 +1214,7 @@ def test_deserialize_date_range_month(self): @test_category data_types """ + def get_upper_bound(seconds): """ function that given a truncated date in seconds from the epoch returns that same date @@ -914,10 +1227,13 @@ def get_upper_bound(seconds): dt = dt + datetime.timedelta(days=32) dt = dt.replace(day=1) - datetime.timedelta(microseconds=1) return int((dt - self.epoch).total_seconds() * 1000) - self._deserialize_date_range({"day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}, - DateRangePrecision.MONTH, - get_upper_bound, - lambda original_value, i: original_value + i * 900 * 50 * 60 * 24 * 30) + + self._deserialize_date_range( + {"day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.MONTH, + get_upper_bound, + lambda original_value, i: original_value + i * 900 * 50 * 60 * 24 * 30, + ) def test_deserialize_date_range_year(self): """ @@ -929,6 +1245,7 @@ def test_deserialize_date_range_year(self): @test_category data_types """ + def get_upper_bound(seconds): """ function that given a truncated date in seconds from the epoch returns that same date @@ -944,14 +1261,31 @@ def get_upper_bound(seconds): diff = time.mktime(dt.timetuple()) - time.mktime(self.epoch.timetuple()) return diff * 1000 + 999 # This doesn't work for big values because it loses precision - #return int((dt - self.epoch).total_seconds() * 1000) - self._deserialize_date_range({"month": 1, "day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}, - DateRangePrecision.YEAR, - get_upper_bound, - lambda original_value, i: original_value + i * 900 * 50 * 60 * 24 * 30 * 12 * 7) - - def _deserialize_date_range(self, truncate_kwargs, precision, - round_up_truncated_upper_value, increment_loop_variable): + # return int((dt - self.epoch).total_seconds() * 1000) + + self._deserialize_date_range( + { + "month": 1, + "day": 1, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + }, + DateRangePrecision.YEAR, + get_upper_bound, + lambda original_value, i: ( + original_value + i * 900 * 50 * 60 * 24 * 30 * 12 * 7 + ), + ) + + def _deserialize_date_range( + self, + truncate_kwargs, + precision, + round_up_truncated_upper_value, + increment_loop_variable, + ): """ This functions iterates over several DateRange objects determined by lower_value upper_value which are given as a value that represents seconds since the epoch. @@ -994,8 +1328,10 @@ def truncate_date(number): upper_value = increment_loop_variable(self.starting_upper_value, i) # Inside the __init__ for DateRange the rounding up and down should happen - dr = DateRange(DateRangeBound(lower_value, precision), - DateRangeBound(upper_value, precision)) + dr = DateRange( + DateRangeBound(lower_value, precision), + DateRangeBound(upper_value, precision), + ) # We verify that rounded value corresponds with what we would expect assert truncate_date(lower_value) == dr.lower_bound.milliseconds @@ -1017,11 +1353,18 @@ def test_host_order(self): @test_category data_types """ - hosts = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in - ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] - hosts_equal = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in - ("127.0.0.1", "127.0.0.1")] - hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4())] + hosts = [ + Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) + for addr in ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4") + ] + hosts_equal = [ + Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) + for addr in ("127.0.0.1", "127.0.0.1") + ] + hosts_equal_conviction = [ + Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4()), + ] check_sequence_consistency(hosts) check_sequence_consistency(hosts_equal, equal=True) check_sequence_consistency(hosts_equal_conviction, equal=True) @@ -1036,7 +1379,12 @@ def test_date_order(self): @test_category data_types """ - dates_from_string = [Date("2017-01-01"), Date("2017-01-05"), Date("2017-01-09"), Date("2017-01-13")] + dates_from_string = [ + Date("2017-01-01"), + Date("2017-01-05"), + Date("2017-01-09"), + Date("2017-01-13"), + ] dates_from_string_equal = [Date("2017-01-01"), Date("2017-01-01")] check_sequence_consistency(dates_from_string) check_sequence_consistency(dates_from_string_equal, equal=True) @@ -1044,33 +1392,49 @@ def test_date_order(self): date_format = "%Y-%m-%d" dates_from_value = [ - Date((datetime.datetime.strptime(dtstr, date_format) - - datetime.datetime(1970, 1, 1)).days) + Date( + ( + datetime.datetime.strptime(dtstr, date_format) + - datetime.datetime(1970, 1, 1) + ).days + ) for dtstr in ("2017-01-02", "2017-01-06", "2017-01-10", "2017-01-14") ] dates_from_value_equal = [Date(1), Date(1)] check_sequence_consistency(dates_from_value) check_sequence_consistency(dates_from_value_equal, equal=True) - dates_from_datetime = [Date(datetime.datetime.strptime(dtstr, date_format)) - for dtstr in ("2017-01-03", "2017-01-07", "2017-01-11", "2017-01-15")] - dates_from_datetime_equal = [Date(datetime.datetime.strptime("2017-01-01", date_format)), - Date(datetime.datetime.strptime("2017-01-01", date_format))] + dates_from_datetime = [ + Date(datetime.datetime.strptime(dtstr, date_format)) + for dtstr in ("2017-01-03", "2017-01-07", "2017-01-11", "2017-01-15") + ] + dates_from_datetime_equal = [ + Date(datetime.datetime.strptime("2017-01-01", date_format)), + Date(datetime.datetime.strptime("2017-01-01", date_format)), + ] check_sequence_consistency(dates_from_datetime) check_sequence_consistency(dates_from_datetime_equal, equal=True) dates_from_date = [ - Date(datetime.datetime.strptime(dtstr, date_format).date()) for dtstr in - ("2017-01-04", "2017-01-08", "2017-01-12", "2017-01-16") + Date(datetime.datetime.strptime(dtstr, date_format).date()) + for dtstr in ("2017-01-04", "2017-01-08", "2017-01-12", "2017-01-16") + ] + dates_from_date_equal = [ + datetime.datetime.strptime(dtstr, date_format) + for dtstr in ("2017-01-09", "2017-01-9") ] - dates_from_date_equal = [datetime.datetime.strptime(dtstr, date_format) for dtstr in - ("2017-01-09", "2017-01-9")] check_sequence_consistency(dates_from_date) check_sequence_consistency(dates_from_date_equal, equal=True) - check_sequence_consistency(self._shuffle_lists(dates_from_string, dates_from_value, - dates_from_datetime, dates_from_date)) + check_sequence_consistency( + self._shuffle_lists( + dates_from_string, + dates_from_value, + dates_from_datetime, + dates_from_date, + ) + ) def test_timer_order(self): """ @@ -1087,20 +1451,33 @@ def test_timer_order(self): check_sequence_consistency(time_from_int) check_sequence_consistency(time_from_int_equal, equal=True) - time_from_datetime = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) - for us in (2, 5, 8, 11)] - time_from_datetime_equal = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) - for us in (1, 1)] + time_from_datetime = [ + Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (2, 5, 8, 11) + ] + time_from_datetime_equal = [ + Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (1, 1) + ] check_sequence_consistency(time_from_datetime) check_sequence_consistency(time_from_datetime_equal, equal=True) - time_from_string = [Time("00:00:00.000003000"), Time("00:00:00.000006000"), - Time("00:00:00.000009000"), Time("00:00:00.000012000")] - time_from_string_equal = [Time("00:00:00.000004000"), Time("00:00:00.000004000")] + time_from_string = [ + Time("00:00:00.000003000"), + Time("00:00:00.000006000"), + Time("00:00:00.000009000"), + Time("00:00:00.000012000"), + ] + time_from_string_equal = [ + Time("00:00:00.000004000"), + Time("00:00:00.000004000"), + ] check_sequence_consistency(time_from_string) check_sequence_consistency(time_from_string_equal, equal=True) - check_sequence_consistency(self._shuffle_lists(time_from_int, time_from_datetime, time_from_string)) + check_sequence_consistency( + self._shuffle_lists(time_from_int, time_from_datetime, time_from_string) + ) def test_token_order(self): """