Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated metadata generation using genAI #1670

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataall.base.api.constants import GraphQLEnumMapper


class MetadataGenerationTargets(GraphQLEnumMapper):
"""Describes the s3_datasets metadata generation targets"""

Table = 'Table'
Folder = 'Folder'
S3_Dataset = 'S3_Dataset'
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
],
)


DatasetPresignedUrlInput = gql.InputType(
name='DatasetPresignedUrlInput',
arguments=[
Expand All @@ -58,6 +59,14 @@

CrawlerInput = gql.InputType(name='CrawlerInput', arguments=[gql.Argument(name='prefix', type=gql.String)])

TableSampleData = gql.InputType(
name='TableSampleData',
arguments=[
gql.Field(name='fields', type=gql.ArrayType(gql.String)),
gql.Field(name='rows', type=gql.ArrayType(gql.String)),
],
)

ImportDatasetInput = gql.InputType(
name='ImportDatasetInput',
arguments=[
Expand Down
19 changes: 14 additions & 5 deletions backend/dataall/modules/s3_datasets/api/dataset/mutations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from dataall.base.api import gql
from dataall.modules.s3_datasets.api.dataset.input_types import (
ModifyDatasetInput,
NewDatasetInput,
ImportDatasetInput,
)
from dataall.modules.s3_datasets.api.dataset.input_types import ModifyDatasetInput, NewDatasetInput, ImportDatasetInput
from dataall.modules.s3_datasets.api.dataset.resolvers import (
create_dataset,
update_dataset,
generate_dataset_access_token,
delete_dataset,
import_dataset,
start_crawler,
generate_metadata,
)
from dataall.modules.s3_datasets.api.dataset.enums import MetadataGenerationTargets

createDataset = gql.MutationField(
name='createDataset',
Expand Down Expand Up @@ -68,3 +66,14 @@
resolver=start_crawler,
type=gql.Ref('GlueCrawler'),
)
generateMetadata = gql.MutationField(
name='generateMetadata',
args=[
gql.Argument(name='resourceUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='targetType', type=gql.NonNullableType(MetadataGenerationTargets.toGraphQLEnum())),
gql.Argument(name='metadataTypes', type=gql.NonNullableType(gql.ArrayType(gql.String))),
gql.Argument(name='tableSampleData', type=gql.Ref('TableSampleData')),
],
type=gql.ArrayType(gql.Ref('DatasetMetadata')),
resolver=generate_metadata,
)
17 changes: 17 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
get_dataset_assume_role_url,
get_file_upload_presigned_url,
list_datasets_owned_by_env_group,
list_dataset_tables_folders,
read_sample_data,
)

getDataset = gql.QueryField(
Expand Down Expand Up @@ -45,3 +47,18 @@
resolver=list_datasets_owned_by_env_group,
test_scope='Dataset',
)
listDatasetTablesFolders = gql.QueryField(
name='listDatasetTablesFolders',
args=[
gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='filter', type=gql.Ref('DatasetFilter')),
],
type=gql.Ref('DatasetItemsSearchResult'),
resolver=list_dataset_tables_folders,
)
listSampleData = gql.QueryField(
name='listSampleData',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('QueryPreviewResult'), # basically returns nothing...?
resolver=read_sample_data,
) # return the data -> user invokes generateMetadata again + sample data ; similar api exists
60 changes: 59 additions & 1 deletion backend/dataall/modules/s3_datasets/api/dataset/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging

import re
from dataall.base.api.context import Context
from dataall.base.feature_toggle_checker import is_feature_enabled
from dataall.base.utils.expiration_util import Expiration
Expand All @@ -11,6 +11,9 @@
from dataall.modules.s3_datasets.db.dataset_models import S3Dataset
from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, ConfidentialityClassification
from dataall.modules.s3_datasets.services.dataset_service import DatasetService
from dataall.modules.s3_datasets.services.dataset_table_service import DatasetTableService
from dataall.modules.s3_datasets.services.dataset_location_service import DatasetLocationService
from dataall.modules.s3_datasets.services.dataset_enums import MetadataGenerationTargets, MetadataGenerationTypes

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,6 +159,49 @@ def list_datasets_owned_by_env_group(
return DatasetService.list_datasets_owned_by_env_group(environmentUri, groupUri, filter)


@is_feature_enabled('modules.s3_datasets.features.generate_metadata_ai.active')
def generate_metadata(
context: Context,
source: S3Dataset,
resourceUri: str,
targetType: str,
metadataTypes: list,
tableSampleData: dict = {},
):
RequestValidator.validate_uri(param_name='resourceUri', param_value=resourceUri)
if any(metadata_type not in [item.value for item in MetadataGenerationTypes] for metadata_type in metadataTypes):
raise InvalidInput(
'metadataType',
metadataTypes,
f'a list of allowed values {[item.value for item in MetadataGenerationTypes]}',
)
if targetType == MetadataGenerationTargets.S3_Dataset.value:
return DatasetService.generate_metadata_for_dataset(uri=resourceUri, metadata_types=metadataTypes)
elif targetType == MetadataGenerationTargets.Table.value:
return DatasetTableService.generate_metadata_for_table(
uri=resourceUri, metadata_types=metadataTypes, sample_data=tableSampleData
)
elif targetType == MetadataGenerationTargets.Folder.value:
return DatasetLocationService.generate_metadata_for_folder(uri=resourceUri, metadata_types=metadataTypes)
else:
raise Exception('Unsupported target type for metadata generation')


def read_sample_data(context: Context, source: S3Dataset, tableUri: str):
RequestValidator.validate_uri(param_name='tableUri', param_value=tableUri)
return DatasetTableService.preview(uri=tableUri)


def update_dataset_metadata(context: Context, source: S3Dataset, resourceUri: str):
return DatasetService.update_dataset(uri=resourceUri, data=input)


def list_dataset_tables_folders(context: Context, source: S3Dataset, datasetUri: str, filter: dict = None):
if not filter:
filter = {}
return DatasetService.list_dataset_tables_folders(uri=datasetUri, filter=filter)


class RequestValidator:
@staticmethod
def validate_creation_request(data):
Expand Down Expand Up @@ -200,6 +246,18 @@ def validate_share_expiration_request(data):
'is of invalid type',
)

@staticmethod
def validate_uri(param_name: str, param_value: str):
if not param_value:
raise RequiredParameter(param_name)
pattern = r'^[a-z0-9]{8}$'
if not re.match(pattern, param_value):
raise InvalidInput(
param_name=param_name,
param_value=param_value,
constraint='8 characters long and contain only lowercase letters and numbers',
)

@staticmethod
def validate_import_request(data):
RequestValidator.validate_creation_request(data)
Expand Down
36 changes: 36 additions & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,39 @@
gql.Field(name='status', type=gql.String),
],
)

DatasetMetadata = gql.ObjectType(
name='DatasetMetadata',
fields=[
gql.Field(name='targetUri', type=gql.String),
gql.Field(name='targetType', type=gql.String),
gql.Field(name='label', type=gql.String),
gql.Field(name='description', type=gql.String),
gql.Field(name='tags', type=gql.ArrayType(gql.String)),
gql.Field(name='topics', type=gql.ArrayType(gql.String)),
],
)

DatasetItem = gql.ObjectType(
name='DatasetItem',
fields=[
gql.Field(name='name', type=gql.String),
gql.Field(name='targetType', type=gql.String),
gql.Field(name='targetUri', type=gql.String),
],
)

DatasetItemsSearchResult = gql.ObjectType(
name='DatasetItemsSearchResult',
fields=[
gql.Field(name='count', type=gql.Integer),
gql.Field(name='nodes', type=gql.ArrayType(DatasetItem)),
gql.Field(name='pageSize', type=gql.Integer),
gql.Field(name='nextPage', type=gql.Integer),
gql.Field(name='pages', type=gql.Integer),
gql.Field(name='page', type=gql.Integer),
gql.Field(name='previousPage', type=gql.Integer),
gql.Field(name='hasNext', type=gql.Boolean),
gql.Field(name='hasPrevious', type=gql.Boolean),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@
gql.Argument('topics', gql.Integer),
],
)
SubitemDescription = gql.InputType(
name='SubitemDescriptionInput',
arguments=[
gql.Argument(name='label', type=gql.String),
gql.Argument(name='description', type=gql.String),
gql.Argument(name='subitem_id', type=gql.String),
],
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataall.base.api import gql
from dataall.modules.s3_datasets.api.table_column.resolvers import sync_table_columns, update_table_column
from dataall.modules.s3_datasets.api.table_column.resolvers import (
sync_table_columns,
update_table_column,
batch_update_table_columns_description,
)

syncDatasetTableColumns = gql.MutationField(
name='syncDatasetTableColumns',
Expand All @@ -18,3 +22,9 @@
type=gql.Ref('DatasetTableColumn'),
resolver=update_table_column,
)
batchUpdateDatasetTableColumn = gql.MutationField(
name='batchUpdateDatasetTableColumn',
args=[gql.Argument(name='columns', type=gql.ArrayType(gql.Ref('SubitemDescriptionInput')))],
type=gql.String,
resolver=batch_update_table_columns_description,
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ def update_table_column(context: Context, source, columnUri: str = None, input:

description = input.get('description', 'No description provided')
return DatasetColumnService.update_table_column_description(column_uri=columnUri, description=description)


def batch_update_table_columns_description(context: Context, source, columns):
if columns is None:
return None
return DatasetColumnService.batch_update_table_columns_description(columns=columns)
101 changes: 101 additions & 0 deletions backend/dataall/modules/s3_datasets/aws/bedrock_metadata_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
import os

from dataall.base.db import exceptions
from dataall.base.aws.sts import SessionHelper
from typing import List, Optional
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_aws import ChatBedrock as BedrockChat
from langchain_core.output_parsers import JsonOutputParser

log = logging.getLogger(__name__)

METADATA_GENERATION_DATASET_TEMPLATE_PATH = os.path.join(
os.path.dirname(__file__), 'bedrock_prompts', 'metadata_generation_dataset_template.txt'
)
METADATA_GENERATION_TABLE_TEMPLATE_PATH = os.path.join(
os.path.dirname(__file__), 'bedrock_prompts', 'metadata_generation_table_template.txt'
)
METADATA_GENERATION_FOLDER_TEMPLATE_PATH = os.path.join(
os.path.dirname(__file__), 'bedrock_prompts', 'metadata_generation_folder_template.txt'
)


class MetadataOutput(BaseModel):
tags: Optional[List[str]] = None
description: Optional[str] = None
label: Optional[str] = None
topics: Optional[List[str]] = None
columns_metadata: Optional[List[dict]] = None


class BedrockClient:
def __init__(self):
session = SessionHelper.get_session()
self._client = session.client('bedrock-runtime', region_name=os.getenv('AWS_REGION', 'eu-west-1'))
model_id = 'eu.anthropic.claude-3-5-sonnet-20240620-v1:0'
model_kwargs = {
'max_tokens': 4096,
'temperature': 0.5,
'top_k': 250,
'top_p': 0.5,
'stop_sequences': ['\n\nHuman'],
}
self._model = BedrockChat(client=self._client, model_id=model_id, model_kwargs=model_kwargs)

def invoke_model_dataset_metadata(self, metadata_types, dataset, tables, folders):
try:
prompt_template = PromptTemplate.from_file(METADATA_GENERATION_DATASET_TEMPLATE_PATH)
parser = JsonOutputParser(pydantic_object=MetadataOutput)
chain = prompt_template | self._model | parser
context = {
'metadata_types': metadata_types,
'dataset_label': dataset.label,
'description': dataset.description,
'tags': dataset.tags,
'topics': dataset.topics,
'table_names': [t.label for t in tables],
'table_descriptions': [t.description for t in tables],
'folder_names': [f.label for f in folders],
}
return chain.invoke(context)
except Exception as e:
raise e

def invoke_model_table_metadata(self, metadata_types, table, columns, sample_data, generate_columns_metadata=False):
try:
prompt_template = PromptTemplate.from_file(METADATA_GENERATION_TABLE_TEMPLATE_PATH)
parser = JsonOutputParser(pydantic_object=MetadataOutput)
chain = prompt_template | self._model | parser
context = {
'metadata_types': metadata_types,
'generate_columns_metadata': generate_columns_metadata,
'label': table.label,
'description': table.description,
'tags': table.tags,
'topics': table.topics,
'column_labels': [c.label for c in columns],
'column_descriptions': [c.description for c in columns],
'sample_data': sample_data,
}
return chain.invoke(context)
except Exception as e:
raise e

def invoke_model_folder_metadata(self, metadata_types, folder, files):
try:
prompt_template = PromptTemplate.from_file(METADATA_GENERATION_FOLDER_TEMPLATE_PATH)
parser = JsonOutputParser(pydantic_object=MetadataOutput)
chain = prompt_template | self._model | parser
context = {
'metadata_types': metadata_types,
'label': folder.label,
'description': folder.description,
'tags': folder.tags,
'topics': folder.topics,
'file_names': files,
}
return chain.invoke(context)
except Exception as e:
raise e
Loading
Loading