Skip to content

Commit

Permalink
Merge pull request #839 from CitrineInformatics/PLA-12272/deprecate-t…
Browse files Browse the repository at this point in the history
…raining-data

PLA-12272: Deprecate training_data on non-graphs
  • Loading branch information
Sean Friedowitz authored Apr 3, 2023
2 parents 9b7189e + aeac353 commit 1637906
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.11.0'
__version__ = '2.11.1'
4 changes: 3 additions & 1 deletion src/citrine/informatics/predictors/auto_ml_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def __init__(self,
self.description: str = description
self.inputs: List[Descriptor] = inputs
self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST}
self.training_data: List[DataSource] = training_data or []
self.outputs = outputs

self._check_deprecated_training_data(training_data)
self.training_data: List[DataSource] = training_data or []

def __str__(self):
return '<AutoMLPredictor {!r}>'.format(self.name)
2 changes: 2 additions & 0 deletions src/citrine/informatics/predictors/mean_property_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(self,
self.impute_properties: bool = impute_properties
self.label: Optional[str] = label
self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties

self._check_deprecated_training_data(training_data)
self.training_data: List[DataSource] = training_data or []

def __str__(self):
Expand Down
15 changes: 14 additions & 1 deletion src/citrine/informatics/predictors/predictor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Type
import warnings
from typing import Optional, Type, List
from uuid import UUID

from citrine._rest.asynchronous_object import AsynchronousObject
from citrine._serialization import properties
from citrine._serialization.polymorphic_serializable import PolymorphicSerializable
from citrine._session import Session
from citrine.informatics.data_sources import DataSource
from citrine.resources.report import ReportResource
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
from citrine.informatics.predictors.single_prediction import SinglePrediction
Expand Down Expand Up @@ -110,3 +112,14 @@ def predict(self,
path = self._path() + '/predict'
res = self._session.post_resource(path, predict_request.dump(), version=self._api_version)
return SinglePrediction.build(res)

@classmethod
def _check_deprecated_training_data(cls, training_data: Optional[List[DataSource]]) -> None:
if training_data is not None:
warnings.warn(
f"The field `training_data` on {cls.__name__} predictors is deprecated "
"and will be removed in version 3.0.0. Include training data for all "
"sub-predictors on the parent GraphPredictor. Existing training data "
"on this predictor will be moved to the parent GraphPredictor upon registration.",
DeprecationWarning
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self,
training_data: Optional[List[DataSource]] = None):
self.name: str = name
self.description: str = description

self._check_deprecated_training_data(training_data)
self.training_data: List[DataSource] = training_data or []

if input_descriptor is not None:
Expand Down
14 changes: 5 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def valid_auto_ml_predictor_data(valid_gem_data_source_dict):
inputs=[x.dump()],
outputs=[z.dump()],
estimators=[AutoMLEstimator.RANDOM_FOREST.value],
training_data=[valid_gem_data_source_dict]
training_data=[]
)
return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance))

Expand Down Expand Up @@ -395,7 +395,7 @@ def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metr
@pytest.fixture
def valid_ing_formulation_predictor_data():
"""Produce valid data used for tests."""
from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor
from citrine.informatics.descriptors import RealDescriptor
instance = dict(
type='IngredientsToSimpleMixture',
name='Ingredients to formulation predictor',
Expand All @@ -416,7 +416,6 @@ def valid_ing_formulation_predictor_data():
def valid_generalized_mean_property_predictor_data():
"""Produce valid data used for tests."""
from citrine.informatics.descriptors import FormulationDescriptor
from citrine.informatics.data_sources import GemTableDataSource
formulation_descriptor = FormulationDescriptor.hierarchical()
instance = dict(
type='GeneralizedMeanProperty',
Expand All @@ -425,7 +424,6 @@ def valid_generalized_mean_property_predictor_data():
input=formulation_descriptor.dump(),
properties=['density'],
p=2,
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump()],
impute_properties=True,
default_properties={'density': 1.0},
label='solvent'
Expand All @@ -437,7 +435,6 @@ def valid_generalized_mean_property_predictor_data():
def valid_mean_property_predictor_data():
"""Produce valid data used for tests."""
from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor
from citrine.informatics.data_sources import GemTableDataSource
formulation_descriptor = FormulationDescriptor.flat()
density = RealDescriptor(key='density', lower_bound=0, upper_bound=100, units='g/cm^3')
instance = dict(
Expand All @@ -447,10 +444,10 @@ def valid_mean_property_predictor_data():
input=formulation_descriptor.dump(),
properties=[density.dump()],
p=2,
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump()],
impute_properties=True,
default_properties={'density': 1.0},
label='solvent'
label='solvent',
training_data=[]
)
return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance))

