From 5c6da4430b29fef36456bdee8cb6c1fd2e83c447 Mon Sep 17 00:00:00 2001 From: fecet Date: Sat, 14 Dec 2024 20:03:59 +0800 Subject: [PATCH] feat: allow commit transaction --- python/python/lance/dataset.py | 46 ++++++++++++++---------- python/src/dataset.rs | 64 ++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 8f0f6daf8a..1dc8987096 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2101,7 +2101,7 @@ def _commit( @staticmethod def commit( base_uri: Union[str, Path, LanceDataset], - operation: LanceOperation.BaseOperation, + operation: LanceOperation.BaseOperation | Transaction, read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, @@ -2206,24 +2206,34 @@ def commit( f"commit_lock must be a function, got {type(commit_lock)}" ) - if read_version is None and not isinstance( - operation, (LanceOperation.Overwrite, LanceOperation.Restore) - ): - raise ValueError( - "read_version is required for all operations except " - "Overwrite and Restore" + if isinstance(operation, Transaction): + new_ds = _Dataset.commit_transaction( + base_uri, + operation, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, + ) + else: + if read_version is None and not isinstance( + operation, (LanceOperation.Overwrite, LanceOperation.Restore) + ): + raise ValueError( + "read_version is required for all operations except " + "Overwrite and Restore" + ) + new_ds = _Dataset.commit( + base_uri, + operation._to_inner(), + read_version, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, ) - - new_ds = _Dataset.commit( - base_uri, - operation._to_inner(), - read_version, - commit_lock, - storage_options=storage_options, - enable_v2_manifest_paths=enable_v2_manifest_paths, - detached=detached, - max_retries=max_retries, - ) ds = LanceDataset.__new__(LanceDataset) ds._storage_options = storage_options ds._ds = new_ds diff --git a/python/src/dataset.rs b/python/src/dataset.rs index f55c0646ba..6685ee721c 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1498,6 +1498,70 @@ impl Dataset { }) } + #[allow(clippy::too_many_arguments)] + #[staticmethod] + #[pyo3(signature = ( + dest, + transaction, + commit_lock = None, + storage_options = None, + enable_v2_manifest_paths = None, + detached = None, + max_retries = None + ))] + fn commit_transaction<'py>( + dest: &Bound<'py, PyAny>, + transaction: &Bound<'py, PyAny>, + commit_lock: Option<&Bound<'py, PyAny>>, + storage_options: Option>, + enable_v2_manifest_paths: Option, + detached: Option, + max_retries: Option, + ) -> PyResult { + let object_store_params = storage_options.as_ref().map(|storage_options| ObjectStoreParams { + storage_options: Some(storage_options.clone()), + ..Default::default() + }); + + let commit_handler = commit_lock.map(|commit_lock| { + Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py()))) as Arc + }); + + let py = dest.py(); + + let dest = if dest.is_instance_of::() { + let dataset: Self = dest.extract()?; + WriteDestination::Dataset(dataset.ds.clone()) + } else { + WriteDestination::Uri(dest.extract()?) + }; + + let mut builder = CommitBuilder::new(dest) + .enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false)) + .with_detached(detached.unwrap_or(false)) + .with_max_retries(max_retries.unwrap_or(20)); + + if let Some(store_params) = object_store_params { + builder = builder.with_store_params(store_params); + } + + if let Some(commit_handler) = commit_handler { + builder = builder.with_commit_handler(commit_handler); + } + + let transaction = extract_transaction(&transaction)?; + + let ds = RT + .block_on(Some(py), builder.execute(transaction))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + + let uri = ds.uri().to_string(); + Ok(Self { + ds: Arc::new(ds), + uri, + }) + } + #[staticmethod] #[pyo3(signature = (dest, transactions, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] fn commit_batch<'py>(