Skip to content

Commit

Permalink
fix(ingestion/prefect-plugin): fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dushayntAW committed Jun 12, 2024
1 parent f03f48a commit 1bf3fce
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 29 deletions.
52 changes: 40 additions & 12 deletions metadata-ingestion-modules/prefect-plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,21 @@ env | `str` | *PROD* | The environment that all assets produced by this orchestr
platform_instance | `str` | *None* | The instance of the platform that all assets produced by this recipe belong to. For more detail please refer [here](https://datahubproject.io/docs/platform-instances/).

```python
import asyncio
from prefect_datahub.datahub_emitter import DatahubEmitter
DatahubEmitter(
datahub_rest_url="http://localhost:8080",
env="PROD",
platform_instance="local_prefect"
).save("BLOCK-NAME-PLACEHOLDER")


async def save_datahub_emitter():
datahub_emitter = DatahubEmitter(
datahub_rest_url="http://localhost:8080",
env="PROD",
platform_instance="local_prefect",
)

await datahub_emitter.save("datahub-block-7", overwrite=True)


asyncio.run(save_datahub_emitter())
```

Congrats! You can now load the saved block to use your configurations in your Flow code:
Expand All @@ -72,25 +81,44 @@ DatahubEmitter.load("BLOCK-NAME-PLACEHOLDER")
After installing `prefect-datahub` and [saving the configution](#saving-configurations-to-a-block), you can easily use it within your prefect workflows to help you emit metadata event as show below!

```python
import asyncio

from prefect import flow, task
from prefect_datahub.dataset import Dataset

from prefect_datahub.datahub_emitter import DatahubEmitter
from prefect_datahub.entities import Dataset


async def load_datahub_emitter():
datahub_emitter = DatahubEmitter()
return datahub_emitter.load("datahub-block-7")


@task(name="Extract", description="Extract the data")
def extract():
data = "This is data"
return data

datahub_emitter = DatahubEmitter.load("MY_BLOCK_NAME")

@task(name="Transform", description="Transform the data")
def transform(data):
def transform(data, datahub_emitter):
data = data.split(" ")
datahub_emitter.add_task(
inputs=[Dataset("snowflake", "mydb.schema.tableA")],
outputs=[Dataset("snowflake", "mydb.schema.tableC")],
inputs=[Dataset("snowflake", "mydb.schema.tableX")],
outputs=[Dataset("snowflake", "mydb.schema.tableY")],
)
return data

@flow(name="ETL flow", description="Extract transform load flow")

@flow(name="ETL", description="Extract transform load flow")
def etl():
data = transform("This is data")
datahub_emitter = asyncio.run(load_datahub_emitter())
data = extract()
data = transform(data, datahub_emitter)
datahub_emitter.emit_flow()


etl()
```

**Note**: To emit the tasks, user compulsory need to emit flow. Otherwise nothing will get emit.
Expand Down
4 changes: 2 additions & 2 deletions metadata-ingestion-modules/prefect-plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ task lint(type: Exec, dependsOn: installDev) {
commandLine 'bash', '-c',
"source ${venv_name}/bin/activate && set -x && " +
"black --check --diff src/ tests/ && " +
// "isort --check --diff src/ tests/ && " +
"isort --check --diff src/ tests/ && " +
"flake8 --count --statistics src/ tests/ && " +
"mypy --show-traceback --show-error-codes src/ tests/"
}
task lintFix(type: Exec, dependsOn: installDev) {
commandLine 'bash', '-x', '-c',
"source ${venv_name}/bin/activate && " +
"black src/ tests/ && " +
// "isort src/ tests/ && " +
"isort src/ tests/ && " +
"flake8 src/ tests/ && " +
"mypy src/ tests/ "
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Datahub Emitter classes used to emit prefect metadata to Datahub REST."""

import asyncio
import datahub.emitter.mce_builder as builder
import traceback
from typing import Any, Dict, List, Optional, cast
from uuid import UUID

import datahub.emitter.mce_builder as builder
from datahub.api.entities.datajob import DataFlow, DataJob
from datahub.api.entities.dataprocess.dataprocess_instance import (
DataProcessInstance,
Expand All @@ -21,7 +21,6 @@
from prefect.blocks.core import Block
from prefect.client import cloud, orchestration
from prefect.client.schemas import FlowRun, TaskRun, Workspace

from prefect.client.schemas.objects import Flow
from prefect.context import FlowRunContext, TaskRunContext
from prefect.settings import PREFECT_API_URL
Expand Down Expand Up @@ -128,16 +127,16 @@ def _get_workspace(self) -> Optional[str]:
"command 'prefect cloud login'."
)
return None

current_workspace_id = PREFECT_API_URL.value().split("/")[-1]
workspaces: List[Workspace] = asyncio.run(
cloud.get_cloud_client().read_workspaces()
)

for workspace in workspaces:
if str(workspace.workspace_id) == current_workspace_id:
return workspace.workspace_name

return None

async def _get_flow_run_graph(self, flow_run_id: str) -> Optional[List[Dict]]:
Expand Down Expand Up @@ -285,8 +284,10 @@ async def get_flow(flow_id: UUID) -> Flow:
name=flow_run_ctx.flow.name,
platform_instance=self.platform_instance,
)

dataflow.description = flow_run_ctx.flow.description
dataflow.tags = set(flow.tags)

flow_property_bag: Dict[str, str] = {}
flow_property_bag[ID] = str(flow.id)
flow_property_bag[CREATED] = str(flow.created)
Expand All @@ -305,12 +306,14 @@ async def get_flow(flow_id: UUID) -> Flow:
ON_CANCELLATION,
ON_CRASHED,
]

for key in allowed_flow_keys:
if (
hasattr(flow_run_ctx.flow, key)
and getattr(flow_run_ctx.flow, key) is not None
):
flow_property_bag[key] = repr(getattr(flow_run_ctx.flow, key))

dataflow.properties = flow_property_bag

return dataflow
Expand All @@ -331,13 +334,16 @@ def _emit_tasks(
workspace_name Optional(str): The prefect cloud workpace name.
"""
assert flow_run_ctx.flow_run

graph_json = asyncio.run(
self._get_flow_run_graph(str(flow_run_ctx.flow_run.id))
)

if graph_json is None:
return

task_run_key_map: Dict[str, str] = {}

for prefect_future in flow_run_ctx.task_run_futures:
if prefect_future.task_run is not None:
task_run_key_map[
Expand All @@ -351,13 +357,16 @@ def _emit_tasks(
data_flow_urn=str(dataflow.urn),
job_id=task_run_key_map[node[ID]],
)

datajob: Optional[DataJob] = None

if str(datajob_urn) in self.datajobs_to_emit:
datajob = cast(DataJob, self.datajobs_to_emit[str(datajob_urn)])
else:
datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_key=task_run_key_map[node[ID]]
)

if datajob is not None:
for each in node[UPSTREAM_DEPENDENCIES]:
upstream_task_urn = DataJobUrn.create_from_ids(
Expand Down Expand Up @@ -390,8 +399,10 @@ def _emit_flow_run(self, dataflow: DataFlow, flow_run_id: UUID) -> None:

async def get_flow_run(flow_run_id: UUID) -> FlowRun:
client = orchestration.get_client()

if not hasattr(client, "read_flow_run"):
raise ValueError("Client does not support async read_flow_run method")

response = client.read_flow_run(flow_run_id=flow_run_id)

if asyncio.iscoroutine(response):
Expand All @@ -407,9 +418,11 @@ async def get_flow_run(flow_run_id: UUID) -> FlowRun:
dpi_id = f"{self.platform_instance}.{flow_run.name}"
else:
dpi_id = flow_run.name

dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dpi_id)

dpi_property_bag: Dict[str, str] = {}

allowed_flow_run_keys = [
ID,
CREATED,
Expand All @@ -423,9 +436,11 @@ async def get_flow_run(flow_run_id: UUID) -> FlowRun:
TAGS,
RUN_COUNT,
]

for key in allowed_flow_run_keys:
if hasattr(flow_run, key) and getattr(flow_run, key) is not None:
dpi_property_bag[key] = str(getattr(flow_run, key))

dpi.properties.update(dpi_property_bag)

if flow_run.start_time is not None:
Expand All @@ -451,8 +466,10 @@ def _emit_task_run(

async def get_task_run(task_run_id: UUID) -> TaskRun:
client = orchestration.get_client()

if not hasattr(client, "read_task_run"):
raise ValueError("Client does not support async read_task_run method")

response = client.read_task_run(task_run_id=task_run_id)

if asyncio.iscoroutine(response):
Expand All @@ -468,6 +485,7 @@ async def get_task_run(task_run_id: UUID) -> TaskRun:
dpi_id = f"{self.platform_instance}.{flow_run_name}.{task_run.name}"
else:
dpi_id = f"{flow_run_name}.{task_run.name}"

dpi = DataProcessInstance.from_datajob(
datajob=datajob,
id=dpi_id,
Expand All @@ -476,6 +494,7 @@ async def get_task_run(task_run_id: UUID) -> TaskRun:
)

dpi_property_bag: Dict[str, str] = {}

allowed_task_run_keys = [
ID,
FLOW_RUN_ID,
Expand All @@ -489,9 +508,11 @@ async def get_task_run(task_run_id: UUID) -> TaskRun:
TAGS,
RUN_COUNT,
]

for key in allowed_task_run_keys:
if hasattr(task_run, key) and getattr(task_run, key) is not None:
dpi_property_bag[key] = str(getattr(task_run, key))

dpi.properties.update(dpi_property_bag)

state_result_map: Dict[str, InstanceRunResult] = {
Expand Down Expand Up @@ -564,13 +585,14 @@ def etl():
"""
flow_run_ctx = FlowRunContext.get()
task_run_ctx = TaskRunContext.get()

assert flow_run_ctx
assert task_run_ctx

datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_run_ctx=task_run_ctx
)

if datajob is not None:
if inputs is not None:
datajob.inlets.extend(self._entities_to_urn_list(inputs))
Expand Down Expand Up @@ -604,7 +626,7 @@ def etl():
"""
try:
flow_run_ctx = FlowRunContext.get()

assert flow_run_ctx
assert flow_run_ctx.flow_run

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import asyncio

from prefect import flow, task

from prefect_datahub.datahub_emitter import DatahubEmitter
from prefect_datahub.entities import Dataset

datahub_emitter = DatahubEmitter().load("datahub-block")

async def load_datahub_emitter():
datahub_emitter = DatahubEmitter()
return datahub_emitter.load("datahub-block-7")


@task(name="Extract", description="Extract the data")
Expand All @@ -13,7 +18,7 @@ def extract():


@task(name="Transform", description="Transform the data")
def transform(data):
def transform(data, datahub_emitter):
data = data.split(" ")
datahub_emitter.add_task(
inputs=[Dataset("snowflake", "mydb.schema.tableX")],
Expand All @@ -24,8 +29,9 @@ def transform(data):

@flow(name="ETL", description="Extract transform load flow")
def etl():
datahub_emitter = asyncio.run(load_datahub_emitter())
data = extract()
data = transform(data)
data = transform(data, datahub_emitter)
datahub_emitter.emit_flow()


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import asyncio

from prefect_datahub.datahub_emitter import DatahubEmitter

DatahubEmitter(
datahub_rest_url="http://localhost:8080",
env="DEV",
platform_instance="local_prefect",
).save("datahub-block", overwrite=True)

async def save_datahub_emitter():
datahub_emitter = DatahubEmitter(
datahub_rest_url="http://localhost:8080",
env="PROD",
platform_instance="local_prefect",
)

await datahub_emitter.save("datahub-block-7", overwrite=True)


asyncio.run(save_datahub_emitter())

0 comments on commit 1bf3fce

Please sign in to comment.