Expand Down Expand Up @@ -524,12 +521,11 @@ def invalid_predictor_data():
@pytest.fixture
def valid_simple_mixture_predictor_data():
"""Produce valid data used for tests."""
from citrine.informatics.data_sources import GemTableDataSource
instance = dict(
type='SimpleMixture',
name='Simple mixture predictor',
description='simple mixture description',
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump()]
training_data=[]
)
return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance))

Expand Down
48 changes: 35 additions & 13 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def auto_ml() -> AutoMLPredictor:
name='AutoML Predictor',
description='Predicts z from inputs w and x',
inputs=[w, x],
outputs=[z],
training_data=[data_source]
outputs=[z]
)


Expand All @@ -102,8 +101,7 @@ def auto_ml_no_outputs() -> AutoMLPredictor:
name='AutoML Predictor',
description='Predicts z from inputs w and x',
inputs=[w, x],
outputs=[],
training_data=[data_source]
outputs=[]
)


Expand All @@ -113,8 +111,7 @@ def auto_ml_multiple_outputs() -> AutoMLPredictor:
name='AutoML Predictor',
description='Predicts z from inputs w and x',
inputs=[w, x],
outputs=[z, y],
training_data=[data_source]
outputs=[z, y]
)


Expand All @@ -125,7 +122,7 @@ def graph_predictor() -> GraphPredictor:
name='Graph predictor',
description='description',
predictors=[uuid.uuid4(), uuid.uuid4()],
training_data=[data_source]
training_data=[data_source, formulation_data_source]
)


Expand Down Expand Up @@ -169,7 +166,6 @@ def mean_property_predictor() -> MeanPropertyPredictor:
input_descriptor=flat_formulation,
properties=[density, chain_type],
p=2.5,
training_data=[formulation_data_source],
impute_properties=True,
default_properties={'density': 1.0, 'Chain Type': 'Gaussian Coil'},
label='solvent'
Expand All @@ -182,7 +178,6 @@ def simple_mixture_predictor() -> SimpleMixturePredictor:
return SimpleMixturePredictor(
name='Simple mixture predictor',
description='Computes mean ingredient properties',
training_data=[formulation_data_source]
)


Expand Down Expand Up @@ -229,7 +224,7 @@ def test_graph_initialization(graph_predictor):
assert graph_predictor.name == 'Graph predictor'
assert graph_predictor.description == 'description'
assert len(graph_predictor.predictors) == 2
assert graph_predictor.training_data == [data_source]
assert graph_predictor.training_data == [data_source, formulation_data_source]
assert str(graph_predictor) == '<GraphPredictor \'Graph predictor\'>'


Expand Down Expand Up @@ -286,7 +281,6 @@ def test_auto_ml(auto_ml):
assert auto_ml.name == "AutoML Predictor"
assert auto_ml.description == "Predicts z from inputs w and x"
assert auto_ml.inputs == [w, x]
assert auto_ml.training_data == [data_source]
assert auto_ml.dump()['instance']['outputs'] == [z.dump()]

