diff --git a/airflow/dags/ingestion/ask-astro-load.py b/airflow/dags/ingestion/ask-astro-load.py index 0cdb554c..4e79a6cb 100644 --- a/airflow/dags/ingestion/ask-astro-load.py +++ b/airflow/dags/ingestion/ask-astro-load.py @@ -265,7 +265,7 @@ def extract_astro_blogs(): task(ask_astro_weaviate_hook.ingest_data, retries=10) .partial( class_name=WEAVIATE_CLASS, - existing="upsert", + existing="skip", doc_key="docLink", batch_params={"batch_size": 1000}, verbose=True, @@ -276,7 +276,7 @@ def extract_astro_blogs(): _import_baseline = task(ask_astro_weaviate_hook.import_baseline, trigger_rule="none_failed")( seed_baseline_url=seed_baseline_url, class_name=WEAVIATE_CLASS, - existing="upsert", + existing="error", doc_key="docLink", uuid_column="id", vector_column="vector", diff --git a/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py b/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py index 34764ad8..b1af62a0 100644 --- a/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py +++ b/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py @@ -22,6 +22,7 @@ class AskAstroWeaviateHook(WeaviateHook): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.batch_errors = [] self.logger = logging.getLogger("airflow.task") self.client = self.get_client() @@ -211,6 +212,7 @@ def batch_ingest( vector_column: str | None = None, batch_params: dict = {}, verbose: bool = False, + tenant: str | None = None, ) -> (list, Any): """ Processes the DataFrame and batches the data for ingestion into Weaviate. @@ -218,63 +220,70 @@ def batch_ingest( :param df: DataFrame containing the data to be ingested. :param class_name: The name of the class in Weaviate to which data will be ingested. :param uuid_column: Name of the column containing the UUID. + :param existing: Strategy to handle existing data ('skip', 'replace', 'upsert' or 'error'). :param vector_column: Name of the column containing the vector data. :param batch_params: Parameters for batch configuration. - :param existing: Strategy to handle existing data ('skip', 'replace', 'upsert'). - :param verbose: Whether to print verbose output. + :param verbose: Whether to log verbose output. + :param tenant: The tenant to which the object will be added. """ - batch = self.client.batch.configure(**batch_params) - batch_errors = [] - for row_id, row in df.iterrows(): - data_object = row.to_dict() - uuid = data_object.pop(uuid_column) - vector = data_object.pop(vector_column, None) - - try: - if self.client.data_object.exists(uuid=uuid, class_name=class_name) is True: - if existing == "skip": - if verbose is True: - self.logger.warning(f"UUID {uuid} exists. Skipping.") - continue - elif existing == "replace": - # Default for weaviate is replace existing - if verbose is True: - self.logger.warning(f"UUID {uuid} exists. Overwriting.") - - except Exception as e: - if verbose: - self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") - batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) - continue - - try: - added_row = batch.add_data_object( - class_name=class_name, uuid=uuid, data_object=data_object, vector=vector - ) - if verbose is True: - self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.") - - except Exception as e: - if verbose: - self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") - batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + # configuration for context manager for __exit__ method to callback on errors for weaviate batch ingestion. + if not batch_params.get("callback"): + batch_params.update({"callback": self.process_batch_errors}) + + self.client.batch.configure(**batch_params) + + with self.client.batch as batch: + for row_id, row in df.iterrows(): + data_object = row.to_dict() + uuid = data_object.pop(uuid_column) + vector = data_object.pop(vector_column, None) + + try: + if self.client.data_object.exists(uuid=uuid, class_name=class_name): + if existing == "error": + raise AirflowException(f"Ingest of UUID {uuid} failed. Object exists.") + + if existing == "skip": + if verbose is True: + self.logger.warning(f"UUID {uuid} exists. Skipping.") + continue + elif existing == "replace": + # Default for weaviate is replace existing + if verbose is True: + self.logger.warning(f"UUID {uuid} exists. Overwriting.") + except AirflowException as e: + if verbose: + self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") + self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + break + except Exception as e: + if verbose: + self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") + self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + continue - results = batch.create_objects() + try: + added_row = batch.add_data_object( + class_name=class_name, uuid=uuid, data_object=data_object, vector=vector, tenant=tenant + ) + if verbose is True: + self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.") - if len(results) > 0: - batch_errors += self.process_batch_errors(results=results, verbose=verbose) + except Exception as e: + if verbose: + self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") + self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) - return batch_errors + return self.batch_errors - def process_batch_errors(self, results: list, verbose: bool) -> list: + def process_batch_errors(self, results: list, verbose: bool = True) -> None: """ Processes the results from batch operation and collects any errors. :param results: Results from the batch operation. :param verbose: Flag to enable verbose logging. """ - errors = [] for item in results: if "errors" in item["result"]: item_error = {"uuid": item["id"], "errors": item["result"]["errors"]} @@ -282,22 +291,22 @@ def process_batch_errors(self, results: list, verbose: bool) -> list: self.logger.info( f"Error occurred in batch process for {item['id']} with error {item['result']['errors']}" ) - errors.append(item_error) - return errors + self.batch_errors.append(item_error) def handle_upsert_rollback( - self, objects_to_upsert: pd.DataFrame, batch_errors: list, class_name: str, verbose: bool - ) -> list: + self, objects_to_upsert: pd.DataFrame, class_name: str, verbose: bool, tenant: str | None = None + ) -> tuple[list, set]: """ Handles rollback of inserts in case of errors during upsert operation. :param objects_to_upsert: Dictionary of objects to upsert. :param class_name: Name of the class in Weaviate. :param verbose: Flag to enable verbose logging. + :param tenant: The tenant to which the object will be added. """ rollback_errors = [] - error_uuids = {error["uuid"] for error in batch_errors} + error_uuids = {error["uuid"] for error in self.batch_errors} objects_to_upsert["rollback_doc"] = objects_to_upsert.objects_to_insert.apply( lambda x: any(error_uuids.intersection(x)) @@ -315,9 +324,11 @@ def handle_upsert_rollback( for uuid in rollback_objects: try: - if self.client.data_object.exists(uuid=uuid, class_name=class_name): + if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant): self.logger.info(f"Removing id {uuid} for rollback.") - self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL") + self.client.data_object.delete( + uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL" + ) elif verbose: self.logger.info(f"UUID {uuid} does not exist. Skipping deletion during rollback.") except Exception as e: @@ -325,20 +336,36 @@ def handle_upsert_rollback( if verbose: self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}") - for uuid in delete_objects: + return rollback_errors, delete_objects + + def handle_successful_upsert( + self, objects_to_remove: list, class_name: str, verbose: bool, tenant: str | None = None + ) -> list: + """ + Handles removal of previous objects after successful upsert. + + :param objects_to_remove: If there were errors rollback will generate a list of successfully inserted objects. + If not set, assume all objects inserted successfully and delete all objects_to_upsert['objects_to_delete'] + :param class_name: Name of the class in Weaviate. + :param verbose: Flag to enable verbose logging. + :param tenant: The tenant to which the object will be added. + """ + deletion_errors = [] + for uuid in objects_to_remove: try: - if self.client.data_object.exists(uuid=uuid, class_name=class_name): + if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant): if verbose: self.logger.info(f"Deleting id {uuid} for successful upsert.") - self.client.data_object.delete(uuid=uuid, class_name=class_name) + self.client.data_object.delete( + uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL" + ) elif verbose: self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.") except Exception as e: - rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + deletion_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) if verbose: self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}") - - return rollback_errors + return deletion_errors def ingest_data( self, @@ -350,6 +377,7 @@ def ingest_data( vector_column: str = None, batch_params: dict = None, verbose: bool = True, + tenant: str | None = None, ) -> list: """ Ingests data into Weaviate, handling upserts and rollbacks, and returns a list of objects that failed to import. @@ -367,11 +395,14 @@ def ingest_data( :param vector_column: Column with embedding vectors for pre-embedded data. :param batch_params: Additional parameters for Weaviate batch configuration. :param verbose: Flag to enable verbose output during the ingestion process. + :param tenant: The tenant to which the object will be added. """ global objects_to_upsert - if existing not in ["skip", "replace", "upsert"]: - raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'") + if existing not in ["skip", "replace", "upsert", "error"]: + raise AirflowException( + "Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert', 'error'." + ) df = pd.concat(dfs, ignore_index=True) @@ -380,7 +411,7 @@ def ingest_data( df=df, class_name=class_name, vector_column=vector_column, uuid_column=uuid_column ) - if existing == "upsert": + if existing == "upsert" or existing == "skip": objects_to_upsert = self.identify_upsert_targets( df=df, class_name=class_name, doc_key=doc_key, uuid_column=uuid_column ) @@ -392,7 +423,7 @@ def ingest_data( self.logger.info(f"Passing {len(df)} objects for ingest.") - batch_errors = self.batch_ingest( + self.batch_ingest( df=df, class_name=class_name, uuid_column=uuid_column, @@ -400,20 +431,41 @@ def ingest_data( batch_params=batch_params, existing=existing, verbose=verbose, + tenant=tenant, ) - if existing == "upsert" and batch_errors: - self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.") - rollback_errors = self.handle_upsert_rollback( - objects_to_upsert=objects_to_upsert, batch_errors=batch_errors, class_name=class_name, verbose=verbose - ) + if existing == "upsert": + if self.batch_errors: + self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.") + rollback_errors, objects_to_remove = self.handle_upsert_rollback( + objects_to_upsert=objects_to_upsert, class_name=class_name, verbose=verbose + ) + + deletion_errors = self.handle_successful_upsert( + objects_to_remove=objects_to_remove, class_name=class_name, verbose=verbose + ) - if len(rollback_errors) > 0: - self.logger.error("Errors encountered during rollback.") - raise AirflowException("Errors encountered during rollback.") + rollback_errors += deletion_errors + + if rollback_errors: + self.logger.error("Errors encountered during rollback.") + self.logger.error("\n".join(rollback_errors)) + raise AirflowException("Errors encountered during rollback.") + else: + removal_errors = self.handle_successful_upsert( + objects_to_remove={item for sublist in objects_to_upsert.objects_to_delete for item in sublist}, + class_name=class_name, + verbose=verbose, + tenant=tenant, + ) + if removal_errors: + self.logger.error("Errors encountered during removal.") + self.logger.error("\n".join(removal_errors)) + raise AirflowException("Errors encountered during removal.") - if batch_errors: + if self.batch_errors: self.logger.error("Errors encountered during ingest.") + self.logger.error("\n".join(self.batch_errors)) raise AirflowException("Errors encountered during ingest.") def _query_objects(self, value: Any, doc_key: str, class_name: str, uuid_column: str) -> set: