Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PNE-6367] Add support for feature effects. #982

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.11.6"
__version__ = "3.12.0"
75 changes: 75 additions & 0 deletions src/citrine/informatics/feature_effects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict
from uuid import UUID

from citrine._rest.resource import Resource
from citrine._serialization import properties


class ShapleyMaterial(Resource):
"""The feature effect of a material."""

material_id = properties.UUID('material_id', serializable=False)
value = properties.Float('value', serializable=False)


class ShapleyFeature(Resource):
"""All feature effects for this feature by material."""

feature = properties.String('feature', serializable=False)
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
serializable=False)

@property
def material_dict(self) -> Dict[UUID, float]:
"""Presents the feature's effects as a dictionary by material."""
return {material.material_id: material.value for material in self.materials}


class ShapleyOutput(Resource):
"""All feature effects for this output by feature."""

output = properties.String('output', serializable=False)
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)

@property
def feature_dict(self) -> Dict[str, Dict[UUID, float]]:
"""Presents the output's feature effects as a dictionary by feature."""
return {feature.feature: feature.material_dict for feature in self.features}


class FeatureEffects(Resource):
"""Captures information about the feature effects associated with a predictor."""

predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
status = properties.String('metadata.status', serializable=False)
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
serializable=False)

outputs = properties.List(properties.Object(ShapleyOutput), 'resultobj', serializable=False)

@classmethod
def _pre_build(cls, data: dict) -> Dict:
shapley = data["result"]
material_ids = shapley["materials"]

outputs = []
for output, feature_dict in shapley["outputs"].items():
features = []
for feature, values in feature_dict.items():
items = zip(material_ids, values)
materials = [{"material_id": mid, "value": value} for mid, value in items]
features.append({
"feature": feature,
"materials": materials
})

outputs.append({"output": output, "features": features})

data["resultobj"] = outputs
return data

@property
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
"""Presents the feature effects as a dictionary by output."""
return {output.output: output.feature_dict for output in self.outputs}
11 changes: 10 additions & 1 deletion src/citrine/informatics/predictors/graph_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from citrine._session import Session
from citrine._utils.functions import format_escaped_url
from citrine.informatics.data_sources import DataSource
from citrine.informatics.feature_effects import FeatureEffects
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.predictors import PredictorNode, Predictor
from citrine.informatics.reports import Report
from citrine.resources.report import ReportResource

__all__ = ['GraphPredictor']
Expand Down Expand Up @@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict:
}

@property
def report(self):
def report(self) -> Report:
"""Fetch the predictor report."""
if self.uid is None or self._session is None or self._project_id is None \
or getattr(self, "version", None) is None:
Expand All @@ -113,6 +115,13 @@ def report(self):
report_resource = ReportResource(self._project_id, self._session)
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)

@property
def feature_effects(self) -> FeatureEffects:
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
path = self._path() + '/shapley/query'
response = self._session.post_resource(path, {}, version=self._api_version)
return FeatureEffects.build(response)

def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
"""Make a one-off prediction with this predictor."""
path = self._path() + '/predict'
Expand Down
1 change: 1 addition & 0 deletions src/citrine/resources/table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
The query used to define the materials underpinning this table
generation_algorithm: TableFromGemdQueryAlgorithm
Which algorithm was used to generate the config based on the GemdQuery results

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My local flake8 was complaining about this file, despite me not touching it. 🤷‍♂️

"""

# FIXME (DML): rename this (this is dependent on the server side)
Expand Down
27 changes: 25 additions & 2 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for citrine.informatics.predictors."""
import uuid
import pytest
import mock
import pytest
import uuid
from random import random

from citrine.informatics.data_sources import GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
Expand All @@ -12,6 +13,10 @@
from citrine.informatics.predictors.single_prediction import SinglePrediction
from citrine.informatics.design_candidate import DesignMaterial

from tests.utils.factories import FeatureEffectsResponseFactory
from tests.utils.session import FakeCall, FakeSession


w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
Expand Down Expand Up @@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
prediction_out = graph_predictor.predict(request)
assert prediction_out.dump() == prediction_in.dump()
assert session.post_resource.call_count == 1


def test_feature_effects(graph_predictor):
feature_effects_response = FeatureEffectsResponseFactory()
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")

session = FakeSession()
session.set_response(feature_effects_response)

graph_predictor._session = session
graph_predictor._project_id = uuid.uuid4()

fe = graph_predictor.feature_effects

expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
f"/versions/{graph_predictor.version}/shapley/query"
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
assert fe.as_dict == feature_effects_as_dict
40 changes: 40 additions & 0 deletions tests/utils/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
id = factory.Faker('uuid4')
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)


class FeatureEffectsResponseResultFactory(factory.DictFactory):
materials = factory.List([
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None),
factory.Faker('uuid4', cast_to=None)
])
outputs = factory.Dict({
"output1": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
}),
"output2": factory.Dict({
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
})
})

class FeatureEffectsMetadataFactory(factory.DictFactory):
predictor_id = factory.Faker('uuid4')
predictor_version = factory.Faker('random_digit_not_null')
created = factory.SubFactory(UserTimestampDataFactory)
updated = factory.SubFactory(UserTimestampDataFactory)
status = 'SUCCEEDED'


class FeatureEffectsResponseFactory(factory.DictFactory):
query = {} # Presently, querying from the SDK is not allowed.
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))


def _expand_condensed(result_obj):
whole_dict = {}
for output, feature_dict in result_obj["outputs"].items():
whole_dict[output] = {}
for feature, values in feature_dict.items():
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
return whole_dict