assert str(auto_ml) == "<AutoMLPredictor 'AutoML Predictor'>"
Expand Down Expand Up @@ -353,7 +347,6 @@ def test_mean_property_initialization(mean_property_predictor):
assert mean_property_predictor.properties == [density, chain_type]
assert mean_property_predictor.p == 2.5
assert mean_property_predictor.impute_properties == True
assert mean_property_predictor.training_data == [formulation_data_source]
assert mean_property_predictor.default_properties == {'density': 1.0, 'Chain Type': 'Gaussian Coil'}
assert mean_property_predictor.label == 'solvent'
expected_str = '<MeanPropertyPredictor \'Mean property predictor\'>'
Expand Down Expand Up @@ -385,7 +378,6 @@ def test_simple_mixture_predictor_initialization(simple_mixture_predictor):
assert simple_mixture_predictor.name == 'Simple mixture predictor'
assert simple_mixture_predictor.input_descriptor.key == FormulationKey.HIERARCHICAL.value
assert simple_mixture_predictor.output_descriptor.key == FormulationKey.FLAT.value
assert simple_mixture_predictor.training_data == [formulation_data_source]
expected_str = '<SimpleMixturePredictor \'Simple mixture predictor\'>'
assert str(simple_mixture_predictor) == expected_str

Expand All @@ -407,6 +399,7 @@ def test_status(valid_label_fractions_predictor_data, auto_ml):
predictor = LabelFractionsPredictor.build(valid_label_fractions_predictor_data)
assert predictor.succeeded() and not predictor.in_progress() and not predictor.failed()


def test_single_predict(graph_predictor):
"""Ensures we get a prediction back from a simple predict call"""
session = mock.Mock()
Expand All @@ -426,6 +419,35 @@ def test_single_predict(graph_predictor):
assert session.post_resource.call_count == 1


def test_deprecated_training_data():
with pytest.warns(DeprecationWarning):
AutoMLPredictor(
name="AutoML",
description="",
inputs=[x, y],
outputs=[z],
training_data=[data_source]
)

with pytest.warns(DeprecationWarning):
MeanPropertyPredictor(
name="SimpleMixture",
description="",
input_descriptor=flat_formulation,
properties=[x, y, z],
p=1.0,
impute_properties=True,
training_data=[data_source]
)

with pytest.warns(DeprecationWarning):
SimpleMixturePredictor(
name="Warning",
description="Description",
training_data=[data_source]
)


def test_formulation_deprecations():
with pytest.warns(DeprecationWarning):
SimpleMixturePredictor(
Expand Down
10 changes: 4 additions & 6 deletions tests/serialization/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from citrine.informatics.predictors import *


def test_simple_legacy_deserialization(valid_auto_ml_predictor_data):
def test_auto_ml_deserialization(valid_auto_ml_predictor_data):
"""Ensure that a deserialized SimplePredictor looks sane."""
predictor: AutoMLPredictor = AutoMLPredictor.build(valid_auto_ml_predictor_data)
assert predictor.name == 'AutoML predictor'
Expand All @@ -18,11 +18,10 @@ def test_simple_legacy_deserialization(valid_auto_ml_predictor_data):
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
assert len(predictor.outputs) == 1
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
assert len(predictor.training_data) == 1
assert predictor.training_data[0].table_id == UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762')
assert len(predictor.training_data) == 0


def test_polymorphic_legacy_deserialization(valid_auto_ml_predictor_data):
def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data):
"""Ensure that a polymorphically deserialized SimplePredictor looks sane."""
predictor: AutoMLPredictor = Predictor.build(valid_auto_ml_predictor_data)
assert predictor.name == 'AutoML predictor'
Expand All @@ -31,8 +30,7 @@ def test_polymorphic_legacy_deserialization(valid_auto_ml_predictor_data):
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
assert len(predictor.outputs) == 1
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
assert len(predictor.training_data) == 1
assert predictor.training_data[0].table_id == UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762')
assert len(predictor.training_data) == 0


def test_legacy_serialization(valid_auto_ml_predictor_data):
Expand Down

0 comments on commit 1637906

Please sign in to comment.