Skip to content

Commit

Permalink
Table and model update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kroenlein committed Oct 15, 2024
1 parent e57e092 commit 14025e5
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.11.0"
__version__ = "3.11.1"
1 change: 1 addition & 0 deletions src/citrine/informatics/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
from .predictor import *
from .node import *
from .attribute_accumulation_predictor import *
from .expression_predictor import *
from .graph_predictor import *
from .ingredient_fractions_predictor import *
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import List

from citrine._rest.resource import Resource
from citrine._serialization import properties as _properties
from citrine.informatics.descriptors import Descriptor
from citrine.informatics.predictors import PredictorNode

__all__ = ['AttributeAccumulationPredictor']


class AttributeAccumulationPredictor(Resource["AttributeAccumulationPredictor"], PredictorNode):
"""A predictor that computes an output from an expression and set of bounded inputs.
For a discussion of expression syntax and a list of allowed symbols,
please see the :ref:`documentation<Attribute Accumulation>`.
Parameters
----------
name: str
name of the configuration
description: str
the description of the predictor
attributes: List[Descriptor]
the attributes that are accumulated from ancestor nodes
"""

attributes = _properties.List(_properties.Object(Descriptor), 'attributes')
sequential = _properties.Boolean('sequential')

typ = _properties.String('type', default='AttributeAccumulation', deserializable=False)

def __init__(self,
name: str,
*,
description: str,
attributes: List[Descriptor],
sequential: bool):
self.name: str = name
self.description: str = description
self.attributes: List[Descriptor] = attributes
self.sequential: bool = sequential

def __str__(self):
return '<AttributeAccumulationPredictor {!r}>'.format(self.name)
69 changes: 69 additions & 0 deletions src/citrine/informatics/predictors/graph_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,72 @@ def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
path = self._path() + '/predict'
res = self._session.post_resource(path, predict_request.dump(), version=self._api_version)
return SinglePrediction.build(res)

def _convert_to_multistep(self) -> "GraphPredictor":
"""Make the GraphPredictor look as if generated with a MULTISTEP_MATERIALS datasource."""
from citrine.informatics.predictors import (
AttributeAccumulationPredictor, MolecularStructureFeaturizer,
LabelFractionsPredictor, SimpleMixturePredictor, IngredientFractionsPredictor,
AutoMLPredictor, MeanPropertyPredictor, ChemicalFormulaFeaturizer
)

automl_outputs = {}
featurizer_outputs = set()
automl_inputs = {}

for predictor in self.predictors:
if isinstance(predictor, AttributeAccumulationPredictor):
raise ValueError("Graph already contains Attribute Accumulation nodes")
elif isinstance(predictor, AutoMLPredictor):
for descriptor in predictor.outputs:
automl_outputs[descriptor.key] = descriptor
for descriptor in predictor.inputs:
automl_inputs[descriptor.key] = descriptor
elif isinstance(predictor, MeanPropertyPredictor):
for descriptor in predictor.properties:
featurizer_outputs.add(
f"mean of property {descriptor.key} in {predictor.input_descriptor.key}"
)
elif isinstance(predictor, IngredientFractionsPredictor):
for ingredient in predictor.ingredients:
featurizer_outputs.add(
f"{ingredient} share in {predictor.input_descriptor.key}"
)
elif isinstance(predictor, LabelFractionsPredictor):
for label in predictor.labels:
featurizer_outputs.add(
f"{label} share in {predictor.input_descriptor.key}"
)
elif isinstance(predictor, (SimpleMixturePredictor, ChemicalFormulaFeaturizer,
MolecularStructureFeaturizer)):
pass
else:
# IngredientsToFormulationRelation, ExpressionPredictor,
# IngredientsToFormulationPredictor
raise NotImplementedError(f"Unhandled predictor type: {type(predictor)}")

