Skip to content

Commit

Permalink
feat(ingest/mssql): include stored procedure lineage (datahub-project…
Browse files Browse the repository at this point in the history
…#11912)

Co-authored-by: Harshal Sheth <[email protected]>
  • Loading branch information
mayurinehate and hsheth2 authored Nov 22, 2024
1 parent 86b8175 commit c3f9a92
Show file tree
Hide file tree
Showing 21 changed files with 4,074 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class StoredProcedure:
flow: Union[MSSQLJob, MSSQLProceduresContainer]
type: str = "STORED_PROCEDURE"
source: str = "mssql"
code: Optional[str] = None

@property
def full_type(self) -> str:
Expand Down
151 changes: 113 additions & 38 deletions metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import StructuredLogLevel
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.sql.mssql.job_models import (
JobStep,
Expand All @@ -36,6 +38,9 @@
ProcedureParameter,
StoredProcedure,
)
from datahub.ingestion.source.sql.mssql.stored_procedure_lineage import (
generate_procedure_lineage,
)
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
SqlWorkUnit,
Expand All @@ -51,6 +56,7 @@
StringTypeClass,
UnionTypeClass,
)
from datahub.utilities.file_backed_collections import FileBackedList

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,6 +105,10 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
default=False,
description="Enable to convert the SQL Server assets urns to lowercase",
)
include_lineage: bool = Field(
default=True,
description="Enable lineage extraction for stored procedures",
)

