Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the changes to handle upsert to remove document and context manag… #177

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/dags/ingestion/ask-astro-load.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def extract_astro_blogs():
task(ask_astro_weaviate_hook.ingest_data, retries=10)
.partial(
class_name=WEAVIATE_CLASS,
existing="upsert",
existing="skip",
doc_key="docLink",
batch_params={"batch_size": 1000},
verbose=True,
Expand All @@ -276,7 +276,7 @@ def extract_astro_blogs():
_import_baseline = task(ask_astro_weaviate_hook.import_baseline, trigger_rule="none_failed")(
seed_baseline_url=seed_baseline_url,
class_name=WEAVIATE_CLASS,
existing="upsert",
existing="error",
doc_key="docLink",
uuid_column="id",
vector_column="vector",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AskAstroWeaviateHook(WeaviateHook):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_errors = []
self.logger = logging.getLogger("airflow.task")
self.client = self.get_client()

Expand Down Expand Up @@ -211,93 +212,101 @@ def batch_ingest(
vector_column: str | None = None,
batch_params: dict = {},
verbose: bool = False,
tenant: str | None = None,
) -> (list, Any):
"""
Processes the DataFrame and batches the data for ingestion into Weaviate.

:param df: DataFrame containing the data to be ingested.
:param class_name: The name of the class in Weaviate to which data will be ingested.
:param uuid_column: Name of the column containing the UUID.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert' or 'error').
:param vector_column: Name of the column containing the vector data.
:param batch_params: Parameters for batch configuration.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert').
:param verbose: Whether to print verbose output.
:param verbose: Whether to log verbose output.
:param tenant: The tenant to which the object will be added.
"""
batch = self.client.batch.configure(**batch_params)
batch_errors = []

for row_id, row in df.iterrows():
data_object = row.to_dict()
uuid = data_object.pop(uuid_column)
vector = data_object.pop(vector_column, None)

try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name) is True:
if existing == "skip":
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
continue
elif existing == "replace":
# Default for weaviate is replace existing
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")

except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
continue

try:
added_row = batch.add_data_object(
class_name=class_name, uuid=uuid, data_object=data_object, vector=vector
)
if verbose is True:
self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")

except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
# configuration for context manager for __exit__ method to callback on errors for weaviate batch ingestion.
if not batch_params.get("callback"):
batch_params.update({"callback": self.process_batch_errors})

self.client.batch.configure(**batch_params)

with self.client.batch as batch:
for row_id, row in df.iterrows():
data_object = row.to_dict()
uuid = data_object.pop(uuid_column)
vector = data_object.pop(vector_column, None)

try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if existing == "error":
raise AirflowException(f"Ingest of UUID {uuid} failed. Object exists.")

if existing == "skip":
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
continue
elif existing == "replace":
# Default for weaviate is replace existing
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")
except AirflowException as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
break
except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
continue

results = batch.create_objects()
try:
added_row = batch.add_data_object(
class_name=class_name, uuid=uuid, data_object=data_object, vector=vector, tenant=tenant
)
if verbose is True:
self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")

if len(results) > 0:
batch_errors += self.process_batch_errors(results=results, verbose=verbose)
except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})

return batch_errors
return self.batch_errors

def process_batch_errors(self, results: list, verbose: bool) -> list:
def process_batch_errors(self, results: list, verbose: bool = True) -> None:
"""
Processes the results from batch operation and collects any errors.

:param results: Results from the batch operation.
:param verbose: Flag to enable verbose logging.
"""
errors = []
for item in results:
if "errors" in item["result"]:
item_error = {"uuid": item["id"], "errors": item["result"]["errors"]}
if verbose:
self.logger.info(
f"Error occurred in batch process for {item['id']} with error {item['result']['errors']}"
)
errors.append(item_error)
return errors
self.batch_errors.append(item_error)

def handle_upsert_rollback(
self, objects_to_upsert: pd.DataFrame, batch_errors: list, class_name: str, verbose: bool
) -> list:
self, objects_to_upsert: pd.DataFrame, class_name: str, verbose: bool, tenant: str | None = None
) -> tuple[list, set]:
"""
Handles rollback of inserts in case of errors during upsert operation.

:param objects_to_upsert: Dictionary of objects to upsert.
:param class_name: Name of the class in Weaviate.
:param verbose: Flag to enable verbose logging.
:param tenant: The tenant to which the object will be added.
"""
rollback_errors = []

error_uuids = {error["uuid"] for error in batch_errors}
error_uuids = {error["uuid"] for error in self.batch_errors}

