From 8a162282ffcbcab5bf2a4464864f6f9f14675d75 Mon Sep 17 00:00:00 2001 From: Austin Noto-Moniz Date: Wed, 20 Nov 2024 15:13:18 -0500 Subject: [PATCH] [PNE-6367] Add support for feature effects. The payload comes in as a condensed format, so it's expanded in order to constructed nested lists of objects for clarity and ease of use. Additionally, the hierarchy of data is flipped to more closely match how it will be used by our customers. To that end, 'as_dict' is provided to ease importing it into a pandas DataFrame for whatever processing and analysis the customer desires. --- src/citrine/__version__.py | 2 +- src/citrine/informatics/feature_effects.py | 75 +++++++++++++++++++ .../informatics/predictors/graph_predictor.py | 11 ++- src/citrine/resources/table_config.py | 1 + tests/informatics/test_predictors.py | 27 ++++++- tests/utils/factories.py | 40 ++++++++++ 6 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 src/citrine/informatics/feature_effects.py diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 92bbba870..d1a7f1e0d 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.11.6" +__version__ = "3.12.0" diff --git a/src/citrine/informatics/feature_effects.py b/src/citrine/informatics/feature_effects.py new file mode 100644 index 000000000..55fb63683 --- /dev/null +++ b/src/citrine/informatics/feature_effects.py @@ -0,0 +1,75 @@ +from typing import Dict +from uuid import UUID + +from citrine._rest.resource import Resource +from citrine._serialization import properties + + +class ShapleyMaterial(Resource): + """The feature effect of a material.""" + + material_id = properties.UUID('material_id', serializable=False) + value = properties.Float('value', serializable=False) + + +class ShapleyFeature(Resource): + """All feature effects for this feature by material.""" + + feature = properties.String('feature', serializable=False) + materials = properties.List(properties.Object(ShapleyMaterial), 'materials', + serializable=False) + + @property + def material_dict(self) -> Dict[UUID, float]: + """Presents the feature's effects as a dictionary by material.""" + return {material.material_id: material.value for material in self.materials} + + +class ShapleyOutput(Resource): + """All feature effects for this output by feature.""" + + output = properties.String('output', serializable=False) + features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False) + + @property + def feature_dict(self) -> Dict[str, Dict[UUID, float]]: + """Presents the output's feature effects as a dictionary by feature.""" + return {feature.feature: feature.material_dict for feature in self.features} + + +class FeatureEffects(Resource): + """Captures information about the feature effects associated with a predictor.""" + + predictor_id = properties.UUID('metadata.predictor_id', serializable=False) + predictor_version = properties.Integer('metadata.predictor_version', serializable=False) + status = properties.String('metadata.status', serializable=False) + failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason', + serializable=False) + + outputs = properties.List(properties.Object(ShapleyOutput), 'resultobj', serializable=False) + + @classmethod + def _pre_build(cls, data: dict) -> Dict: + shapley = data["result"] + material_ids = shapley["materials"] + + outputs = [] + for output, feature_dict in shapley["outputs"].items(): + features = [] + for feature, values in feature_dict.items(): + items = zip(material_ids, values) + materials = [{"material_id": mid, "value": value} for mid, value in items] + features.append({ + "feature": feature, + "materials": materials + }) + + outputs.append({"output": output, "features": features}) + + data["resultobj"] = outputs + return data + + @property + def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]: + """Presents the feature effects as a dictionary by output.""" + return {output.output: output.feature_dict for output in self.outputs} diff --git a/src/citrine/informatics/predictors/graph_predictor.py b/src/citrine/informatics/predictors/graph_predictor.py index 2e2cebdca..324a2fbb0 100644 --- a/src/citrine/informatics/predictors/graph_predictor.py +++ b/src/citrine/informatics/predictors/graph_predictor.py @@ -7,9 +7,11 @@ from citrine._session import Session from citrine._utils.functions import format_escaped_url from citrine.informatics.data_sources import DataSource +from citrine.informatics.feature_effects import FeatureEffects from citrine.informatics.predictors.single_predict_request import SinglePredictRequest from citrine.informatics.predictors.single_prediction import SinglePrediction from citrine.informatics.predictors import PredictorNode, Predictor +from citrine.informatics.reports import Report from citrine.resources.report import ReportResource __all__ = ['GraphPredictor'] @@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict: } @property - def report(self): + def report(self) -> Report: """Fetch the predictor report.""" if self.uid is None or self._session is None or self._project_id is None \ or getattr(self, "version", None) is None: @@ -113,6 +115,13 @@ def report(self): report_resource = ReportResource(self._project_id, self._session) return report_resource.get(predictor_id=self.uid, predictor_version=self.version) + @property + def feature_effects(self) -> FeatureEffects: + """Retrieve the feature effects for all outputs in the predictor's training data..""" + path = self._path() + '/shapley/query' + response = self._session.post_resource(path, {}, version=self._api_version) + return FeatureEffects.build(response) + def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction: """Make a one-off prediction with this predictor.""" path = self._path() + '/predict' diff --git a/src/citrine/resources/table_config.py b/src/citrine/resources/table_config.py index 3a5726709..1421ef82e 100644 --- a/src/citrine/resources/table_config.py +++ b/src/citrine/resources/table_config.py @@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]): The query used to define the materials underpinning this table generation_algorithm: TableFromGemdQueryAlgorithm Which algorithm was used to generate the config based on the GemdQuery results + """ # FIXME (DML): rename this (this is dependent on the server side) diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 0645fe4a7..570c96514 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -1,7 +1,8 @@ """Tests for citrine.informatics.predictors.""" -import uuid -import pytest import mock +import pytest +import uuid +from random import random from citrine.informatics.data_sources import GemTableDataSource from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \ @@ -12,6 +13,10 @@ from citrine.informatics.predictors.single_prediction import SinglePrediction from citrine.informatics.design_candidate import DesignMaterial +from tests.utils.factories import FeatureEffectsResponseFactory +from tests.utils.session import FakeCall, FakeSession + + w = IntegerDescriptor("w", lower_bound=0, upper_bound=100) x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") @@ -485,3 +490,21 @@ def test_single_predict(graph_predictor): prediction_out = graph_predictor.predict(request) assert prediction_out.dump() == prediction_in.dump() assert session.post_resource.call_count == 1 + + +def test_feature_effects(graph_predictor): + feature_effects_response = FeatureEffectsResponseFactory() + feature_effects_as_dict = feature_effects_response.pop("_result_as_dict") + + session = FakeSession() + session.set_response(feature_effects_response) + + graph_predictor._session = session + graph_predictor._project_id = uuid.uuid4() + + fe = graph_predictor.feature_effects + + expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ + f"/versions/{graph_predictor.version}/shapley/query" + assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + assert fe.as_dict == feature_effects_as_dict diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 04da2e27a..83cf1fee4 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory): id = factory.Faker('uuid4') data = factory.SubFactory(AnalysisWorkflowDataDataFactory) metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory) + + +class FeatureEffectsResponseResultFactory(factory.DictFactory): + materials = factory.List([ + factory.Faker('uuid4', cast_to=None), + factory.Faker('uuid4', cast_to=None), + factory.Faker('uuid4', cast_to=None) + ]) + outputs = factory.Dict({ + "output1": factory.Dict({ + "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) + }), + "output2": factory.Dict({ + "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]), + "feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) + }) + }) + +class FeatureEffectsMetadataFactory(factory.DictFactory): + predictor_id = factory.Faker('uuid4') + predictor_version = factory.Faker('random_digit_not_null') + created = factory.SubFactory(UserTimestampDataFactory) + updated = factory.SubFactory(UserTimestampDataFactory) + status = 'SUCCEEDED' + + +class FeatureEffectsResponseFactory(factory.DictFactory): + query = {} # Presently, querying from the SDK is not allowed. + metadata = factory.SubFactory(FeatureEffectsMetadataFactory) + result = factory.SubFactory(FeatureEffectsResponseResultFactory) + _result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result)) + + +def _expand_condensed(result_obj): + whole_dict = {} + for output, feature_dict in result_obj["outputs"].items(): + whole_dict[output] = {} + for feature, values in feature_dict.items(): + whole_dict[output][feature] = dict(zip(result_obj["materials"], values)) + return whole_dict