diff --git a/framework/feature_factory/__init__.py b/framework/feature_factory/__init__.py index dda842c..753c373 100644 --- a/framework/feature_factory/__init__.py +++ b/framework/feature_factory/__init__.py @@ -5,12 +5,14 @@ from framework.feature_factory.feature import Feature, FeatureSet, Multiplier from framework.configobj import ConfigObj from framework.feature_factory.helpers import Helpers +from framework.feature_factory.agg_granularity import AggregationGranularity import re import logging import datetime import inspect from collections import OrderedDict from typing import List +from enum import Enum logger = logging.getLogger(__name__) @@ -19,7 +21,7 @@ class Feature_Factory(): def __init__(self): self.helpers = Helpers() - def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[FeatureSet], withTrendsForFeatures: List[FeatureSet] = None): + def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[FeatureSet], withTrendsForFeatures: List[FeatureSet] = None, granularityEnum: Enum = None): """ Appends features to incoming df. The features columns and groupby cols will be deduped and validated. If there's a group by, the groupby cols will be applied before appending features. @@ -50,10 +52,13 @@ def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[Featur # valid_result, undef_cols = self.helpers._validate_col(df, *groupBy_cols) # assert valid_result, "groupby cols {} are not defined in df columns {}".format(undef_cols, df.columns) + granularity_validator = AggregationGranularity(granularityEnum) if granularityEnum else None for feature in features: assert True if ((len(feature.aggs) > 0) and (len( groupBy_cols) > 0) or feature.agg_func is None) else False, "{} has either aggs or groupBys " \ "but not both, ensure both are present".format(feature.name) + if granularity_validator: + granularity_validator.validate(feature, groupBy_cols) # feature_cols.append(feature.assembled_column) # feature_cols.append(F.col(feature.output_alias)) agg_cols += [agg_col for agg_col in feature.aggs] @@ -76,7 +81,7 @@ def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[Featur # new_df = df.select(*df.columns + feature_cols) return final_df - def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names = [], withTrendsForFeatures: List[FeatureSet] = None): + def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names = [], withTrendsForFeatures: List[FeatureSet] = None, granularityEnum: Enum = None): """ Appends features to incoming df. The features columns and groupby cols will be deduped and validated. If there's a group by, the groupby cols will be applied before appending features. @@ -89,5 +94,5 @@ def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names # dct = self._get_all_features(catalog_cls) dct = catalog_cls.get_all_features() fs = FeatureSet(dct) - return self.append_features(df, groupBy_cols, [fs], withTrendsForFeatures) + return self.append_features(df, groupBy_cols, [fs], withTrendsForFeatures, granularityEnum) diff --git a/framework/feature_factory/agg_granularity.py b/framework/feature_factory/agg_granularity.py new file mode 100644 index 0000000..9639561 --- /dev/null +++ b/framework/feature_factory/agg_granularity.py @@ -0,0 +1,26 @@ +from enum import IntEnum, EnumMeta +from pyspark.sql.column import Column + +class AggregationGranularity: + def __init__(self, granularity: EnumMeta) -> None: + assert isinstance(granularity, EnumMeta), "Granularity should be of type Enum." + self.granularity = granularity + + + def validate(self, feat, groupby_list): + if not feat.agg_granularity: + return None + min_granularity_level = float("inf") + for level in groupby_list: + if isinstance(level, str): + try: + level = self.granularity[level] + except: + print(f"{level} is not part of {self.granularity}") + continue + if isinstance(level, Column): + continue + min_granularity_level = min(min_granularity_level, level.value) + assert min_granularity_level <= feat.agg_granularity.value, f"Required granularity for {feat.name} is {feat.agg_granularity}" + + \ No newline at end of file diff --git a/framework/feature_factory/feature.py b/framework/feature_factory/feature.py index 2ed522c..8f1e541 100644 --- a/framework/feature_factory/feature.py +++ b/framework/feature_factory/feature.py @@ -21,7 +21,8 @@ def __init__(self, _agg_func=None, _agg_alias:str=None, _kind="multipliable", - _is_temporary=False): + _is_temporary=False, + _agg_granularity=None): """ :param _name: name of the feature @@ -47,6 +48,7 @@ def __init__(self, self.kind = _kind self.is_temporary = _is_temporary self.names = None + self.agg_granularity = _agg_granularity def set_feature_name(self, name: str): self._name = name @@ -135,7 +137,8 @@ def create(cls, base_col: Column, filter: List[Column] = [], negative_value=None, agg_func=None, - agg_alias: str = None): + agg_alias: str = None, + agg_granularity: str = None): return Feature( _name = "", @@ -143,7 +146,8 @@ def create(cls, base_col: Column, _filter = filter, _negative_value = negative_value, _agg_func = agg_func, - _agg_alias = agg_alias) + _agg_alias = agg_alias, + _agg_granularity = agg_granularity) class FeatureSet: diff --git a/test/test_feature_catalog.py b/test/test_feature_catalog.py index 41cb62d..4b12c84 100644 --- a/test/test_feature_catalog.py +++ b/test/test_feature_catalog.py @@ -8,6 +8,7 @@ from pyspark.sql.types import StructType from test.local_spark_singleton import SparkSingleton from framework.feature_factory.catalog import CatalogBase +from enum import IntEnum class CommonCatalog(CatalogBase): total_sales = Feature.create( @@ -18,15 +19,22 @@ class CommonCatalog(CatalogBase): total_quants = Feature.create(base_col=f.col("ss_quantity"), agg_func=f.sum) +class Granularity(IntEnum): + PRODUCT_ID = 1, + PRODUCT_DIVISION = 2, + COUNTRY = 3 + class SalesCatalog(CommonCatalog): _valid_sales_filter = f.col("ss_net_paid") > 0 total_sales = Feature.create( base_col=CommonCatalog.total_sales, filter=_valid_sales_filter, - agg_func=f.sum + agg_func=f.sum, + agg_granularity=Granularity.PRODUCT_DIVISION ) - + + def generate_sales_catalog(CommonCatalog): class SalesCatalog(CommonCatalog): _valid_sales_filter = f.col("ss_net_paid") > 0 @@ -40,10 +48,13 @@ class SalesCatalog(CommonCatalog): class TestSalesCatalog(unittest.TestCase): def setUp(self): - with open("test/data/sales_store_schema.json") as f: - sales_schema = StructType.fromJson(json.load(f)) - self.sales_df = SparkSingleton.get_instance().read.csv("test/data/sales_store_tpcds.csv", schema=sales_schema, header=True) - + with open("test/data/sales_store_schema.json") as fp: + sales_schema = StructType.fromJson(json.load(fp)) + df = SparkSingleton.get_instance().read.csv("test/data/sales_store_tpcds.csv", schema=sales_schema, header=True) + self.sales_df = df.withColumn("PRODUCT_ID", f.lit("product"))\ + .withColumn("PRODUCT_DIVISION", f.lit("division"))\ + .withColumn("COUNTRY", f.lit("country")) + def test_append_catalog(self): customer_id = f.col("ss_customer_sk").alias("customer_id") ff = Feature_Factory() @@ -57,4 +68,10 @@ def test_common_catalog(self): salesCatalogClass = generate_sales_catalog(CommonCatalog=CommonCatalog) df = ff.append_catalog(self.sales_df, [customer_id], salesCatalogClass) assert df.count() > 0 - assert "total_sales" in df.columns and "total_quants" in df.columns \ No newline at end of file + assert "total_sales" in df.columns and "total_quants" in df.columns + + def test_granularity(self): + customer_id = f.col("ss_customer_sk").alias("customer_id") + ff = Feature_Factory() + df = ff.append_catalog(self.sales_df, [customer_id, "PRODUCT_ID"], SalesCatalog, granularityEnum=Granularity) + assert df.count() > 0