objects_to_upsert["rollback_doc"] = objects_to_upsert.objects_to_insert.apply(
lambda x: any(error_uuids.intersection(x))
Expand All @@ -315,30 +324,48 @@ def handle_upsert_rollback(

for uuid in rollback_objects:
try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant):
self.logger.info(f"Removing id {uuid} for rollback.")
self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL")
self.client.data_object.delete(
uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL"
)
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion during rollback.")
except Exception as e:
rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
if verbose:
self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}")

for uuid in delete_objects:
return rollback_errors, delete_objects

def handle_successful_upsert(
self, objects_to_remove: list, class_name: str, verbose: bool, tenant: str | None = None
) -> list:
"""
Handles removal of previous objects after successful upsert.

:param objects_to_remove: If there were errors rollback will generate a list of successfully inserted objects.
If not set, assume all objects inserted successfully and delete all objects_to_upsert['objects_to_delete']
:param class_name: Name of the class in Weaviate.
:param verbose: Flag to enable verbose logging.
:param tenant: The tenant to which the object will be added.
"""
deletion_errors = []
for uuid in objects_to_remove:
try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant):
if verbose:
self.logger.info(f"Deleting id {uuid} for successful upsert.")
self.client.data_object.delete(uuid=uuid, class_name=class_name)
self.client.data_object.delete(
uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL"
)
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.")
except Exception as e:
rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
deletion_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
if verbose:
self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}")

return rollback_errors
return deletion_errors

def ingest_data(
self,
Expand All @@ -350,6 +377,7 @@ def ingest_data(
vector_column: str = None,
batch_params: dict = None,
verbose: bool = True,
tenant: str | None = None,
) -> list:
"""
Ingests data into Weaviate, handling upserts and rollbacks, and returns a list of objects that failed to import.
Expand All @@ -367,11 +395,14 @@ def ingest_data(
:param vector_column: Column with embedding vectors for pre-embedded data.
:param batch_params: Additional parameters for Weaviate batch configuration.
:param verbose: Flag to enable verbose output during the ingestion process.
:param tenant: The tenant to which the object will be added.
"""

global objects_to_upsert
if existing not in ["skip", "replace", "upsert"]:
raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'")
if existing not in ["skip", "replace", "upsert", "error"]:
raise AirflowException(
"Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert', 'error'."
)

df = pd.concat(dfs, ignore_index=True)

Expand All @@ -380,7 +411,7 @@ def ingest_data(
df=df, class_name=class_name, vector_column=vector_column, uuid_column=uuid_column
)

if existing == "upsert":
if existing == "upsert" or existing == "skip":
objects_to_upsert = self.identify_upsert_targets(
df=df, class_name=class_name, doc_key=doc_key, uuid_column=uuid_column
)
Expand All @@ -392,28 +423,49 @@ def ingest_data(

self.logger.info(f"Passing {len(df)} objects for ingest.")

batch_errors = self.batch_ingest(
self.batch_ingest(
df=df,
class_name=class_name,
uuid_column=uuid_column,
vector_column=vector_column,
batch_params=batch_params,
existing=existing,
verbose=verbose,
tenant=tenant,
)

if existing == "upsert" and batch_errors:
self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.")
rollback_errors = self.handle_upsert_rollback(
objects_to_upsert=objects_to_upsert, batch_errors=batch_errors, class_name=class_name, verbose=verbose
)
if existing == "upsert":
if self.batch_errors:
self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.")
rollback_errors, objects_to_remove = self.handle_upsert_rollback(
objects_to_upsert=objects_to_upsert, class_name=class_name, verbose=verbose
)

deletion_errors = self.handle_successful_upsert(
objects_to_remove=objects_to_remove, class_name=class_name, verbose=verbose
)

if len(rollback_errors) > 0:
self.logger.error("Errors encountered during rollback.")
raise AirflowException("Errors encountered during rollback.")
rollback_errors += deletion_errors

if rollback_errors:
self.logger.error("Errors encountered during rollback.")
self.logger.error("\n".join(rollback_errors))
raise AirflowException("Errors encountered during rollback.")
else:
removal_errors = self.handle_successful_upsert(
objects_to_remove={item for sublist in objects_to_upsert.objects_to_delete for item in sublist},
class_name=class_name,
verbose=verbose,
tenant=tenant,
)
if removal_errors:
self.logger.error("Errors encountered during removal.")
self.logger.error("\n".join(removal_errors))
raise AirflowException("Errors encountered during removal.")

if batch_errors:
if self.batch_errors:
self.logger.error("Errors encountered during ingest.")
self.logger.error("\n".join(self.batch_errors))
raise AirflowException("Errors encountered during ingest.")

def _query_objects(self, value: Any, doc_key: str, class_name: str, uuid_column: str) -> set:
Expand Down
Loading