output_accumulator = AttributeAccumulationPredictor(
name="Output variable accumulation",
description="Output variables encountered in the material history. "
"Only sequential mixing steps are considered.",
attributes=list(automl_outputs.values()),
sequential=True
)
input_accumulator = AttributeAccumulationPredictor(
name="Attribute accumulation",
description="Parameters/conditions encountered in the material history. "
"Most recent values are selected first.",
attributes=[automl_inputs[key] for key in automl_inputs
if key not in featurizer_outputs],
sequential=False
)

update = GraphPredictor(
name=self.name,
description=self.description,
predictors=self.predictors + [output_accumulator, input_accumulator],
training_data=self.training_data
)
update.uid = self.uid

return update
2 changes: 2 additions & 0 deletions src/citrine/informatics/predictors/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PredictorNode(PolymorphicSerializable["PredictorNode"], Predictor):
@classmethod
def get_type(cls, data) -> Type['PredictorNode']:
"""Return the subtype."""
from .attribute_accumulation_predictor import AttributeAccumulationPredictor
from .expression_predictor import ExpressionPredictor
from .molecular_structure_featurizer import MolecularStructureFeaturizer
from .ingredients_to_formulation_predictor import IngredientsToFormulationPredictor
Expand All @@ -30,6 +31,7 @@ def get_type(cls, data) -> Type['PredictorNode']:
from .chemical_formula_featurizer import ChemicalFormulaFeaturizer
type_dict = {
"AnalyticExpression": ExpressionPredictor,
"AttributeAccumulation": AttributeAccumulationPredictor,
"MoleculeFeaturizer": MolecularStructureFeaturizer,
"IngredientsToSimpleMixture": IngredientsToFormulationPredictor,
"MeanProperty": MeanPropertyPredictor,
Expand Down
25 changes: 24 additions & 1 deletion src/citrine/resources/table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from citrine.gemtables.variables import (
Variable, IngredientIdentifierByProcessTemplateAndName, IngredientQuantityByProcessAndName,
IngredientQuantityDimension, IngredientIdentifierInOutput, IngredientQuantityInOutput,
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput,
AttributeByTemplateAndObjectTemplate, LocalAttributeAndObject
)

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -429,6 +430,28 @@ def add_all_ingredients_in_output(self, *,
new_config.version_uid = copy(self.version_uid)
return new_config

def _convert_to_multistep(self) -> "TableConfig":
"""Convert the TableConfig to look like something generated by MULTISTEP_MATERIALS."""
dup: TableConfig = TableConfig.build(self.dump())

def _convert_local(old: Variable) -> Variable:
if isinstance(old, AttributeByTemplateAndObjectTemplate):
return LocalAttributeAndObject(
name=old.name,
headers=old.headers,
template=old.attribute_template,
object_template=old.object_template,
attribute_constraints=old.attribute_constraints,
type_selector=old.type_selector,
)
else:
return old

dup.variables = [_convert_local(x) for x in dup.variables]
dup.generation_algorithm = TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS

return dup


class TableConfigCollection(Collection[TableConfig]):
"""Represents the collection of all Table Configs associated with a project."""
Expand Down
61 changes: 61 additions & 0 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,28 @@ def ingredient_fractions_predictor() -> IngredientFractionsPredictor:
)


@pytest.fixture
def input_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor:
"""Build an accumulation node for model inputs."""
return AttributeAccumulationPredictor(
name='Input accumulation predictor',
description='Bubbles attributes up through the graph',
attributes=auto_ml.inputs,
sequential=False
)


@pytest.fixture
def output_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor:
"""Build an accumulation node for model outputs."""
return AttributeAccumulationPredictor(
name='Output accumulation predictor',
description='Bubbles attributes up through the graph',
attributes=auto_ml.outputs,
sequential=True
)


def test_simple_report(graph_predictor):
"""Ensures we get a report from a simple predictor post_build call"""
with pytest.raises(ValueError):
Expand Down Expand Up @@ -453,6 +475,17 @@ def test_ingredient_fractions_property_initialization(ingredient_fractions_predi
assert str(ingredient_fractions_predictor) == expected_str


def test_attribute_accumulation_predictor_initialization(input_accumulation_predictor, output_accumulation_predictor):
"""Make sure the correct fields go to the correct places for an attribute accumulation predictor."""
assert len(input_accumulation_predictor.attributes) == 2
expected_input = f"<AttributeAccumulationPredictor '{input_accumulation_predictor.name}'>"
assert str(input_accumulation_predictor) == expected_input

assert len(output_accumulation_predictor.attributes) == 1
expected_output = f"<AttributeAccumulationPredictor '{output_accumulation_predictor.name}'>"
assert str(output_accumulation_predictor) == expected_output


def test_status(graph_predictor, valid_graph_predictor_data):
"""Ensure we can check the status of predictor validation."""
# A locally built predictor should be "False" for all status checks
Expand Down Expand Up @@ -485,3 +518,31 @@ def test_single_predict(graph_predictor):
prediction_out = graph_predictor.predict(request)
assert prediction_out.dump() == prediction_in.dump()
assert session.post_resource.call_count == 1

def test__convert_to_multistep(molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor,
label_fractions_predictor, expression_predictor, output_accumulation_predictor,
input_accumulation_predictor):
"""Verify graph predictor multistep material update."""
graph_predictor = GraphPredictor(
name='Graph predictor',
description='description',
predictors=[molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor, label_fractions_predictor],
training_data=[data_source, formulation_data_source]
)
updated = graph_predictor._convert_to_multistep()
assert len(updated.predictors) == len(graph_predictor.predictors) + 2
generated_accumulation = [p for p in updated.predictors if isinstance(p, AttributeAccumulationPredictor)]
assert generated_accumulation[0].attributes == output_accumulation_predictor.attributes
assert generated_accumulation[1].attributes == input_accumulation_predictor.attributes

with pytest.raises(ValueError):
updated._convert_to_multistep()


with pytest.raises(NotImplementedError):
GraphPredictor(
name='Graph predictor',
description='description',
predictors=[expression_predictor],
training_data=[data_source, formulation_data_source]
)._convert_to_multistep()
20 changes: 19 additions & 1 deletion tests/resources/test_table_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
IngredientQuantityDimension, IngredientQuantityByProcessAndName, \
IngredientIdentifierByProcessTemplateAndName, TerminalMaterialIdentifier, \
IngredientQuantityInOutput, IngredientIdentifierInOutput, \
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput, AttributeByTemplateAndObjectTemplate, \
LocalAttribute, LocalAttributeAndObject
from citrine.resources.table_config import TableConfig, TableConfigCollection, TableBuildAlgorithm, \
TableFromGemdQueryAlgorithm
from citrine.resources.data_concepts import CITRINE_SCOPE
Expand Down Expand Up @@ -900,3 +901,20 @@ def test_update_unregistered_fail(collection, session):
def test_delete(collection):
with pytest.raises(NotImplementedError):
collection.delete(empty_defn().config_uid)


def test__convert_to_multistep():
variables = [
AttributeByTemplate("One", headers=["one"], template=uuid4()),
AttributeByTemplateAndObjectTemplate("Two", headers=["two"], attribute_template=uuid4(), object_template=uuid4()),
LocalAttribute("Three", headers=["three"], template=uuid4()),
LocalAttributeAndObject("Four", headers=["four"], template=uuid4(), object_template=uuid4()),
]
columns = [MeanColumn(data_source=v.name, target_units="") for v in variables]
config: TableConfig = TableConfig.build(TableConfigDataFactory(
variables=[v.dump() for v in variables],
columns=[c.dump() for c in columns],
))
updated = config._convert_to_multistep()
assert len(config.variables) == len(config.variables)
assert not any(isinstance(x, AttributeByTemplateAndObjectTemplate) for x in updated.variables)

0 comments on commit 14025e5

Please sign in to comment.