diff --git a/lightwood/__about__.py b/lightwood/__about__.py index edff5943a..7fccf38f3 100644 --- a/lightwood/__about__.py +++ b/lightwood/__about__.py @@ -1,6 +1,6 @@ __title__ = 'lightwood' __package_name__ = 'lightwood' -__version__ = '23.6.4.0' +__version__ = '23.7.1.0' __description__ = "Lightwood is a toolkit for automatic machine learning model building" __email__ = "community@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/lightwood/analysis/analyze.py b/lightwood/analysis/analyze.py index 87e21b9c9..7f2546218 100644 --- a/lightwood/analysis/analyze.py +++ b/lightwood/analysis/analyze.py @@ -1,5 +1,6 @@ from typing import Dict, List, Tuple, Optional +import numpy as np from dataprep_ml import StatisticalAnalysis from lightwood.helpers.log import log @@ -8,7 +9,7 @@ from lightwood.analysis.base import BaseAnalysisBlock from lightwood.data.encoded_ds import EncodedDs from lightwood.encoder.text.pretrained import PretrainedLangEncoder -from lightwood.api.types import ModelAnalysis, TimeseriesSettings, PredictionArguments +from lightwood.api.types import ModelAnalysis, ProblemDefinition, PredictionArguments def model_analyzer( @@ -17,7 +18,7 @@ def model_analyzer( train_data: EncodedDs, stats_info: StatisticalAnalysis, target: str, - tss: TimeseriesSettings, + pdef: ProblemDefinition, dtype_dict: Dict[str, str], accuracy_functions, ts_analysis: Dict, @@ -39,6 +40,7 @@ def model_analyzer( runtime_analyzer = {} data_type = dtype_dict[target] + tss = pdef.timeseries_settings # retrieve encoded data representations encoded_train_data = train_data @@ -46,47 +48,56 @@ def model_analyzer( data = encoded_val_data.data_frame input_cols = list([col for col in data.columns if col != target]) - # predictive task - is_numerical = data_type in (dtype.integer, dtype.float, dtype.num_tsarray, dtype.quantity) - is_classification = data_type in (dtype.categorical, dtype.binary, dtype.cat_tsarray) - is_multi_ts = tss.is_timeseries and tss.horizon > 1 - has_pretrained_text_enc = any([isinstance(enc, PretrainedLangEncoder) - for enc in encoded_train_data.encoders.values()]) - - # raw predictions for validation dataset - args = {} if not is_classification else {"predict_proba": True} - filtered_df = encoded_val_data.data_frame - normal_predictions = None - - if len(analysis_blocks) > 0: - normal_predictions = predictor(encoded_val_data, args=PredictionArguments.from_dict(args)) - normal_predictions = normal_predictions.set_index(encoded_val_data.data_frame.index) - - # ------------------------- # - # Run analysis blocks, both core and user-defined - # ------------------------- # - kwargs = { - 'predictor': predictor, - 'target': target, - 'input_cols': input_cols, - 'dtype_dict': dtype_dict, - 'normal_predictions': normal_predictions, - 'data': filtered_df, - 'train_data': train_data, - 'encoded_val_data': encoded_val_data, - 'is_classification': is_classification, - 'is_numerical': is_numerical, - 'is_multi_ts': is_multi_ts, - 'stats_info': stats_info, - 'tss': tss, - 'ts_analysis': ts_analysis, - 'accuracy_functions': accuracy_functions, - 'has_pretrained_text_enc': has_pretrained_text_enc - } - - for block in analysis_blocks: - log.info("The block %s is now running its analyze() method", block.__class__.__name__) - runtime_analyzer = block.analyze(runtime_analyzer, **kwargs) + if not pdef.embedding_only: + # predictive task + is_numerical = data_type in (dtype.integer, dtype.float, dtype.num_tsarray, dtype.quantity) + is_classification = data_type in (dtype.categorical, dtype.binary, dtype.cat_tsarray) + is_multi_ts = tss.is_timeseries and tss.horizon > 1 + has_pretrained_text_enc = any([isinstance(enc, PretrainedLangEncoder) + for enc in encoded_train_data.encoders.values()]) + + # raw predictions for validation dataset + args = {} if not is_classification else {"predict_proba": True} + normal_predictions = None + + if len(analysis_blocks) > 0: + if tss.is_timeseries: + # we retrieve the first entry per group (closest to supervision cutoff) + if tss.group_by: + encoded_val_data.data_frame['__mdb_val_idx'] = np.arange(len(encoded_val_data)) + idxs = encoded_val_data.data_frame.groupby(by=tss.group_by).first()['__mdb_val_idx'].values + encoded_val_data.data_frame = encoded_val_data.data_frame.iloc[idxs, :] + if encoded_val_data.cache_built: + encoded_val_data.X_cache = encoded_val_data.X_cache[idxs, :] + encoded_val_data.Y_cache = encoded_val_data.Y_cache[idxs, :] + normal_predictions = predictor(encoded_val_data, args=PredictionArguments.from_dict(args)) + normal_predictions = normal_predictions.set_index(encoded_val_data.data_frame.index) + + # ------------------------- # + # Run analysis blocks, both core and user-defined + # ------------------------- # + kwargs = { + 'predictor': predictor, + 'target': target, + 'input_cols': input_cols, + 'dtype_dict': dtype_dict, + 'normal_predictions': normal_predictions, + 'data': encoded_val_data.data_frame, + 'train_data': train_data, + 'encoded_val_data': encoded_val_data, + 'is_classification': is_classification, + 'is_numerical': is_numerical, + 'is_multi_ts': is_multi_ts, + 'stats_info': stats_info, + 'tss': tss, + 'ts_analysis': ts_analysis, + 'accuracy_functions': accuracy_functions, + 'has_pretrained_text_enc': has_pretrained_text_enc + } + + for block in analysis_blocks: + log.info("The block %s is now running its analyze() method", block.__class__.__name__) + runtime_analyzer = block.analyze(runtime_analyzer, **kwargs) # ------------------------- # # Populate ModelAnalysis object diff --git a/lightwood/api/json_ai.py b/lightwood/api/json_ai.py index b19f3b249..b77d1f143 100644 --- a/lightwood/api/json_ai.py +++ b/lightwood/api/json_ai.py @@ -1,19 +1,20 @@ # TODO: _add_implicit_values unit test ensures NO changes for a fully specified file. +import inspect from copy import deepcopy + +from type_infer.dtype import dtype from type_infer.base import TypeInformation from dataprep_ml import StatisticalAnalysis +from lightwood.helpers.log import log from lightwood.helpers.templating import call, inline_dict, align -from lightwood.helpers.templating import _consolidate_analysis_blocks -from type_infer.dtype import dtype +from lightwood.helpers.templating import _consolidate_analysis_blocks, _add_cls_kwarg from lightwood.api.types import ( JsonAI, ProblemDefinition, ) -import inspect -from lightwood.helpers.log import log from lightwood.__about__ import __version__ as lightwood_version - +import lightwood.ensemble # For custom modules, we create a module loader with necessary imports below IMPORT_EXTERNAL_DIRS = """ @@ -535,29 +536,29 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI: problem_definition = json_ai.problem_definition tss = problem_definition.timeseries_settings is_ts = tss.is_timeseries + # tsa_val = "self.ts_analysis" if is_ts else None # TODO: remove + mixers = json_ai.model['args']['submodels'] # Add implicit ensemble arguments - json_ai.model["args"]["target"] = json_ai.model["args"].get("target", "$target") - json_ai.model["args"]["data"] = json_ai.model["args"].get("data", "encoded_test_data") - json_ai.model["args"]["mixers"] = json_ai.model["args"].get("mixers", "$mixers") - json_ai.model["args"]["fit"] = json_ai.model["args"].get("fit", True) - json_ai.model["args"]["args"] = json_ai.model["args"].get("args", "$pred_args") # TODO correct? - - # @TODO: change this to per-parameter basis and signature inspection - if json_ai.model["module"] in ("BestOf", "ModeEnsemble", "WeightedMeanEnsemble"): - json_ai.model["args"]["accuracy_functions"] = json_ai.model["args"].get("accuracy_functions", - "$accuracy_functions") - - if json_ai.model["module"] in ("BestOf", "TsStackedEnsemble", "WeightedMeanEnsemble"): - tsa_val = "self.ts_analysis" if is_ts else None - json_ai.model["args"]["ts_analysis"] = json_ai.model["args"].get("ts_analysis", tsa_val) + param_pairs = { + 'target': json_ai.model["args"].get("target", "$target"), + 'data': json_ai.model["args"].get("data", "encoded_test_data"), + 'mixers': json_ai.model["args"].get("mixers", "$mixers"), + 'fit': json_ai.model["args"].get("fit", True), + 'args': json_ai.model["args"].get("args", "$pred_args"), + 'accuracy_functions': json_ai.model["args"].get("accuracy_functions", "$accuracy_functions"), + 'ts_analysis': json_ai.model["args"].get("ts_analysis", "self.ts_analysis" if is_ts else None), + 'dtype_dict': json_ai.model["args"].get("dtype_dict", "$dtype_dict"), + } + ensemble_cls = getattr(lightwood.ensemble, json_ai.model["module"]) + filtered_params = {} + for p_name, p_value in param_pairs.items(): + _add_cls_kwarg(ensemble_cls, filtered_params, p_name, p_value) - if json_ai.model["module"] in ("MeanEnsemble", "ModeEnsemble", "StackedEnsemble", "TsStackedEnsemble", - "WeightedMeanEnsemble"): - json_ai.model["args"]["dtype_dict"] = json_ai.model["args"].get("dtype_dict", "$dtype_dict") + json_ai.model["args"] = filtered_params + json_ai.model["args"]['submodels'] = mixers # add mixers back in # Add implicit mixer arguments - mixers = json_ai.model['args']['submodels'] for i in range(len(mixers)): if not mixers[i].get("args", False): mixers[i]["args"] = {} @@ -685,7 +686,7 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI: "module": "model_analyzer", "args": { "stats_info": "$statistical_analysis", - "tss": "$problem_definition.timeseries_settings", + "pdef": "$problem_definition", "accuracy_functions": "$accuracy_functions", "predictor": "$ensemble", "data": "encoded_test_data", @@ -1170,7 +1171,12 @@ def code_from_json_ai(json_ai: JsonAI) -> str: # Prepare mixers log.info(f'[Learn phase 6/{n_phases}] - Mixer training') -self.fit(enc_train_test) +if not self.problem_definition.embedding_only: + self.fit(enc_train_test) +else: + self.mixers = [] + self.ensemble = Embedder(self.target, mixers=list(), data=enc_train_test['train']) + self.supports_proba = self.ensemble.supports_proba # Analyze the ensemble log.info(f'[Learn phase 7/{n_phases}] - Ensemble analysis') @@ -1221,9 +1227,17 @@ def code_from_json_ai(json_ai: JsonAI) -> str: encoded_data = encoded_ds.get_encoded_data(include_target=False) log.info(f'[Predict phase 3/{{n_phases}}] - Calling ensemble') -df = self.ensemble(encoded_ds, args=self.pred_args) +if self.pred_args.return_embedding: + embedder = Embedder(self.target, mixers=list(), data=encoded_ds) + df = embedder(encoded_ds, args=self.pred_args) +else: + df = self.ensemble(encoded_ds, args=self.pred_args) -if not self.pred_args.all_mixers: +if not(any( + [self.pred_args.all_mixers, + self.pred_args.return_embedding, + self.problem_definition.embedding_only] + )): log.info(f'[Predict phase 4/{{n_phases}}] - Analyzing output') df, global_insights = {call(json_ai.explainer)} self.global_insights = {{**self.global_insights, **global_insights}} diff --git a/lightwood/api/types.py b/lightwood/api/types.py index aa193b19d..5cdff7fe5 100644 --- a/lightwood/api/types.py +++ b/lightwood/api/types.py @@ -185,6 +185,7 @@ class ProblemDefinition: timeseries_settings: TimeseriesSettings anomaly_detection: bool use_default_analysis: bool + embedding_only: bool dtype_dict: Optional[dict] ignore_features: List[str] fit_on_all: bool @@ -220,6 +221,7 @@ def from_dict(obj: Dict): ignore_features = obj.get('ignore_features', []) fit_on_all = obj.get('fit_on_all', True) use_default_analysis = obj.get('use_default_analysis', True) + embedding_only = obj.get('embedding_only', False) strict_mode = obj.get('strict_mode', True) seed_nr = obj.get('seed_nr', 1) problem_definition = ProblemDefinition( @@ -237,6 +239,7 @@ def from_dict(obj: Dict): dtype_dict=dtype_dict, ignore_features=ignore_features, use_default_analysis=use_default_analysis, + embedding_only=embedding_only, fit_on_all=fit_on_all, strict_mode=strict_mode, seed_nr=seed_nr @@ -453,6 +456,7 @@ class PredictionArguments: simple_ts_bounds: bool = False time_format: str = '' force_ts_infer: bool = False + return_embedding: bool = False @staticmethod def from_dict(obj: Dict): @@ -474,6 +478,7 @@ def from_dict(obj: Dict): simple_ts_bounds = obj.get('simple_ts_bounds', PredictionArguments.simple_ts_bounds) time_format = obj.get('time_format', PredictionArguments.time_format) force_ts_infer = obj.get('force_ts_infer', PredictionArguments.force_ts_infer) + return_embedding = obj.get('return_embedding', PredictionArguments.return_embedding) pred_args = PredictionArguments( predict_proba=predict_proba, @@ -485,6 +490,7 @@ def from_dict(obj: Dict): simple_ts_bounds=simple_ts_bounds, time_format=time_format, force_ts_infer=force_ts_infer, + return_embedding=return_embedding, ) return pred_args diff --git a/lightwood/data/timeseries_transform.py b/lightwood/data/timeseries_transform.py index 950c1ccb3..135447210 100644 --- a/lightwood/data/timeseries_transform.py +++ b/lightwood/data/timeseries_transform.py @@ -109,6 +109,7 @@ def transform_timeseries( secondary_type_dict[oby] = dtype_dict[oby] original_df[f'__mdb_original_{oby}'] = original_df[oby] + original_df = _ts_to_obj(original_df, [oby] + tss.historical_columns) group_lengths = [] if len(gb_arr) > 0: df_arr = [] @@ -136,30 +137,30 @@ def transform_timeseries( make_preds = [True for _ in range(len(df_arr[i]))] df_arr[i]['__make_predictions'] = make_preds - if len(original_df) > 500: + if len(df_arr) > 1 and len(original_df) > 5000: # @TODO: restore possibility to override this with args - nr_procs = get_nr_procs(original_df) + biggest_sub_df = df_arr[np.argmax(group_lengths)] + nr_procs = min(get_nr_procs(biggest_sub_df), len(df_arr)) log.info(f'Using {nr_procs} processes to reshape.') - pool = mp.Pool(processes=nr_procs) - # Make type `object` so that dataframe cells can be python lists - df_arr = pool.map(partial(_ts_to_obj, historical_columns=[oby] + tss.historical_columns), df_arr) - df_arr = pool.map( - partial( - _ts_add_previous_rows, order_cols=[oby] + tss.historical_columns, window=window), - df_arr) - df_arr = pool.map(partial(_ts_add_future_target, target=target, horizon=tss.horizon, - data_dtype=tss.target_type, mode=mode), - df_arr) - - if tss.use_previous_target: + with mp.Pool(processes=nr_procs) as pool: df_arr = pool.map( - partial(_ts_add_previous_target, target=target, window=tss.window), - df_arr) - pool.close() - pool.join() + partial(_ts_add_previous_rows, order_cols=[oby] + tss.historical_columns, window=window), + df_arr + ) + + df_arr = pool.map( + partial(_ts_add_future_target, target=target, horizon=tss.horizon, + data_dtype=tss.target_type, mode=mode), + df_arr + ) + + if tss.use_previous_target: + df_arr = pool.map( + partial(_ts_add_previous_target, target=target, window=tss.window), + df_arr + ) else: for i in range(n_groups): - df_arr[i] = _ts_to_obj(df_arr[i], historical_columns=[oby] + tss.historical_columns) df_arr[i] = _ts_add_previous_rows(df_arr[i], order_cols=[oby] + tss.historical_columns, window=window) df_arr[i] = _ts_add_future_target(df_arr[i], target=target, horizon=tss.horizon, diff --git a/lightwood/encoder/__init__.py b/lightwood/encoder/__init__.py index ab104cb40..e71623d62 100644 --- a/lightwood/encoder/__init__.py +++ b/lightwood/encoder/__init__.py @@ -9,6 +9,7 @@ from lightwood.encoder.text.short import ShortTextEncoder from lightwood.encoder.text.vocab import VocabularyEncoder from lightwood.encoder.text.rnn import RnnEncoder as TextRnnEncoder +from lightwood.encoder.categorical.simple_label import SimpleLabelEncoder from lightwood.encoder.categorical.onehot import OneHotEncoder from lightwood.encoder.categorical.binary import BinaryEncoder from lightwood.encoder.categorical.autoencoder import CategoricalAutoEncoder @@ -23,5 +24,5 @@ __all__ = ['BaseEncoder', 'DatetimeEncoder', 'Img2VecEncoder', 'NumericEncoder', 'TsNumericEncoder', 'TsArrayNumericEncoder', 'ShortTextEncoder', 'VocabularyEncoder', 'TextRnnEncoder', 'OneHotEncoder', 'CategoricalAutoEncoder', 'TimeSeriesEncoder', 'ArrayEncoder', 'MultiHotEncoder', 'TsCatArrayEncoder', - 'NumArrayEncoder', 'CatArrayEncoder', + 'NumArrayEncoder', 'CatArrayEncoder', 'SimpleLabelEncoder', 'PretrainedLangEncoder', 'BinaryEncoder', 'DatetimeNormalizerEncoder', 'MFCCEncoder'] diff --git a/lightwood/encoder/categorical/__init__.py b/lightwood/encoder/categorical/__init__.py index 82e613ddd..ee4819744 100644 --- a/lightwood/encoder/categorical/__init__.py +++ b/lightwood/encoder/categorical/__init__.py @@ -1,5 +1,6 @@ from lightwood.encoder.categorical.onehot import OneHotEncoder +from lightwood.encoder.categorical.simple_label import SimpleLabelEncoder from lightwood.encoder.categorical.multihot import MultiHotEncoder from lightwood.encoder.categorical.autoencoder import CategoricalAutoEncoder -__all__ = ['OneHotEncoder', 'MultiHotEncoder', 'CategoricalAutoEncoder'] +__all__ = ['OneHotEncoder', 'SimpleLabelEncoder', 'MultiHotEncoder', 'CategoricalAutoEncoder'] diff --git a/lightwood/encoder/categorical/autoencoder.py b/lightwood/encoder/categorical/autoencoder.py index b26e5ca14..c96d1fe72 100644 --- a/lightwood/encoder/categorical/autoencoder.py +++ b/lightwood/encoder/categorical/autoencoder.py @@ -3,6 +3,7 @@ import torch from torch.utils.data import DataLoader from lightwood.mixer.helpers.ranger import Ranger +from lightwood.encoder.categorical.simple_label import SimpleLabelEncoder from lightwood.encoder.categorical.onehot import OneHotEncoder from lightwood.encoder.categorical.gym import Gym from lightwood.encoder.base import BaseEncoder @@ -24,13 +25,14 @@ class CategoricalAutoEncoder(BaseEncoder): is_trainable_encoder: bool = True def __init__( - self, - stop_after: float = 3600, - is_target: bool = False, - max_encoded_length: int = 100, - desired_error: float = 0.01, - batch_size: int = 200, - device: str = '', + self, + stop_after: float = 3600, + is_target: bool = False, + max_encoded_length: int = 100, + desired_error: float = 0.01, + batch_size: int = 200, + device: str = '', + input_encoder: str = None ): """ :param stop_after: Stops training with provided time limit (sec) @@ -39,6 +41,7 @@ def __init__( :param desired_error: Threshold for reconstruction accuracy error :param batch_size: Minimum batch size while training :param device: Name of the device that get_device_from_name will attempt to use + :param input_encoder: one of `OneHotEncoder` or `SimpleLabelEncoder` to force usage of the underlying input encoder. Note that OHE does not scale for categorical features with high cardinality, while SLE can but is less accurate overall. """ # noqa super().__init__(is_target) self.is_prepared = False @@ -49,8 +52,9 @@ def __init__( self.net = None self.encoder = None self.decoder = None - self.onehot_encoder = OneHotEncoder(is_target=self.is_target) + self.input_encoder = None # TBD at prepare() self.device_type = device + self.input_encoder = input_encoder # Training details self.batch_size = batch_size @@ -64,26 +68,36 @@ def prepare(self, train_priming_data: pd.Series, dev_priming_data: pd.Series): :param train_priming_data: Input training data :param dev_priming_data: Input dev data (Not supported currently) """ # noqa + if self.is_prepared: raise Exception('You can only call "prepare" once for a given encoder.') if self.is_target: - log.warning( - 'You are trying to use an autoencoder for the target value! \ - This is very likely a bad idea' - ) + log.warning('You are trying to use an autoencoder for the target value! This is very likely a bad idea.') - log.info( - 'Preparing a categorical autoencoder, this may take up to ' - + str(self.stop_after) - + " seconds." - ) + error_msg = f'Provided an invalid input encoder ({self.input_encoder}), please use either `OneHotEncoder` or `SimpleLabelEncoder`.' # noqa + if self.input_encoder is not None: + assert self.input_encoder in ('OneHotEncoder', 'SimpleLabelEncoder'), error_msg + + log.info('Preparing a categorical autoencoder.') + + if self.input_encoder == 'SimpleLabelEncoder' or \ + (self.input_encoder is None and train_priming_data.nunique() > 500): + log.info('Deploying SimpleLabelEncoder for CategoricalAutoEncoder input.') + self.input_encoder = SimpleLabelEncoder(is_target=self.is_target) + input_len = self.input_encoder.output_size + self.output_size = 32 + net_shape = [input_len, 128, 64, self.output_size, input_len] + else: + log.info('Deploying OneHotEncoder for CategoricalAutoEncoder input.') + self.input_encoder = OneHotEncoder(is_target=self.is_target) + net_shape = None # defined at prepare() due to the OHE output size being determined then train_loader, dev_loader = self._prepare_AE_input( train_priming_data, dev_priming_data ) - best_model = self._prepare_catae(train_loader, dev_loader) + best_model = self._prepare_catae(train_loader, dev_loader, net_shape=net_shape) self.net = best_model.to(self.net.device) modules = [ @@ -92,9 +106,9 @@ def prepare(self, train_priming_data: pd.Series, dev_priming_data: pd.Series): if type(module) != torch.nn.Sequential and type(module) != DefaultNet ] - self.encoder = torch.nn.Sequential(*modules[0:2]).eval() - self.decoder = torch.nn.Sequential(*modules[2:3]).eval() - log.info('Categorical autoencoder ready') + self.encoder = torch.nn.Sequential(*modules[0:-1]).eval() + self.decoder = torch.nn.Sequential(*modules[-1:]).eval() + log.info('Categorical autoencoder ready.') self.is_prepared = True @@ -106,11 +120,13 @@ def encode(self, column_data: Iterable[str]) -> torch.Tensor: :returns: An embedding for each sample in original input """ # noqa - oh_encoded_tensor = self.onehot_encoder.encode(column_data) + encoded_tensor = self.input_encoder.encode(column_data) with torch.no_grad(): - oh_encoded_tensor = oh_encoded_tensor.to(self.net.device) - embeddings = self.encoder(oh_encoded_tensor) + encoded_tensor = encoded_tensor.to(self.net.device) + if len(encoded_tensor.shape) < 2: + encoded_tensor = encoded_tensor.unsqueeze(-1) + embeddings = self.encoder(encoded_tensor) return embeddings.to('cpu') def decode(self, encoded_data: torch.Tensor) -> List[str]: @@ -125,12 +141,12 @@ def decode(self, encoded_data: torch.Tensor) -> List[str]: """ # noqa with torch.no_grad(): encoded_data = encoded_data.to(self.net.device) - oh_encoded_tensor = self.decoder(encoded_data) - oh_encoded_tensor = oh_encoded_tensor.to('cpu') - return self.onehot_encoder.decode(oh_encoded_tensor) + encoded_tensor = self.decoder(encoded_data) + encoded_tensor = encoded_tensor.to('cpu') + return self.input_encoder.decode(encoded_tensor) def _prepare_AE_input( - self, train_priming_data: pd.Series, dev_priming_data: pd.Series + self, train_priming_data: pd.Series, dev_priming_data: pd.Series ) -> Tuple[DataLoader, DataLoader]: """ Creates the data loaders for the CatAE model inputs. Expected inputs are generally of form `pd.Series` @@ -150,7 +166,7 @@ def _prepare_AE_input( random.seed(len(priming_data)) # Prepare a one-hot encoder for CatAE inputs - self.onehot_encoder.prepare(priming_data) + self.input_encoder.prepare(priming_data) self.batch_size = max(min(self.batch_size, int(len(priming_data) / 50)), 1) train_loader = DataLoader( @@ -164,19 +180,36 @@ def _prepare_AE_input( return train_loader, dev_loader - def _prepare_catae(self, train_loader: DataLoader, dev_loader: DataLoader): + def _prepare_catae(self, train_loader: DataLoader, dev_loader: DataLoader, net_shape=None): """ Trains the CatAE using Lightwood's `Gym` class. :param train_loader: Training dataset Loader :param dev_loader: Validation set DataLoader """ # noqa - input_len = self.onehot_encoder.output_size - - self.net = DefaultNet(shape=[input_len, self.output_size, input_len], device=self.device_type) + if net_shape is None: + input_len = self.input_encoder.output_size + net_shape = [input_len, self.output_size, input_len] + + self.net = DefaultNet(shape=net_shape, device=self.device_type) + + if isinstance(self.input_encoder, OneHotEncoder): + criterion = torch.nn.CrossEntropyLoss() + desired_error = self.desired_error + elif isinstance(self.input_encoder, SimpleLabelEncoder): + criterion = torch.nn.MSELoss() + desired_error = 1e-9 + else: + raise Exception(f'[CatAutoEncoder] Input encoder of type {type(self.input_encoder)} is not supported!') - criterion = torch.nn.CrossEntropyLoss() - optimizer = Ranger(self.net.parameters()) + if isinstance(self.input_encoder, OneHotEncoder): + optimizer = Ranger(self.net.parameters()) + output_encoder = self._encoder_targets + max_time = self.stop_after + else: + optimizer = Ranger(self.net.parameters(), weight_decay=1e-2) + output_encoder = self._label_targets + max_time = 60 * 2 gym = Gym( model=self.net, @@ -185,15 +218,15 @@ def _prepare_catae(self, train_loader: DataLoader, dev_loader: DataLoader): loss_criterion=criterion, device=self.net.device, name=self.name, - input_encoder=self.onehot_encoder.encode, - output_encoder=self._encoder_targets, + input_encoder=self.input_encoder.encode, + output_encoder=output_encoder, ) best_model, _, _ = gym.fit( train_loader, dev_loader, - desired_error=self.desired_error, - max_time=self.stop_after, + desired_error=desired_error, + max_time=max_time, eval_every_x_epochs=1, max_unimproving_models=5, ) @@ -201,10 +234,18 @@ def _prepare_catae(self, train_loader: DataLoader, dev_loader: DataLoader): return best_model def _encoder_targets(self, data): - """""" - oh_encoded_categories = self.onehot_encoder.encode(data) - target = oh_encoded_categories.cpu().numpy() + """ Encodes target data with a OHE encoder """ + encoded_categories = self.input_encoder.encode(data) + target = encoded_categories.cpu().numpy() target_indexes = np.where(target > 0)[1] targets_c = torch.LongTensor(target_indexes) labels = targets_c.to(self.net.device) return labels + + def _label_targets(self, data): + """ Encodes target data with a label encoder """ + data = pd.Series(data) + enc = self.input_encoder.encode(data) + if len(enc.shape) < 2: + enc = enc.unsqueeze(-1) + return enc diff --git a/lightwood/encoder/categorical/gym.py b/lightwood/encoder/categorical/gym.py index 8e90eb5b4..7300d0c7d 100644 --- a/lightwood/encoder/categorical/gym.py +++ b/lightwood/encoder/categorical/gym.py @@ -4,6 +4,7 @@ import numpy as np from lightwood.helpers.torch import LightwoodAutocast +from lightwood.helpers.log import log class Gym: @@ -46,6 +47,8 @@ def fit(self, train_data_loader, test_data_loader, desired_error, max_time, call with LightwoodAutocast(): if self.input_encoder is not None: input = self.input_encoder(input) + if len(input.shape) < 2: + input = input.unsqueeze(-1) if self.output_encoder is not None: real = self.output_encoder(real) @@ -67,6 +70,8 @@ def fit(self, train_data_loader, test_data_loader, desired_error, max_time, call running_loss += loss.item() error = running_loss / (i + 1) + # end of epoch checks + log.debug(f'Categorical AutoEncoder train loss at epoch {epoch}: {round(error, 9)}') if epoch % eval_every_x_epochs == 0: if test_data_loader is not None: test_running_loss = 0.0 @@ -98,6 +103,7 @@ def fit(self, train_data_loader, test_data_loader, desired_error, max_time, call with torch.no_grad(): loss = custom_test_func(self.model, data, self) + log.debug(f'Categorical AutoEncoder val loss at epoch {epoch}: {round(loss, 9)}') test_running_loss += loss.item() test_error = test_running_loss / (i + 1) else: diff --git a/lightwood/encoder/categorical/simple_label.py b/lightwood/encoder/categorical/simple_label.py new file mode 100644 index 000000000..5cd2767e7 --- /dev/null +++ b/lightwood/encoder/categorical/simple_label.py @@ -0,0 +1,67 @@ +from typing import List, Union +from collections import defaultdict +import pandas as pd +import numpy as np +import torch + +from lightwood.encoder.base import BaseEncoder +from lightwood.helpers.constants import _UNCOMMON_WORD + + +class SimpleLabelEncoder(BaseEncoder): + """ + Simple encoder that assigns a unique integer to every observed label. + + Allocates an `unknown` label by default to index 0. + + Labels must be exact matches between inference and training (e.g. no .lower() on strings is performed here). + """ # noqa + + def __init__(self, is_target=False, normalize=True) -> None: + super().__init__(is_target) + self.label_map = defaultdict(int) # UNK category maps to 0 + self.inv_label_map = {} # invalid encoded values are mapped to None in `decode` + self.output_size = 1 + self.n_labels = None + self.normalize = normalize + + def prepare(self, priming_data: Union[list, pd.Series]) -> None: + if not isinstance(priming_data, pd.Series): + priming_data = pd.Series(priming_data) + + for i, v in enumerate(priming_data.unique()): + if v is not None: + self.label_map[str(v)] = int(i + 1) # leave 0 for UNK + self.n_labels = len(self.label_map) + for k, v in self.label_map.items(): + self.inv_label_map[v] = k + self.is_prepared = True + + def encode(self, data: Union[tuple, np.ndarray, pd.Series], normalize=True) -> torch.Tensor: + """ + :param normalize: can be used to temporarily return unnormalized values + """ + if not isinstance(data, pd.Series): + data = pd.Series(data) # specific to the Gym class - remove once deprecated! + if isinstance(data, np.ndarray): + data = pd.Series(data) + + data = data.astype(str) + encoded = torch.Tensor(data.map(self.label_map)) + + if normalize and self.normalize: + encoded /= self.n_labels + if len(encoded.shape) < 2: + encoded = encoded.unsqueeze(-1) + + return encoded + + def decode(self, encoded_values: torch.Tensor, normalize=True) -> List[object]: + """ + :param normalize: can be used to temporarily return unnormalized values + """ + if normalize and self.normalize: + encoded_values *= self.n_labels + values = encoded_values.long().squeeze().tolist() # long() as inv_label_map expects an int key + values = [self.inv_label_map.get(v, _UNCOMMON_WORD) for v in values] + return values diff --git a/lightwood/ensemble/__init__.py b/lightwood/ensemble/__init__.py index 8507774f3..00078b9a4 100644 --- a/lightwood/ensemble/__init__.py +++ b/lightwood/ensemble/__init__.py @@ -1,4 +1,6 @@ from lightwood.ensemble.base import BaseEnsemble +from lightwood.ensemble.identity import IdentityEnsemble +from lightwood.ensemble.embed import Embedder from lightwood.ensemble.best_of import BestOf from lightwood.ensemble.mean_ensemble import MeanEnsemble from lightwood.ensemble.mode_ensemble import ModeEnsemble @@ -7,4 +9,4 @@ from lightwood.ensemble.weighted_mean_ensemble import WeightedMeanEnsemble __all__ = ['BaseEnsemble', 'BestOf', 'MeanEnsemble', 'ModeEnsemble', 'WeightedMeanEnsemble', 'StackedEnsemble', - 'TsStackedEnsemble'] + 'TsStackedEnsemble', 'Embedder', 'IdentityEnsemble'] diff --git a/lightwood/ensemble/embed.py b/lightwood/ensemble/embed.py new file mode 100644 index 000000000..c296397d1 --- /dev/null +++ b/lightwood/ensemble/embed.py @@ -0,0 +1,23 @@ +from typing import List +import pandas as pd + +from lightwood.mixer.base import BaseMixer +from lightwood.ensemble.base import BaseEnsemble +from lightwood.api.types import PredictionArguments +from lightwood.data.encoded_ds import EncodedDs + + +class Embedder(BaseEnsemble): + """ + This ensemble acts as a simple embedder that bypasses all mixers. + When called, it will return the encoded representation of the data stored in (or generated by) an EncodedDs object. + """ # noqa + def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs) -> None: + super().__init__(target, list(), data) + self.embedding_size = data.get_encoded_data(include_target=False).shape[-1] + self.prepared = True + + def __call__(self, ds: EncodedDs, args: PredictionArguments = None) -> pd.DataFrame: + # shape: (B, self.embedding_size) + encoded_representations = ds.get_encoded_data(include_target=False).numpy() + return pd.DataFrame(encoded_representations) diff --git a/lightwood/ensemble/identity.py b/lightwood/ensemble/identity.py new file mode 100644 index 000000000..1ee29374c --- /dev/null +++ b/lightwood/ensemble/identity.py @@ -0,0 +1,36 @@ +from typing import List +import pandas as pd + +from lightwood.mixer.base import BaseMixer +from lightwood.ensemble.base import BaseEnsemble +from lightwood.api.types import PredictionArguments +from lightwood.data.encoded_ds import EncodedDs + + +class IdentityEnsemble(BaseEnsemble): + """ + This ensemble performs no aggregation. User can define an "active mixer" and calling the ensemble will call said mixer. + + Ideal for use cases with single mixers where (potentially expensive) evaluation runs are done internally, as in `BestOf`. + """ # noqa + + def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs, args: PredictionArguments) -> None: + super().__init__(target, mixers, data=data) + self._active_mixer = 0 + single_row_ds = EncodedDs(data.encoders, data.data_frame.iloc[[0]], data.target) + _ = self.mixers[self._active_mixer](single_row_ds, args)['prediction'] # prime mixer for storage, needed because NHitsMixer.model (neuralforecast.NHITS) is not serializable without this, oddly enough. Eventually, check this again and remove if possible! # noqa + self.prepared = True + + def __call__(self, ds: EncodedDs, args: PredictionArguments = None) -> pd.DataFrame: + assert self.prepared + mixer = self.mixers[self.active_mixer] + return mixer(ds, args=args) + + @property + def active_mixer(self): + return self._active_mixer + + @active_mixer.setter + def active_mixer(self, idx): + assert 0 <= idx < len(self.mixers), f'The ensemble has {len(self.mixers)} mixers, please provide a valid index.' + self._active_mixer = idx diff --git a/lightwood/helpers/parallelism.py b/lightwood/helpers/parallelism.py index c1dd95cd8..c18141ff8 100644 --- a/lightwood/helpers/parallelism.py +++ b/lightwood/helpers/parallelism.py @@ -19,11 +19,11 @@ def get_nr_procs(df=None): return 1 else: available_mem = psutil.virtual_memory().available - max_per_proc_usage = 0.2 * pow(10, 9) + max_per_proc_usage = 2 * pow(10, 8) if df is not None: max_per_proc_usage += df.memory_usage(index=True, deep=True).sum() - proc_count = int(min(mp.cpu_count(), available_mem // max_per_proc_usage)) - 1 + proc_count = min(mp.cpu_count(), available_mem // max_per_proc_usage) - 1 return max(proc_count, 1) diff --git a/lightwood/helpers/templating.py b/lightwood/helpers/templating.py index 26b15579d..4bbb3650f 100644 --- a/lightwood/helpers/templating.py +++ b/lightwood/helpers/templating.py @@ -1,5 +1,7 @@ +from typing import Callable from collections import deque +import inspect import numpy as np from type_infer.dtype import dtype @@ -131,3 +133,13 @@ def _consolidate_analysis_blocks(jsonai, key): sorted_blocks.append(block_objs[idx2block[idx]]) return sorted_blocks + + +def _add_cls_kwarg(cls: Callable, kwargs: dict, key: str, value): + """ + Adds arguments to the `kwargs` dictionary if the key-value pair is valid for the `cls` class signature. + """ + if key in [p.name for p in inspect.signature(cls).parameters.values()]: + kwargs[key] = value + + return kwargs diff --git a/lightwood/mixer/nhits.py b/lightwood/mixer/nhits.py index 7d22c4764..a443705a7 100644 --- a/lightwood/mixer/nhits.py +++ b/lightwood/mixer/nhits.py @@ -54,6 +54,7 @@ def __init__( self.dtype_dict = dtype_dict self.ts_analysis = ts_analysis self.grouped_by = ['__default'] if not ts_analysis['tss'].group_by else ts_analysis['tss'].group_by + self.group_boundaries = {} # stores last observed timestamp per series self.train_args = train_args.get('trainer_args', {}) if train_args else {} self.train_args['early_stop_patience_steps'] = self.train_args.get('early_stop_patience_steps', 10) self.conf_level = self.train_args.pop('conf_level', [90]) @@ -93,7 +94,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: oby_col = self.ts_analysis["tss"].order_by gby = self.ts_analysis["tss"].group_by if self.ts_analysis["tss"].group_by else [] df = deepcopy(cat_ds.data_frame) - Y_df = self._make_initial_df(df) + Y_df = self._make_initial_df(df, mode='train') + self.group_boundaries = self._set_boundary(Y_df, gby) if gby: n_time = df[gby].value_counts().min() else: @@ -130,9 +132,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: log.info('Successfully trained N-HITS forecasting model.') def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional[dict] = None) -> None: - # TODO: reimplement this with automatic novel-row differential self.hyperparam_search = False - self.fit(dev_data, train_data) # TODO: add support for passing args (e.g. n_epochs) + self.fit(train_data, dev_data) # TODO: add support for passing args (e.g. n_epochs) self.prepared = True def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs], @@ -183,7 +184,13 @@ def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs], ydf['confidence'] = level / 100 return ydf - def _make_initial_df(self, df): + def _make_initial_df(self, df, mode='inference'): + """ + Prepares a dataframe for the NHITS model according to what neuralforecast expects. + + If a per-group boundary exists, this method additionally drops out all observations prior to the cutoff. + """ # noqa + oby_col = self.ts_analysis["tss"].order_by df = df.sort_values(by=f'__mdb_original_{oby_col}') df[f'__mdb_parsed_{oby_col}'] = df.index @@ -198,4 +205,32 @@ def _make_initial_df(self, df): else: Y_df['unique_id'] = '__default' - return Y_df.reset_index() + Y_df = Y_df.reset_index() + + # filter if boundary exists + if mode == 'train' and self.group_boundaries: + filtered = [] + grouped = Y_df.groupby(by='unique_id') + for group, sdf in grouped: + if group in self.group_boundaries: + sdf = sdf[sdf['ds'].gt(self.group_boundaries[group])] + if sdf.shape[0] > 0: + filtered.append(sdf) + if filtered: + Y_df = pd.concat(filtered) + + return Y_df + + @staticmethod + def _set_boundary(df: pd.DataFrame, gby: list) -> Dict[str, object]: + """ + Finds last observation for every series in a pre-sorted `df` given a `gby` list of columns to group by. + """ + if not gby: + group_boundaries = {'__default': df.iloc[-1]['ds']} + else: + # could use groupby().transform('max'), but we leverage pre-sorting instead + grouped_df = df.groupby(by='unique_id', as_index=False).last() + group_boundaries = grouped_df[['unique_id', 'ds']].set_index('unique_id').to_dict()['ds'] + + return group_boundaries diff --git a/lightwood/mixer/sktime.py b/lightwood/mixer/sktime.py index 666e09631..7ebabcfd9 100644 --- a/lightwood/mixer/sktime.py +++ b/lightwood/mixer/sktime.py @@ -1,4 +1,3 @@ -import inspect import importlib from copy import deepcopy from datetime import datetime @@ -13,6 +12,7 @@ from sktime.forecasting.statsforecast import StatsForecastAutoARIMA as AutoARIMA from lightwood.helpers.log import log +from lightwood.helpers.templating import _add_cls_kwarg from lightwood.mixer.base import BaseMixer from lightwood.api.types import PredictionArguments from lightwood.data.encoded_ds import EncodedDs, ConcatedEncodedDs @@ -164,7 +164,7 @@ def _fit(self, data): options['freq'] = self.freq for k, v in options.items(): - kwargs = self._add_forecaster_kwarg(model_class, kwargs, k, v) + kwargs = _add_cls_kwarg(model_class, kwargs, k, v) model_pipeline = [] @@ -337,15 +337,6 @@ def _get_best_model(self, trial, train_data, test_data): log.info(f'Trial got error: {error}') return error - def _add_forecaster_kwarg(self, forecaster: BaseForecaster, kwargs: dict, key: str, value): - """ - Adds arguments to the `kwargs` dictionary if the key-value pair is valid for the `forecaster` class signature. - """ - if key in [p.name for p in inspect.signature(forecaster).parameters.values()]: - kwargs[key] = value - - return kwargs - def _transform_index_to_datetime(self, series, series_oby, freq): series_oby = np.array([np.array(lst) for lst in series_oby]) start = datetime.utcfromtimestamp(np.min(series_oby[series_oby != np.min(series_oby)])) diff --git a/tests/integration/basic/test_embedding.py b/tests/integration/basic/test_embedding.py new file mode 100644 index 000000000..7d37e9c94 --- /dev/null +++ b/tests/integration/basic/test_embedding.py @@ -0,0 +1,29 @@ +import unittest +import pandas as pd +from tests.utils.timing import train_and_check_time_aim +from lightwood.api.types import ProblemDefinition +from lightwood.api.high_level import predictor_from_problem + + +class TestEmbeddingPredictor(unittest.TestCase): + def test_0_embedding_at_inference_time(self): + df = pd.read_csv('tests/data/hdi.csv') + pdef = ProblemDefinition.from_dict({'target': 'Development Index', 'time_aim': 10}) + predictor = predictor_from_problem(df, pdef) + train_and_check_time_aim(predictor, df, ignore_time_aim=True) + predictions = predictor.predict(df, args={'return_embedding': True}) + + self.assertTrue(predictions.shape[0] == len(df)) + self.assertTrue(predictions.shape[1] != 1) # embedding dimension + + def test_1_embedding_only_at_creation(self): + df = pd.read_csv('tests/data/hdi.csv') + target = 'Development Index' + pdef = ProblemDefinition.from_dict({'target': target, 'time_aim': 10, 'embedding_only': True}) + predictor = predictor_from_problem(df, pdef) + train_and_check_time_aim(predictor, df, ignore_time_aim=True) + predictions = predictor.predict(df) + + self.assertTrue(predictions.shape[0] == len(df)) + self.assertTrue(predictions.shape[1] == predictor.ensemble.embedding_size) + self.assertTrue(len(predictor.mixers) == 0) diff --git a/tests/unit_tests/encoder/categorical/test_autoencoder.py b/tests/unit_tests/encoder/categorical/test_autoencoder.py index 4d6da1b86..3f74e0ec1 100644 --- a/tests/unit_tests/encoder/categorical/test_autoencoder.py +++ b/tests/unit_tests/encoder/categorical/test_autoencoder.py @@ -22,10 +22,10 @@ def create_test_data(self, test_data_rel_size=0.33): random.seed(2) np_random = np.random.default_rng(seed=2) - int_categories = np_random.integers(low=1, high=20, size=nb_int_categories) + int_categories = np_random.integers(low=1, high=nb_int_categories, size=nb_int_categories) str_categories = [ ''.join(random.choices(string.ascii_uppercase + string.digits, k=random.randint(7, 8))) - for category_i in range(nb_categories - nb_int_categories) + for _ in range(nb_categories - nb_int_categories) ] categories = list(int_categories) + str_categories category_sizes = np_random.integers(low=1, high=max_category_size, size=nb_categories) @@ -41,28 +41,7 @@ def create_test_data(self, test_data = priming_data[:test_data_size] return priming_data, test_data - def create_test_data_old(self): - random.seed(2) - cateogries = [''.join(random.choices(string.ascii_uppercase + string.digits, - k=random.randint(7, 8))) for x in range(500)] - for i in range(len(cateogries)): - if i % 10 == 0: - cateogries[i] = random.randint(1, 20) - - priming_data = [] - test_data = [] - for category in cateogries: - times = random.randint(1, 50) - for i in range(times): - priming_data.append(category) - if i % 3 == 0 or i == 1: - test_data.append(category) - - random.shuffle(priming_data) - random.shuffle(test_data) - return priming_data, test_data - - def test_autoencoder(self): + def test_autoencoder_ohe(self): """ Checks reconstruction accuracy above 70% for a set of categories, length 8, for up to 500 unique categories (actual around 468). """ # noqa @@ -83,6 +62,25 @@ def test_autoencoder(self): print(f'Categorical encoder accuracy for: {encoder_accuracy} on testing dataset') self.assertTrue(encoder_accuracy > 0.70) + def test_autoencoder_label(self): + """ + Checks reconstruction accuracy above 60%, less strict than OHE, because it is over a larger number of categories (1000). + """ # noqa + log.setLevel(logging.DEBUG) + torch.manual_seed(2) + + priming_data, test_data = self.create_test_data(nb_categories=1000, nb_int_categories=1000) + + enc = CategoricalAutoEncoder(stop_after=20) + + enc.prepare(pd.Series(priming_data), pd.Series(priming_data)) + encoded_data = enc.encode(test_data) + decoded_data = enc.decode(encoded_data) + + encoder_accuracy = accuracy_score(list(map(str, test_data)), list(map(str, decoded_data))) + print(f'Categorical encoder accuracy for: {encoder_accuracy} on testing dataset') + self.assertTrue(encoder_accuracy > 0.60) + def check_encoder_on_device(self, device): priming_data, _ = self.create_test_data(nb_categories=8, nb_int_categories=3, diff --git a/tests/unit_tests/encoder/categorical/test_label.py b/tests/unit_tests/encoder/categorical/test_label.py new file mode 100644 index 000000000..8dc1e23c1 --- /dev/null +++ b/tests/unit_tests/encoder/categorical/test_label.py @@ -0,0 +1,59 @@ +import unittest +from torch import Tensor +import pandas as pd +from lightwood.encoder.categorical.simple_label import ( + SimpleLabelEncoder, +) +from lightwood.helpers.constants import _UNCOMMON_WORD + + +class TestLabel(unittest.TestCase): + """ Test the label encoder """ + + def test_encode_and_decode(self): + """ + Tests encoder end to end + + Checks: + (1) UNKS are assigned to 0th index + (2) Nones or unrecognized categories are both handled + (3) The decode/encode map order is the same + """ # noqa + data = pd.Series(['category 1', 'category 3', 'category 4', None, 'category 3']) + test_data = pd.Series(['unseen', 'category 4', 'category 1', 'category 3', None]) + n_points = data.nunique() + + ytest = [ + _UNCOMMON_WORD, + 'category 4', + 'category 1', + 'category 3', + _UNCOMMON_WORD, + ] + + enc = SimpleLabelEncoder() + enc.prepare(data) + + # Check the encoded patterns correct + encoded_data = enc.encode(data) + print(encoded_data) + self.assertTrue( + ( + encoded_data + == Tensor( + [ + 1 / n_points, # category 1 + 2 / n_points, # category 3 + 3 / n_points, # category 4 + 0 / n_points, # None + 2 / n_points, # category 3 + ] + ).reshape(-1, 1) + ).all() + ) + + # Check the decoded patterns correct + decoded_data = enc.decode(enc.encode(test_data)) + print(decoded_data) + for i in range(len(ytest)): + self.assertTrue(decoded_data[i] == ytest[i]) diff --git a/tests/unit_tests/encoder/text/test_short.py b/tests/unit_tests/encoder/text/test_short.py index 1cb669118..087ac4ac4 100644 --- a/tests/unit_tests/encoder/text/test_short.py +++ b/tests/unit_tests/encoder/text/test_short.py @@ -104,8 +104,13 @@ def test_non_smallvocab_target_auto_mode(self): test_data = random.sample(priming_data, len(priming_data) // 5) enc = ShortTextEncoder(is_target=True) - enc.prepare(priming_data) + enc.cae.input_encoder = 'OneHotEncoder!' # check that invalid input encoder triggers exception + self.assertRaises(AssertionError, enc.prepare, priming_data) + + # train as usual (note, for this test we pick OHE to focus on the CAE's accuracy) + enc.cae.input_encoder = 'OneHotEncoder' + enc.prepare(priming_data) assert enc.is_target is True # _combine is expected to be 'concat' when is_target is True