@pydantic.validator("uri_args")
def passwords_match(cls, v, values, **kwargs):
Expand Down Expand Up @@ -161,6 +171,7 @@ def __init__(self, config: SQLServerConfig, ctx: PipelineContext):
self.current_database = None
self.table_descriptions: Dict[str, str] = {}
self.column_descriptions: Dict[str, str] = {}
self.stored_procedures: FileBackedList[StoredProcedure] = FileBackedList()
if self.config.include_descriptions:
for inspector in self.get_inspectors():
db_name: str = self.get_db_name(inspector)
Expand Down Expand Up @@ -374,7 +385,7 @@ def loop_jobs(
def loop_job_steps(
self, job: MSSQLJob, job_steps: Dict[str, Any]
) -> Iterable[MetadataWorkUnit]:
for step_id, step_data in job_steps.items():
for _step_id, step_data in job_steps.items():
step = JobStep(
job_name=job.formatted_name,
step_name=step_data["step_name"],
Expand Down Expand Up @@ -412,37 +423,44 @@ def loop_stored_procedures( # noqa: C901
if procedures:
yield from self.construct_flow_workunits(data_flow=data_flow)
for procedure in procedures:
upstream = self._get_procedure_upstream(conn, procedure)
downstream = self._get_procedure_downstream(conn, procedure)
data_job = MSSQLDataJob(
entity=procedure,
)
# TODO: because of this upstream and downstream are more dependencies,
# can't be used as DataJobInputOutput.
# Should be reorganized into lineage.
data_job.add_property("procedure_depends_on", str(upstream.as_property))
data_job.add_property(
"depending_on_procedure", str(downstream.as_property)
)
procedure_definition, procedure_code = self._get_procedure_code(
conn, procedure
)
if procedure_definition:
data_job.add_property("definition", procedure_definition)
if sql_config.include_stored_procedures_code and procedure_code:
data_job.add_property("code", procedure_code)
procedure_inputs = self._get_procedure_inputs(conn, procedure)
properties = self._get_procedure_properties(conn, procedure)
data_job.add_property(
"input parameters", str([param.name for param in procedure_inputs])
)
for param in procedure_inputs:
data_job.add_property(
f"parameter {param.name}", str(param.properties)
)
for property_name, property_value in properties.items():
data_job.add_property(property_name, str(property_value))
yield from self.construct_job_workunits(data_job)
yield from self._process_stored_procedure(conn, procedure)

def _process_stored_procedure(
self, conn: Connection, procedure: StoredProcedure
) -> Iterable[MetadataWorkUnit]:
upstream = self._get_procedure_upstream(conn, procedure)
downstream = self._get_procedure_downstream(conn, procedure)
data_job = MSSQLDataJob(
entity=procedure,
)
# TODO: because of this upstream and downstream are more dependencies,
# can't be used as DataJobInputOutput.
# Should be reorganized into lineage.
data_job.add_property("procedure_depends_on", str(upstream.as_property))
data_job.add_property("depending_on_procedure", str(downstream.as_property))
procedure_definition, procedure_code = self._get_procedure_code(conn, procedure)
procedure.code = procedure_code
if procedure_definition:
data_job.add_property("definition", procedure_definition)
if procedure_code and self.config.include_stored_procedures_code:
data_job.add_property("code", procedure_code)
procedure_inputs = self._get_procedure_inputs(conn, procedure)
properties = self._get_procedure_properties(conn, procedure)
data_job.add_property(
"input parameters", str([param.name for param in procedure_inputs])
)
for param in procedure_inputs:
data_job.add_property(f"parameter {param.name}", str(param.properties))
for property_name, property_value in properties.items():
data_job.add_property(property_name, str(property_value))
if self.config.include_lineage:
# These will be used to construct lineage
self.stored_procedures.append(procedure)
yield from self.construct_job_workunits(
data_job,
# For stored procedure lineage is ingested later
include_lineage=False,
)

@staticmethod
def _get_procedure_downstream(
Expand Down Expand Up @@ -546,8 +564,8 @@ def _get_procedure_code(
code_list.append(row["Text"])
if code_slice_text in re.sub(" +", " ", row["Text"].lower()).strip():
code_slice_index = index
definition = "\n".join(code_list[:code_slice_index])
code = "\n".join(code_list[code_slice_index:])
definition = "".join(code_list[:code_slice_index])
code = "".join(code_list[code_slice_index:])
except ResourceClosedError:
logger.warning(
"Connection was closed from procedure '%s'",
Expand Down Expand Up @@ -602,16 +620,18 @@ def _get_stored_procedures(
def construct_job_workunits(
self,
data_job: MSSQLDataJob,
include_lineage: bool = True,
) -> Iterable[MetadataWorkUnit]:
yield MetadataChangeProposalWrapper(
entityUrn=data_job.urn,
aspect=data_job.as_datajob_info_aspect,
).as_workunit()

yield MetadataChangeProposalWrapper(
entityUrn=data_job.urn,
aspect=data_job.as_datajob_input_output_aspect,
).as_workunit()
if include_lineage:
yield MetadataChangeProposalWrapper(
entityUrn=data_job.urn,
aspect=data_job.as_datajob_input_output_aspect,
).as_workunit()
# TODO: Add SubType when it appear

def construct_flow_workunits(
Expand Down Expand Up @@ -664,3 +684,58 @@ def get_identifier(
if self.config.convert_urns_to_lowercase
else qualified_table_name
)

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
yield from super().get_workunits_internal()

# This is done at the end so that we will have access to tables
# from all databases in schema_resolver and discovered_tables
for procedure in self.stored_procedures:
with self.report.report_exc(
message="Failed to parse stored procedure lineage",
context=procedure.full_name,
level=StructuredLogLevel.WARN,
):
yield from auto_workunit(
generate_procedure_lineage(
schema_resolver=self.schema_resolver,
procedure=procedure,
procedure_job_urn=MSSQLDataJob(entity=procedure).urn,
is_temp_table=self.is_temp_table,
)
)

def is_temp_table(self, name: str) -> bool:
try:
parts = name.split(".")
table_name = parts[-1]
schema_name = parts[-2]
db_name = parts[-3]

if table_name.startswith("#"):
return True

# This is also a temp table if
# 1. this name would be allowed by the dataset patterns, and
# 2. we have a list of discovered tables, and
# 3. it's not in the discovered tables list
if (
self.config.database_pattern.allowed(db_name)
and self.config.schema_pattern.allowed(schema_name)
and self.config.table_pattern.allowed(name)
and self.standardize_identifier_case(name)
not in self.discovered_datasets
):
logger.debug(f"inferred as temp table {name}")
return True

except Exception:
logger.warning(f"Error parsing table name {name} ")
return False

def standardize_identifier_case(self, table_ref_str: str) -> str:
return (
table_ref_str.lower()
if self.config.convert_urns_to_lowercase
else table_ref_str
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
from typing import Callable, Iterable, Optional

from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.source.sql.mssql.job_models import StoredProcedure
from datahub.metadata.schema_classes import DataJobInputOutputClass
from datahub.sql_parsing.datajob import to_datajob_input_output
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.split_statements import split_statements
from datahub.sql_parsing.sql_parsing_aggregator import (
ObservedQuery,
SqlParsingAggregator,
)

logger = logging.getLogger(__name__)


def parse_procedure_code(
*,
schema_resolver: SchemaResolver,
default_db: Optional[str],
default_schema: Optional[str],
code: str,
is_temp_table: Callable[[str], bool],
raise_: bool = False,
) -> Optional[DataJobInputOutputClass]:
aggregator = SqlParsingAggregator(
platform=schema_resolver.platform,
env=schema_resolver.env,
schema_resolver=schema_resolver,
generate_lineage=True,
generate_queries=False,
generate_usage_statistics=False,
generate_operations=False,
generate_query_subject_fields=False,
generate_query_usage_statistics=False,
is_temp_table=is_temp_table,
)
for query in split_statements(code):
# TODO: We should take into account `USE x` statements.
aggregator.add_observed_query(
observed=ObservedQuery(
default_db=default_db,
default_schema=default_schema,
query=query,
)
)
if aggregator.report.num_observed_queries_failed and raise_:
logger.info(aggregator.report.as_string())
raise ValueError(
f"Failed to parse {aggregator.report.num_observed_queries_failed} queries."
)

mcps = list(aggregator.gen_metadata())
return to_datajob_input_output(
mcps=mcps,
ignore_extra_mcps=True,
)


# Is procedure handling generic enough to be added to SqlParsingAggregator?
def generate_procedure_lineage(
*,
schema_resolver: SchemaResolver,
procedure: StoredProcedure,
procedure_job_urn: str,
is_temp_table: Callable[[str], bool] = lambda _: False,
raise_: bool = False,
) -> Iterable[MetadataChangeProposalWrapper]:
if procedure.code:
datajob_input_output = parse_procedure_code(
schema_resolver=schema_resolver,
default_db=procedure.db,
default_schema=procedure.schema,
code=procedure.code,
is_temp_table=is_temp_table,
raise_=raise_,
)

if datajob_input_output:
yield MetadataChangeProposalWrapper(
entityUrn=procedure_job_urn,
aspect=datajob_input_output,
)
12 changes: 10 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str)
platform_instance=self.config.platform_instance,
env=self.config.env,
)
self.discovered_datasets: Set[str] = set()
self._view_definition_cache: MutableMapping[str, str]
if self.config.use_file_backed_cache:
self._view_definition_cache = FileBackedDict[str]()
Expand Down Expand Up @@ -831,8 +832,9 @@ def _process_table(
self._classify(dataset_name, schema, table, data_reader, schema_metadata)

dataset_snapshot.aspects.append(schema_metadata)
if self.config.include_view_lineage:
if self._save_schema_to_resolver():
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
self.discovered_datasets.add(dataset_name)
db_name = self.get_db_name(inspector)

yield from self.add_table_to_schema_container(
Expand Down Expand Up @@ -1126,8 +1128,9 @@ def _process_view(
columns,
canonical_schema=schema_fields,
)
if self.config.include_view_lineage:
if self._save_schema_to_resolver():
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
self.discovered_datasets.add(dataset_name)
description, properties, _ = self.get_table_properties(inspector, schema, view)
try:
view_definition = inspector.get_view_definition(view, schema)
Expand Down Expand Up @@ -1190,6 +1193,11 @@ def _process_view(
domain_registry=self.domain_registry,
)

def _save_schema_to_resolver(self):
return self.config.include_view_lineage or (
hasattr(self.config, "include_lineage") and self.config.include_lineage
)

def _run_sql_parser(
self, view_identifier: str, query: str, schema_resolver: SchemaResolver
) -> Optional[SqlParsingResult]:
Expand Down
Loading

0 comments on commit c3f9a92

Please sign in to comment.