diff --git a/.gitignore b/.gitignore index 7043f0e7d4..fb6e37429b 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,5 @@ htmlcov pyiceberg/avro/decoder_fast.c pyiceberg/avro/*.html pyiceberg/avro/*.so + +.ks/ \ No newline at end of file diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 0aff68520b..14697e3457 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -27,6 +27,7 @@ from sortedcontainers import SortedList +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysFalse, BooleanExpression, @@ -60,6 +61,7 @@ Snapshot, SnapshotSummaryCollector, Summary, + ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.update import ( @@ -250,6 +252,21 @@ def _commit(self) -> UpdatesAndRequirements: ) location_provider = self._transaction._table.location_provider() manifest_list_file_path = location_provider.new_metadata_location(file_name) + + # get current snapshot id and starting snapshot id, and validate that there are no conflicts + from pyiceberg.table import StagedTable + + if not isinstance(self._transaction._table, StagedTable): + starting_snapshot = self._transaction.table_metadata.current_snapshot() + current_snapshot = self._transaction._table.refresh().metadata.current_snapshot() + + if starting_snapshot is not None and current_snapshot is not None: + self._validate(starting_snapshot, current_snapshot) + + # If the current snapshot is not the same as the parent snapshot, update the parent snapshot id + if current_snapshot is not None and current_snapshot.snapshot_id != self._parent_snapshot_id: + self._parent_snapshot_id = current_snapshot.snapshot_id + with write_manifest_list( format_version=self._transaction.table_metadata.format_version, output_file=self._io.new_output(manifest_list_file_path), @@ -278,6 +295,30 @@ def _commit(self) -> UpdatesAndRequirements: (AssertRefSnapshotId(snapshot_id=self._transaction.table_metadata.current_snapshot_id, ref="main"),), ) + def _validate(self, starting_snapshot: Snapshot, current_snapshot: Snapshot) -> None: + # Define allowed operations for each type of operation + allowed_operations = { + Operation.APPEND: {Operation.APPEND, Operation.REPLACE, Operation.OVERWRITE, Operation.DELETE}, + Operation.REPLACE: {}, + Operation.OVERWRITE: set(), + Operation.DELETE: set(), + } + + # get all the snapshots between the current snapshot id and the parent id + snapshots = ancestors_of(current_snapshot, self._transaction._table.metadata) + + for snapshot in snapshots: + if snapshot.snapshot_id == starting_snapshot.snapshot_id: + break + + snapshot_operation = snapshot.summary.operation if snapshot.summary is not None else None + + if snapshot_operation not in allowed_operations[self._operation]: + raise CommitFailedException( + f"Operation {snapshot_operation} is not allowed when performing {self._operation}. " + "Check for overlaps or conflicts." + ) + @property def snapshot_id(self) -> int: return self._snapshot_id diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 3d36ffcf31..f4ade53b76 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -31,7 +31,7 @@ from pytest_mock.plugin import MockerFixture from pyiceberg.catalog import Catalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError from pyiceberg.io import FileIO from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, schema_to_pyarrow from pyiceberg.manifest import DataFile @@ -901,6 +901,81 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value) +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + tbl1.delete("string == 'z'") + + with pytest.raises( + CommitFailedException, match="Operation .* is not allowed when performing .*. Check for overlaps or conflicts." + ): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + # This is allowed + tbl1.delete("string == 'z'") + tbl2.append(arrow_table_with_null) + + # verify against expected table + arrow_table_expected = arrow_table_with_null[:2] + assert tbl1.scan().to_arrow() == arrow_table_expected + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + tbl1.append(arrow_table_with_null) + + with pytest.raises( + CommitFailedException, match="Operation .* is not allowed when performing .*. Check for overlaps or conflicts." + ): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [2]) +def test_conflict_append_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + tbl1.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null) + + # verify against expected table + arrow_table_expected = pa.concat_tables([arrow_table_with_null, arrow_table_with_null, arrow_table_with_null]) + assert tbl1.scan().to_arrow() == arrow_table_expected + + @pytest.mark.integration def test_add_files_hour_transform(session_catalog: Catalog) -> None: identifier = "default.test_add_files_hour_transform"