diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..7bd4597399 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -344,7 +344,7 @@ def _partition_summary(self, update_metrics: UpdateMetrics) -> str: def update_snapshot_summaries(summary: Summary, previous_summary: Mapping[str, str] | None = None) -> Summary: - if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE}: + if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE, Operation.REPLACE}: raise ValueError(f"Operation not implemented: {summary.operation}") if not previous_summary: diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 37d120969a..9de9da00af 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -667,6 +667,96 @@ def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: return [] +class _RewriteFiles(_SnapshotProducer["_RewriteFiles"]): + """A snapshot producer that rewrites data files.""" + + def __init__(self, operation: Operation, transaction: Transaction, io: FileIO, snapshot_properties: dict[str, str]): + super().__init__(operation, transaction, io, snapshot_properties=snapshot_properties) + + def _commit(self) -> UpdatesAndRequirements: + # Only produce a commit when there is something to rewrite + if self._deleted_data_files or self._added_data_files: + # Grab the entries that we actually found in the table's manifests + deleted_entries = self._deleted_entries() + found_deleted_files = {entry.data_file for entry in deleted_entries} + + # If the user asked to delete files that aren't in the table, abort. + if len(found_deleted_files) != len(self._deleted_data_files): + raise ValueError("Cannot delete files that are not present in the table") + + added_records = sum(f.record_count for f in self._added_data_files) + deleted_records = sum(entry.data_file.record_count for entry in deleted_entries) + + if added_records > deleted_records: + raise ValueError(f"Invalid replace: records added ({added_records}) exceeds records removed ({deleted_records})") + + return super()._commit() + else: + return (), () + + def _deleted_entries(self) -> list[ManifestEntry]: + """Check if we need to mark the files as deleted.""" + if self._parent_snapshot_id is not None: + previous_snapshot = self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) + if previous_snapshot is None: + raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") + + executor = ExecutorFactory.get_or_create() + + def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: + return [ + ManifestEntry.from_args( + status=ManifestEntryStatus.DELETED, + snapshot_id=self.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + for entry in manifest.fetch_manifest_entry(self._io, discard_deleted=True) + if entry.data_file.content == DataFileContent.DATA and entry.data_file in self._deleted_data_files + ] + + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._io)) + return list(itertools.chain(*list_of_entries)) + else: + return [] + + def _existing_manifests(self) -> list[ManifestFile]: + """To determine if there are any existing manifests.""" + existing_files = [] + if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch): + for manifest_file in snapshot.manifests(io=self._io): + entries_to_write: set[ManifestEntry] = set() + found_deleted_entries: set[ManifestEntry] = set() + + for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): + if entry.data_file in self._deleted_data_files: + found_deleted_entries.add(entry) + else: + entries_to_write.add(entry) + + if len(found_deleted_entries) == 0: + existing_files.append(manifest_file) + continue + + if len(entries_to_write) == 0: + continue + + with self.new_manifest_writer(self.spec(manifest_file.partition_spec_id)) as writer: + for entry in entries_to_write: + writer.add_entry( + ManifestEntry.from_args( + status=ManifestEntryStatus.EXISTING, + snapshot_id=entry.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + ) + existing_files.append(writer.to_manifest_file()) + return existing_files + + class UpdateSnapshot: _transaction: Transaction _io: FileIO @@ -724,6 +814,14 @@ def delete(self) -> _DeleteFiles: snapshot_properties=self._snapshot_properties, ) + def replace(self) -> _RewriteFiles: + return _RewriteFiles( + operation=Operation.REPLACE, + transaction=self._transaction, + io=self._io, + snapshot_properties=self._snapshot_properties, + ) + class _ManifestMergeManager(Generic[U]): _target_size_bytes: int diff --git a/tests/table/test_replace.py b/tests/table/test_replace.py new file mode 100644 index 0000000000..bab1dfacdd --- /dev/null +++ b/tests/table/test_replace.py @@ -0,0 +1,458 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import cast + +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.manifest import ( + DataFile, + DataFileContent, + FileFormat, + ManifestEntry, + ManifestEntryStatus, +) +from pyiceberg.schema import Schema +from pyiceberg.table.snapshots import Operation, Snapshot, Summary +from pyiceberg.typedef import Record + + +def test_replace_internally(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace", + schema=Schema(), + ) + + # 1. File we will delete + file_to_delete = DataFile.from_args( + file_path="s3://bucket/test/data/deleted.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_delete.spec_id = 0 + + # 2. File we will leave completely untouched + file_to_keep = DataFile.from_args( + file_path="s3://bucket/test/data/kept.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=50, + file_size_in_bytes=512, + content=DataFileContent.DATA, + ) + file_to_keep.spec_id = 0 + + # 3. File we are adding as a replacement + file_to_add = DataFile.from_args( + file_path="s3://bucket/test/data/added.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_add.spec_id = 0 + + # Initially append BOTH the file to delete and the file to keep + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + append_snapshot.append_data_file(file_to_keep) + + old_snapshot = cast(Snapshot, table.current_snapshot()) + old_snapshot_id = old_snapshot.snapshot_id + old_sequence_number = cast(int, old_snapshot.sequence_number) + + # Call the internal replace API + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(file_to_add) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + # 1. Has a unique snapshot ID + assert snapshot.snapshot_id is not None + assert snapshot.snapshot_id != old_snapshot_id + + # 2. Parent points to the previous snapshot + assert snapshot.parent_snapshot_id == old_snapshot_id + + # 3. Sequence number is exactly previous + 1 + assert snapshot.sequence_number == old_sequence_number + 1 + + # 4. Operation type is set to "replace" + assert summary["operation"] == Operation.REPLACE + + # 5. Manifest list path is correct (just verify it exists and is a string path) + assert snapshot.manifest_list is not None + assert isinstance(snapshot.manifest_list, str) + + # 6. Summary counts are accurate + assert summary["added-data-files"] == "1" + assert summary["deleted-data-files"] == "1" + assert summary["added-records"] == "100" + assert summary["deleted-records"] == "100" + assert summary["total-records"] == "150" + + # Fetch all entries from the new manifests + manifest_files = snapshot.manifests(table.io) + entries: list[ManifestEntry] = [] + for manifest in manifest_files: + entries.extend(manifest.fetch_manifest_entry(table.io, discard_deleted=False)) + + # We expect 3 entries: ADDED, DELETED, and EXISTING + assert len(entries) == 3 + + # Check ADDED + added_entries = [e for e in entries if e.status == ManifestEntryStatus.ADDED] + assert len(added_entries) == 1 + assert added_entries[0].data_file.file_path == file_to_add.file_path + assert added_entries[0].snapshot_id == snapshot.snapshot_id + + # Check DELETED + deleted_entries = [e for e in entries if e.status == ManifestEntryStatus.DELETED] + assert len(deleted_entries) == 1 + assert deleted_entries[0].data_file.file_path == file_to_delete.file_path + assert deleted_entries[0].snapshot_id == snapshot.snapshot_id + + # Check EXISTING + existing_entries = [e for e in entries if e.status == ManifestEntryStatus.EXISTING] + assert len(existing_entries) == 1 + assert existing_entries[0].data_file.file_path == file_to_keep.file_path + assert existing_entries[0].snapshot_id == old_snapshot_id + + +def test_replace_reuses_unaffected_manifests(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_reuse_manifest", + schema=Schema(), + ) + + file_a = DataFile.from_args( + file_path="s3://bucket/test/data/a.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=10, + file_size_in_bytes=100, + content=DataFileContent.DATA, + ) + file_a.spec_id = 0 + + file_b = DataFile.from_args( + file_path="s3://bucket/test/data/b.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=10, + file_size_in_bytes=100, + content=DataFileContent.DATA, + ) + file_b.spec_id = 0 + + file_c = DataFile.from_args( + file_path="s3://bucket/test/data/c.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=10, + file_size_in_bytes=100, + content=DataFileContent.DATA, + ) + file_c.spec_id = 0 + + # Commit 1: Append file A (Creates Manifest 1) + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a) + + # Commit 2: Append file B (Creates Manifest 2) + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_b) + + snapshot_before = cast(Snapshot, table.current_snapshot()) + manifests_before = snapshot_before.manifests(table.io) + assert len(manifests_before) == 2 + + # Identify which manifest belongs to file_b and file_a + manifest_b_path = None + manifest_a_path = None + for m in manifests_before: + entries = m.fetch_manifest_entry(table.io, discard_deleted=False) + if any(e.data_file.file_path == file_b.file_path for e in entries): + manifest_b_path = m.manifest_path + if any(e.data_file.file_path == file_a.file_path for e in entries): + manifest_a_path = m.manifest_path + + assert manifest_b_path is not None + assert manifest_a_path is not None + + # Commit 3: Replace file A with file C + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_a) + rewrite.append_data_file(file_c) + + snapshot_after = cast(Snapshot, table.current_snapshot()) + assert snapshot_after is not None + manifests_after = snapshot_after.manifests(table.io) + + # We expect 3 manifests: + # 1. The reused one for file B + # 2. The newly rewritten one marking file A as DELETED + # 3. The new one for file C (ADDED) + assert len(manifests_after) == 3 + + manifest_paths_after = [m.manifest_path for m in manifests_after] + + # ASSERTION 1: The untouched manifest is completely reused (the path matches exactly) + assert manifest_b_path in manifest_paths_after + + # ASSERTION 2: File A's old manifest is NOT reused (since it was rewritten to change status to DELETED) + assert manifest_a_path not in manifest_paths_after + + +def test_replace_empty_files(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_empty", + schema=Schema(), + ) + + # Replacing empty lists should not throw errors, but should produce no changes. + with table.transaction() as tx: + with tx.update_snapshot().replace(): + pass # Entering and exiting the context manager without adding/deleting + + # History should be completely empty since no files were rewritten + assert len(table.history()) == 0 + assert table.current_snapshot() is None + + +def test_replace_missing_file_abort(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_missing", + schema=Schema(), + ) + + fake_data_file = DataFile.from_args( + file_path="s3://bucket/test/data/does_not_exist.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + fake_data_file.spec_id = 0 + + new_data_file = DataFile.from_args( + file_path="s3://bucket/test/data/new.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + new_data_file.spec_id = 0 + + # Ensure it aborts when trying to replace a file that isn't in the table + with pytest.raises(ValueError, match="Cannot delete files that are not present in the table"): + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(fake_data_file) + rewrite.append_data_file(new_data_file) + + +def test_replace_invariant_violation(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_invariant", + schema=Schema(), + ) + + file_to_delete = DataFile.from_args( + file_path="s3://bucket/test/data/deleted.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_delete.spec_id = 0 + + # Create a new file with MORE records than the one we are deleting + too_many_records_file = DataFile.from_args( + file_path="s3://bucket/test/data/too_many.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=101, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + too_many_records_file.spec_id = 0 + + # Initially append to have something to replace + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + # Ensure it enforces the invariant: records added <= records removed + with pytest.raises(ValueError, match=r"Invalid replace: records added \(101\) exceeds records removed \(100\)"): + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(too_many_records_file) + + +def test_replace_allows_shrinking_for_soft_deletes(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_shrink", + schema=Schema(), + ) + + # Old data file has 100 records + file_to_delete = DataFile.from_args( + file_path="s3://bucket/test/data/deleted.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_delete.spec_id = 0 + + # New data file only has 90 records (simulating 10 records were soft-deleted) + shrunk_file_to_add = DataFile.from_args( + file_path="s3://bucket/test/data/shrunk.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=90, + file_size_in_bytes=900, + content=DataFileContent.DATA, + ) + shrunk_file_to_add.spec_id = 0 + + # Initially append + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + # This should succeed without throwing an invariant violation + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(shrunk_file_to_add) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + assert summary["operation"] == Operation.REPLACE + assert summary["added-records"] == "90" + assert summary["deleted-records"] == "100" + + +def test_replace_passes_through_delete_manifests(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_delete_manifests", + schema=Schema(), + properties={"format-version": "2"}, + ) + + # 1. Data file we will replace + file_a = DataFile.from_args( + file_path="s3://bucket/test/data/a.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=10, + file_size_in_bytes=100, + content=DataFileContent.DATA, + ) + file_a.spec_id = 0 + + # 2. A Position Delete file (representing row-level deletes) + file_a_deletes = DataFile.from_args( + file_path="s3://bucket/test/data/a_deletes.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=2, + file_size_in_bytes=50, + content=DataFileContent.POSITION_DELETES, + ) + file_a_deletes.spec_id = 0 + + # 3. Data file we are adding as a replacement + file_b = DataFile.from_args( + file_path="s3://bucket/test/data/b.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=10, + file_size_in_bytes=100, + content=DataFileContent.DATA, + ) + file_b.spec_id = 0 + + # Commit 1: Append the data file + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a) + + # Commit 2: Append the delete file + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a_deletes) + + # Find the path of the delete manifest so we can verify it survives + snapshot_before = cast(Snapshot, table.current_snapshot()) + manifests_before = snapshot_before.manifests(table.io) + + delete_manifest_path = None + for m in manifests_before: + entries = m.fetch_manifest_entry(table.io, discard_deleted=False) + if any(e.data_file.file_path == file_a_deletes.file_path for e in entries): + delete_manifest_path = m.manifest_path + break + + assert delete_manifest_path is not None + + # Commit 3: Replace data file A with data file B + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_a) + rewrite.append_data_file(file_b) + + # Verify the delete manifest was passed through unchanged + snapshot_after = cast(Snapshot, table.current_snapshot()) + assert snapshot_after is not None + manifests_after = snapshot_after.manifests(table.io) + manifest_paths_after = [m.manifest_path for m in manifests_after] + + assert delete_manifest_path in manifest_paths_after diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index cfdc516227..7f78a7546d 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -398,8 +398,8 @@ def test_merge_snapshot_summaries_overwrite_summary() -> None: def test_invalid_operation() -> None: with pytest.raises(ValueError) as e: - update_snapshot_summaries(summary=Summary(Operation.REPLACE)) - assert "Operation not implemented: Operation.REPLACE" in str(e.value) + update_snapshot_summaries(summary=Summary.model_construct(operation="unknown_operation")) + assert "Operation not implemented: unknown_operation" in str(e.value) def test_invalid_type() -> None: