From 3d0fa675fb86dba3dab3f016d468ed09ba6de20f Mon Sep 17 00:00:00 2001 From: Brad Hunter Date: Thu, 14 Apr 2016 15:57:30 -0700 Subject: [PATCH] Version 1.0 Candidate WIP --- core/build.gradle | 12 +- .../com/airbnb/aerosolve/core/Example.java | 45 ++ .../airbnb/aerosolve/core/FeatureVector.java | 123 ++++ .../aerosolve/core/features/BasicFamily.java | 179 +++++ .../core/features/BasicMultiFamilyVector.java | 323 +++++++++ .../aerosolve/core/features/DenseVector.java | 169 +++++ .../aerosolve/core/features/Family.java | 39 ++ .../aerosolve/core/features/FamilyVector.java | 10 + .../aerosolve/core/features/Feature.java | 70 ++ .../core/features/FeatureFamily.java | 31 - .../aerosolve/core/features/FeatureGen.java | 39 -- .../core/features/FeatureRegistry.java | 34 + .../aerosolve/core/features/FeatureValue.java | 9 + .../core/features/FeatureValueEntry.java | 10 + .../core/features/FeatureVectorGen.java | 74 -- .../aerosolve/core/features/Features.java | 157 ----- .../aerosolve/core/features/FloatFamily.java | 20 - .../features/GenericNamingConvention.java | 148 ++++ .../core/features/InputGenerator.java | 44 ++ .../{FeatureMapping.java => InputSchema.java} | 13 +- .../core/features/MultiFamilyVector.java | 114 ++++ .../core/features/NamingConvention.java | 22 + .../core/features/SimpleExample.java | 92 +++ .../core/features/SimpleFeatureValue.java | 24 + .../features/SimpleFeatureValueEntry.java | 43 ++ .../aerosolve/core/features/SparseVector.java | 78 +++ .../aerosolve/core/features/StringFamily.java | 34 - .../core/function/AbstractFunction.java | 58 -- .../aerosolve/core/function/Function.java | 30 - .../aerosolve/core/function/Spline.java | 199 ------ .../core/functions/AbstractFunction.java | 64 ++ .../aerosolve/core/functions/Function.java | 26 + .../{function => functions}/FunctionUtil.java | 61 +- .../core/{function => functions}/Linear.java | 50 +- .../MultiDimensionPoint.java | 8 +- .../MultiDimensionSpline.java | 88 +-- .../aerosolve/core/functions/Spline.java | 197 ++++++ .../aerosolve/core/images/HOGFeature.java | 2 - .../aerosolve/core/images/HSVFeature.java | 4 +- .../core/images/ImageFeatureExtractor.java | 22 +- .../aerosolve/core/images/LBPFeature.java | 1 - .../aerosolve/core/images/RGBFeature.java | 1 - .../aerosolve/core/models/AbstractModel.java | 34 +- .../aerosolve/core/models/AdditiveModel.java | 350 +++++----- .../core/models/BoostedStumpsModel.java | 73 +- .../core/models/DecisionTreeModel.java | 94 +-- .../aerosolve/core/models/ForestModel.java | 47 +- .../core/models/FullRankLinearModel.java | 129 ++-- .../aerosolve/core/models/KDTreeModel.java | 53 +- .../aerosolve/core/models/KernelModel.java | 56 +- .../aerosolve/core/models/LinearModel.java | 222 +++--- .../core/models/LowRankLinearModel.java | 103 ++- .../aerosolve/core/models/MaxoutModel.java | 204 +++--- .../aerosolve/core/models/MlpModel.java | 105 ++- .../airbnb/aerosolve/core/models/Model.java | 4 +- .../aerosolve/core/models/ModelFactory.java | 67 +- .../aerosolve/core/models/NDTreeModel.java | 30 +- .../aerosolve/core/models/SplineModel.java | 243 +++---- .../aerosolve/core/scoring/ModelConfig.java | 2 +- .../aerosolve/core/scoring/ModelScorer.java | 37 +- .../ApproximatePercentileTransform.java | 121 ++-- .../core/transforms/BucketFloatTransform.java | 48 -- .../core/transforms/BucketTransform.java | 39 ++ .../core/transforms/CapFloatTransform.java | 30 - .../core/transforms/CapTransform.java | 21 + .../ConvertStringCaseTransform.java | 38 +- .../core/transforms/CrossTransform.java | 113 ++- .../CustomLinearLogQuantizeTransform.java | 123 ---- .../CustomMultiscaleQuantizeTransform.java | 82 --- .../core/transforms/CutFloatTransform.java | 66 -- .../core/transforms/CutTransform.java | 36 + .../core/transforms/DateDiffTransform.java | 72 +- .../core/transforms/DateValTransform.java | 114 ++-- .../transforms/DecisionTreeTransform.java | 79 ++- .../DefaultStringTokenizerTransform.java | 114 ++-- .../DeleteFloatFeatureFamilyTransform.java | 40 -- .../DeleteFloatFeatureTransform.java | 34 - .../DeleteStringFeatureColumnTransform.java | 4 - .../DeleteStringFeatureFamilyTransform.java | 41 -- .../DeleteStringFeatureTransform.java | 46 -- .../core/transforms/DeleteTransform.java | 118 ++++ .../core/transforms/DenseTransform.java | 117 ++++ .../core/transforms/DivideTransform.java | 84 +-- .../transforms/FloatCrossFloatTransform.java | 78 --- .../transforms/FloatToDenseTransform.java | 80 --- .../transforms/KdtreeContinuousTransform.java | 90 --- .../core/transforms/KdtreeTransform.java | 122 ++-- .../core/transforms/LegacyNames.java | 16 + .../LinearLogQuantizeTransform.java | 138 ---- .../core/transforms/ListTransform.java | 51 +- .../core/transforms/MathFloatTransform.java | 86 --- .../core/transforms/MathTransform.java | 87 +++ .../aerosolve/core/transforms/ModelAware.java | 13 + .../MoveFloatToStringAndFloatTransform.java | 93 --- .../MoveFloatToStringTransform.java | 89 --- .../core/transforms/MoveTransform.java | 90 +++ .../MultiscaleGridContinuousTransform.java | 78 +-- .../MultiscaleGridQuantizeTransform.java | 66 +- .../MultiscaleMoveFloatToStringTransform.java | 64 -- .../MultiscaleQuantizeTransform.java | 65 -- .../core/transforms/NearestTransform.java | 66 +- .../transforms/NormalizeFloatTransform.java | 40 -- .../core/transforms/NormalizeTransform.java | 38 ++ .../transforms/NormalizeUtf8Transform.java | 43 +- .../core/transforms/ProductTransform.java | 54 +- .../core/transforms/QuantizeTransform.java | 238 ++++++- .../ReplaceAllStringsTransform.java | 60 +- .../core/transforms/SelfCrossTransform.java | 48 -- .../transforms/StringCrossFloatTransform.java | 43 -- .../core/transforms/StuffIdTransform.java | 61 +- .../core/transforms/StumpTransform.java | 110 +-- .../core/transforms/SubtractTransform.java | 69 +- .../aerosolve/core/transforms/Transform.java | 12 +- .../core/transforms/TransformFactory.java | 65 +- .../core/transforms/Transformer.java | 128 +--- .../core/transforms/WtaTransform.java | 151 ++-- .../base/BaseFeaturesTransform.java | 144 ++++ .../base/BoundedFeaturesTransform.java | 34 + .../base/ConfigurableTransform.java | 205 ++++++ .../transforms/base/DualFamilyTransform.java | 40 ++ .../transforms/base/DualFeatureTransform.java | 60 ++ .../base/OtherFeatureTransform.java | 51 ++ .../base/SingleFeatureTransform.java | 79 +++ .../core/transforms/base/StringTransform.java | 41 ++ .../core/transforms/types/FloatTransform.java | 42 -- .../transforms/types/StringTransform.java | 69 -- .../airbnb/aerosolve/core/util/DateUtil.java | 3 +- .../com/airbnb/aerosolve/core/util/Debug.java | 143 +--- .../core/util/FeatureDictionary.java | 25 +- .../core/util/FeatureVectorUtil.java | 60 +- .../aerosolve/core/util/FloatVector.java | 8 +- .../core/util/KNearestNeighborsOptions.java | 25 +- ...ySensitiveHashSparseFeatureDictionary.java | 126 ++-- .../util/MinKernelDenseFeatureDictionary.java | 67 +- .../core/util/ReinforcementLearning.java | 43 +- .../aerosolve/core/util/StringDictionary.java | 40 +- .../aerosolve/core/util/SupportVector.java | 18 +- .../com/airbnb/aerosolve/core/util/Util.java | 292 ++++---- .../airbnb/aerosolve/core/util/Weibull.java | 5 +- core/src/main/thrift/MLSchema.thrift | 19 +- .../core/features/FeatureGenTest.java | 28 - .../core/features/FeatureMappingTest.java | 27 - .../aerosolve/core/features/FeaturesTest.java | 186 ----- .../core/features/InputGenerationTest.java | 232 +++++++ .../aerosolve/core/function/SplineTest.java | 197 ------ .../{function => functions}/LinearTest.java | 42 +- .../MultiDimensionPointTest.java | 8 +- .../MultiDimensionSplineTest.java | 2 +- .../aerosolve/core/functions/SplineTest.java | 197 ++++++ .../aerosolve/core/images/HOGFeatureTest.java | 5 +- .../aerosolve/core/images/HSVFeatureTest.java | 4 +- .../images/ImageFeatureExtractorTest.java | 43 +- .../aerosolve/core/images/LBPFeatureTest.java | 4 +- .../aerosolve/core/images/RGBFeatureTest.java | 5 +- .../core/models/AdditiveModelPerfTest.java | 90 +++ .../core/models/AdditiveModelTest.java | 205 +++--- .../core/models/KDTreeModelTest.java | 7 +- .../core/models/LinearModelTest.java | 93 ++- .../core/models/LowRankLinearModelTest.java | 105 +-- .../aerosolve/core/models/MlpModelTest.java | 76 ++- .../core/models/NDTreeModelTest.java | 6 +- .../core/scoring/ModelScorerTest.java | 65 +- .../ApproximatePercentileTransformTest.java | 81 +-- .../core/transforms/BaseTransformTest.java | 123 ++++ .../transforms/BucketFloatTransformTest.java | 81 --- .../core/transforms/BucketTransformTest.java | 52 ++ .../transforms/CapFloatTransformTest.java | 76 +-- .../ConvertStringCaseTransformTest.java | 111 +-- .../core/transforms/CrossTransformTest.java | 80 +-- .../CustomLinearLogQuantizeTransformTest.java | 98 +-- ...CustomMultiscaleQuantizeTransformTest.java | 199 +++--- .../transforms/CutFloatTransformTest.java | 114 ---- .../core/transforms/CutTransformTest.java | 89 +++ .../transforms/DateDiffTransformTest.java | 59 +- .../core/transforms/DateValTransformTest.java | 145 ++-- .../transforms/DecisionTreeTransformTest.java | 95 ++- .../DefaultStringTokenizerTransformTest.java | 189 ++---- .../DeleteFeatureTransformTest.java | 46 ++ ...DeleteFloatFeatureFamilyTransformTest.java | 100 +-- .../DeleteFloatFeatureTransformTest.java | 72 -- ...eleteStringFeatureFamilyTransformTest.java | 95 +-- .../DeleteStringFeatureTransformTest.java | 70 +- ...sformTest.java => DenseTransformTest.java} | 28 +- .../core/transforms/DivideTransformTest.java | 51 +- .../FloatCrossFloatTransformTest.java | 100 +-- .../KdtreeContinuousTransformTest.java | 69 +- .../core/transforms/KdtreeTransformTest.java | 65 +- .../LinearLogQuantizeTransformTest.java | 94 +-- .../core/transforms/ListTransformTest.java | 71 +- .../transforms/MathFloatTransformTest.java | 96 ++- .../core/transforms/ModelTransformsTest.java | 110 ++- ...oveFloatToStringAndFloatTransformTest.java | 123 ++-- .../MoveFloatToStringTransformTest.java | 73 +- ...MultiscaleGridContinuousTransformTest.java | 75 +- .../MultiscaleGridQuantizeTransformTest.java | 68 +- ...tiscaleMoveFloatToStringTransformTest.java | 66 +- .../MultiscaleQuantizeTransformTest.java | 86 +-- .../core/transforms/NearestTransformTest.java | 76 +-- .../NormalizeFloatTransformTest.java | 43 +- .../NormalizeUtf8TransformTest.java | 160 ++--- .../core/transforms/ProductTransformTest.java | 61 +- .../transforms/QuantizeTransformTest.java | 66 +- .../ReplaceAllStringsTransformTest.java | 92 +-- .../transforms/SelfCrossTransformTest.java | 67 +- .../StringCrossFloatTransformTest.java | 78 +-- .../core/transforms/StuffIdTransformTest.java | 69 +- .../core/transforms/StumpTransformTest.java | 75 +- .../transforms/SubtractTransformTest.java | 50 +- .../transforms/TransformTestingHelper.java | 133 +++- .../core/transforms/WtaTransformTest.java | 80 +-- .../aerosolve/core/util/DateUtilTest.java | 3 +- .../core/util/FeatureDictionaryTest.java | 166 ++--- .../core/util/FeatureVectorUtilTest.java | 34 +- .../aerosolve/core/util/FloatVectorTest.java | 8 - .../core/util/ReinforcementLearningTest.java | 78 ++- .../core/util/StringDictionaryTest.java | 56 +- .../core/util/SupportVectorTest.java | 15 +- .../airbnb/aerosolve/core/util/UtilTest.java | 87 ++- .../src/test/resources/income_prediction.conf | 7 +- thrift-cli.gradle | 2 +- training/build.gradle | 2 + .../training/AdditiveModelTrainer.scala | 279 ++++---- .../training/BoostedForestTrainer.scala | 118 ++-- .../training/BoostedStumpsTrainer.scala | 87 ++- .../training/DecisionTreeTrainer.scala | 150 ++-- .../aerosolve/training/Evaluation.scala | 61 +- .../aerosolve/training/FeatureSelection.scala | 88 +-- .../aerosolve/training/ForestTrainer.scala | 40 +- .../training/FullRankLinearTrainer.scala | 198 +++--- .../aerosolve/training/GradientUtils.scala | 112 ++- .../training/HistogramCalibrator.scala | 2 +- .../airbnb/aerosolve/training/KDTree.scala | 18 +- .../aerosolve/training/KernelTrainer.scala | 75 +- .../training/LinearRankerTrainer.scala | 642 +++++++++--------- .../training/LinearRankerUtils.scala | 205 ++---- .../training/LowRankLinearTrainer.scala | 154 ++--- .../aerosolve/training/MaxoutTrainer.scala | 133 ++-- .../aerosolve/training/MlpModelTrainer.scala | 211 +++--- .../airbnb/aerosolve/training/NDTree.scala | 22 +- .../aerosolve/training/SplineTrainer.scala | 201 +++--- .../aerosolve/training/TrainingUtils.scala | 156 ++--- .../training/pipeline/EvalUtil.scala | 32 +- .../training/pipeline/ExampleUtil.scala | 10 +- .../training/pipeline/GenericPipeline.scala | 137 ++-- .../training/pipeline/PipelineUtil.scala | 8 +- .../training/AdditiveModelTrainerTest.scala | 31 +- .../training/BoostedStumpsModelTest.scala | 16 +- .../training/DecisionTreeTrainerTest.scala | 82 +-- .../training/ForestTrainerTest.scala | 49 +- .../training/FullRankLinearModelTest.scala | 23 +- .../aerosolve/training/KDTreeTest.scala | 28 +- .../training/KernelTrainerTest.scala | 24 +- .../LinearClassificationTrainerTest.scala | 31 +- ...earLogisticClassificationTrainerTest.scala | 31 +- .../training/LinearRankerTrainerTest.scala | 66 +- .../LinearRegressionTrainerTest.scala | 39 +- .../training/LowRankLinearTrainerTest.scala | 25 +- .../training/MaxoutTrainerTest.scala | 44 +- .../training/MlpModelTrainerTest.scala | 27 +- .../aerosolve/training/NDTreeTest.scala | 44 +- .../training/SplineRankingTrainerTest.scala | 43 +- .../training/SplineTrainerTest.scala | 95 ++- .../training/TrainingTestHelper.scala | 138 ++-- .../training/TrainingUtilsTest.scala | 32 +- .../training/pipeline/EvalUtilTest.scala | 15 +- .../pipeline/GenericPipelineTest.scala | 69 +- .../pipeline/PipelineTestingUtil.scala | 92 ++- .../training/pipeline/PipelineUtilTest.scala | 2 +- 268 files changed, 10170 insertions(+), 10671 deletions(-) create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/Example.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/FeatureVector.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/BasicFamily.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/BasicMultiFamilyVector.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/DenseVector.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/Family.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FamilyVector.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/Feature.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureFamily.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureGen.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureRegistry.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValue.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValueEntry.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FeatureVectorGen.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/Features.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/FloatFamily.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/GenericNamingConvention.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/InputGenerator.java rename core/src/main/java/com/airbnb/aerosolve/core/features/{FeatureMapping.java => InputSchema.java} (87%) create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/MultiFamilyVector.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/NamingConvention.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/SimpleExample.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValue.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValueEntry.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/SparseVector.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/features/StringFamily.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/function/AbstractFunction.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/function/Function.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/function/Spline.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/functions/AbstractFunction.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/functions/Function.java rename core/src/main/java/com/airbnb/aerosolve/core/{function => functions}/FunctionUtil.java (53%) rename core/src/main/java/com/airbnb/aerosolve/core/{function => functions}/Linear.java (65%) rename core/src/main/java/com/airbnb/aerosolve/core/{function => functions}/MultiDimensionPoint.java (94%) rename core/src/main/java/com/airbnb/aerosolve/core/{function => functions}/MultiDimensionSpline.java (68%) create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/functions/Spline.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CapFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CapTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CutFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/CutTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureColumnTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/DenseTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/LegacyNames.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MathFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MathTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/ModelAware.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BaseFeaturesTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BoundedFeaturesTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/ConfigurableTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFamilyTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFeatureTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/OtherFeatureTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/SingleFeatureTransform.java create mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/base/StringTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/types/FloatTransform.java delete mode 100644 core/src/main/java/com/airbnb/aerosolve/core/transforms/types/StringTransform.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/features/FeatureGenTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/features/FeatureMappingTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/features/FeaturesTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/features/InputGenerationTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/function/SplineTest.java rename core/src/test/java/com/airbnb/aerosolve/core/{function => functions}/LinearTest.java (58%) rename core/src/test/java/com/airbnb/aerosolve/core/{function => functions}/MultiDimensionPointTest.java (94%) rename core/src/test/java/com/airbnb/aerosolve/core/{function => functions}/MultiDimensionSplineTest.java (98%) create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/functions/SplineTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelPerfTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/BaseTransformTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransformTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketTransformTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/CutFloatTransformTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/CutTransformTest.java create mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFeatureTransformTest.java delete mode 100644 core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransformTest.java rename core/src/test/java/com/airbnb/aerosolve/core/transforms/{FloatToDenseTransformTest.java => DenseTransformTest.java} (90%) diff --git a/core/build.gradle b/core/build.gradle index 1fe1cfe1..e6b20a52 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -40,7 +40,17 @@ dependencies { compile 'com.typesafe:config:1.3.0' compile 'org.slf4j:slf4j-api:1.7.7' compile 'joda-time:joda-time:2.5' - compile 'org.projectlombok:lombok:1.14.8' + compile 'org.projectlombok:lombok:1.16.8' compile 'org.apache.commons:commons-lang3:3.4' + compile 'org.reflections:reflections:0.9.9' + compile 'it.unimi.dsi:fastutil:7.0.12' + + //Validation + compile 'org.hibernate:hibernate-validator:5.2.4.Final' + compile 'javax.el:javax.el-api:2.2.4' + compile 'org.glassfish.web:javax.el:2.2.4' + compile 'org.hibernate:hibernate-validator-cdi:5.2.4.Final' + testCompile 'org.slf4j:slf4j-simple:1.7.7' + testCompile 'junit:junit:4.11' } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/Example.java b/core/src/main/java/com/airbnb/aerosolve/core/Example.java new file mode 100644 index 00000000..cc5c4476 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/Example.java @@ -0,0 +1,45 @@ +package com.airbnb.aerosolve.core; + +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.Transformer; +import java.util.Iterator; + +/** + * + */ +public interface Example extends Iterable { + MultiFamilyVector context(); + + MultiFamilyVector createVector(); + + MultiFamilyVector addToExample(MultiFamilyVector vector); + + Example transform(Transformer transformer, + AbstractModel model); + + default Example transform(Transformer transformer) { + return transform(transformer, null); + } + + /** + * Returns the only MultiFamilyVector in this Example. + * + * If the Example contains nothing or more than one thing, this will throw an + * IllegalStateException. + * + * (Brad): Lots of code assumes the Example has only one item. Ideally, we should remove that + * assumption and then this method. This method helps us find code paths making that assumption. + */ + default MultiFamilyVector only() { + Iterator iterator = iterator(); + if (!iterator.hasNext()) { + throw new IllegalStateException("Called only() on an Example which contains nothing"); + } + MultiFamilyVector result = iterator.next(); + if (iterator.hasNext()) { + throw new IllegalStateException("Called only() on an Example containing more than one vector."); + } + return result; + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/FeatureVector.java b/core/src/main/java/com/airbnb/aerosolve/core/FeatureVector.java new file mode 100644 index 00000000..e3b58160 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/FeatureVector.java @@ -0,0 +1,123 @@ +package com.airbnb.aerosolve.core; + +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.SimpleFeatureValue; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Consumer; + +/** + * When iterating a FeatureVector it may not always return a new copy of each value. If you want + * to save the values returned by the iterator, use the entry set instead. + */ +public interface FeatureVector extends Object2DoubleMap, Iterable { + FeatureRegistry registry(); + + default void putString(Feature feature) { + put(feature, 1.0d); + } + + default Iterator fastIterator() { + return iterator(); + } + + /** + * Use this if you intend to store the values. Don't use foreach. + */ + default Set featureValueEntrySet() { + return Sets.newHashSet(iterator()); + } + + @Override + default void forEach(Consumer action) { + Iterator iter = fastIterator(); + while (iter.hasNext()) { + action.accept(iter.next()); + } + } + + default double get(String familyName, String featureName) { + Feature feature = registry().feature(familyName, featureName); + return getDouble(feature); + } + + default boolean containsKey(String familyName, String featureName) { + Feature feature = registry().feature(familyName, featureName); + return containsKey(feature); + } + + // TODO (Brad): This kind of breaks the abstraction. Do all Features have families? + default void put(String familyName, String featureName, double value) { + Feature feature = registry().feature(familyName, featureName); + put(feature, value); + } + + default void putString(String familyName, String featureName) { + Feature feature = registry().feature(familyName, featureName); + putString(feature); + } + + default Iterable withDropout(double dropout) { + return Iterables.filter(this, f -> ThreadLocalRandom.current().nextDouble() >= dropout); + } + + default Iterable iterateMatching(List features) { + Preconditions.checkNotNull(features, "Cannot iterate all features when features is null"); + return () -> new FeatureSetIterator(this, features); + } + + default double[] denseArray() { + double[] result = new double[size()]; + int i = 0; + for (FeatureValue value : this) { + result[i] = value.value(); + i++; + } + return result; + } + + class FeatureSetIterator implements Iterator { + + private final FeatureVector vector; + private final List features; + private SimpleFeatureValue entry = SimpleFeatureValue.of(null, 0.0d); + private int index = -1; + private int nextIndex = 0; + + public FeatureSetIterator(FeatureVector vector, List features) { + this.vector = vector; + this.features = features; + } + + @Override + public boolean hasNext() { + if (index >= nextIndex) { + nextIndex++; + while(nextIndex < features.size() && !vector.containsKey(features.get(nextIndex))) { + nextIndex++; + } + } + return nextIndex < features.size(); + } + + @Override + public FeatureValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + entry.feature(features.get(nextIndex)); + entry.value(vector.getDouble(entry.feature())); + index = nextIndex; + return entry; + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/BasicFamily.java b/core/src/main/java/com/airbnb/aerosolve/core/features/BasicFamily.java new file mode 100644 index 00000000..34cd4f52 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/BasicFamily.java @@ -0,0 +1,179 @@ +package com.airbnb.aerosolve.core.features; + +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.Getter; +import lombok.Synchronized; +import lombok.Value; +import lombok.experimental.Accessors; + +/** + * + */ +@Accessors(fluent = true, chain = true) +public class BasicFamily implements Family, Serializable { + + private final int hashCode; + private Map featuresByName; + @Getter + private final String name; + @Getter + private final int index; + private final AtomicInteger featureCount; + private Feature[] featuresByIndex; + @Getter + private boolean isDense = false; + private Map> crosses; + + BasicFamily(String name, int index) { + Preconditions.checkNotNull(name, "All Families must have a name"); + this.name = name; + this.index = index; + this.featureCount = new AtomicInteger(0); + this.crosses = new Object2ObjectOpenHashMap<>(); + this.hashCode = name.hashCode(); + } + + @Override + public void markDense() { + if (featuresByName != null && !featuresByName.isEmpty()) { + throw new IllegalStateException("Tried to make a family dense but it already has features" + + " defined by name. Some code probably thinks it's sparse."); + } + isDense = true; + } + + @Override + public Feature feature(String featureName) { + if (isDense) { + // Note that it's not recommended to use this method if the type is DENSE. + // Just call feature(int). It will be faster. + Integer index = Ints.tryParse(featureName); + if (index == null) { + throw new IllegalArgumentException(String.format( + "Could not parse %s to a valid integer for lookup in a dense family: %s. Dense families " + + "do not have names for each feature.", featureName, name())); + } + return feature(index); + } + if (featuresByName == null) { + featuresByName = new ConcurrentHashMap<>(allocationSize()); + } + Feature feature = featuresByName.computeIfAbsent( + featureName, + innerName -> new Feature(this, innerName, featureCount.getAndIncrement()) + ); + if (featuresByIndex == null || feature.index() >= featuresByIndex.length) { + resizeFeaturesByIndex(feature.index()); + } + if (featuresByIndex[feature.index()] == null) { + featuresByIndex[feature.index()] = feature; + } + return feature; + } + + @Override + public Feature feature(int index) { + if (featuresByIndex == null || index >= featuresByIndex.length) { + if (isDense) { + resizeFeaturesByIndex(index); + } else { + return null; + } + } + if (isDense && featuresByIndex[index] == null) { + featuresByIndex[index] = new Feature(this, String.valueOf(index), index); + } + return featuresByIndex[index]; + } + + @Synchronized + private void resizeFeaturesByIndex(int index) { + if (featuresByIndex == null) { + featuresByIndex = new Feature[Family.allocationSize(index + 1)]; + return; + } + // We check outside and inside this method because it's synchronized and this can change + // between when we intend to enter and when we actually enter. + if (index < featuresByIndex.length) { + return; + } + // Need to resize. + int length = featuresByIndex.length; + while (index >= length) { + length = length * 2; + } + featuresByIndex = Arrays.copyOf(featuresByIndex, length); + } + + @Override + public Feature cross(Feature left, Feature right, String separator) { + Map rightMap = crosses.get(left); + if (rightMap == null) { + rightMap = new Object2ObjectOpenHashMap<>(); + crosses.put(left, rightMap); + } + FeatureJoin[] joinArr = rightMap.get(right); + if (joinArr == null) { + joinArr = new FeatureJoin[2]; + rightMap.put(right, joinArr); + } + int i; + for (i = 0; i < joinArr.length; i++) { + FeatureJoin join = joinArr[i]; + if (join == null) { + break; + } + if (join.separator().equals(separator)) { + return join.feature(); + } + } + Feature feature = feature(left.name() + separator + right.name()); + FeatureJoin newJoin = new FeatureJoin(feature, separator); + if (i >= joinArr.length) { + joinArr = Arrays.copyOf(joinArr, joinArr.length * 2); + rightMap.put(right, joinArr); + } + joinArr[i] = newJoin; + return feature; + } + + public int size() { + return featureCount.get(); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof Family)) { + return false; + } + return name.equals(((Family) obj).name()); + } + + @Override + public String toString() { + return name; + } + + @Value + private static class FeatureJoin { + private final Feature feature; + private final String separator; + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/BasicMultiFamilyVector.java b/core/src/main/java/com/airbnb/aerosolve/core/features/BasicMultiFamilyVector.java new file mode 100644 index 00000000..61b9f779 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/BasicMultiFamilyVector.java @@ -0,0 +1,323 @@ +package com.airbnb.aerosolve.core.features; + +import com.airbnb.aerosolve.core.ThriftFeatureVector; +import com.google.common.primitives.Doubles; +import it.unimi.dsi.fastutil.objects.AbstractObject2DoubleMap; +import it.unimi.dsi.fastutil.objects.AbstractObjectIterator; +import it.unimi.dsi.fastutil.objects.AbstractObjectSet; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import it.unimi.dsi.fastutil.objects.ObjectSet; +import java.io.Serializable; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; + +/** + * + */ +@SuppressWarnings("unchecked") +public class BasicMultiFamilyVector extends AbstractObject2DoubleMap + implements MultiFamilyVector, Serializable { + + private final FeatureRegistry registry; + // TODO (Brad): In training the index of the families may be different due to the registries + // being on different machines. This may not work. + private FamilyVector[] families; + private int totalSize; + + public BasicMultiFamilyVector(FeatureRegistry registry) { //featureRegistry.familyCount() + this.families = new FamilyVector[registry.familyCapacity()]; + this.totalSize = 0; + this.registry = registry; + } + + public BasicMultiFamilyVector(BasicMultiFamilyVector other) { + this.families = Arrays.copyOf(other.families, other.families.length); + this.totalSize = other.totalSize; + this.registry = other.registry; + } + + public BasicMultiFamilyVector(ThriftFeatureVector tmp, FeatureRegistry registry) { + this(registry); + if (tmp.getStringFeatures() != null) { + for (Map.Entry> entry : tmp.getStringFeatures().entrySet()) { + Family family = registry.family(entry.getKey()); + for (String featureName : entry.getValue()) { + putString(family.feature(featureName)); + } + } + } + if (tmp.getDenseFeatures() != null) { + for (Map.Entry> entry : tmp.getDenseFeatures().entrySet()) { + Family family = registry.family(entry.getKey()); + putDense(family, Doubles.toArray(entry.getValue())); + } + } + if (tmp.getFloatFeatures() != null) { + for (Map.Entry> entry : tmp.getFloatFeatures().entrySet()) { + Family family = registry.family(entry.getKey()); + for (Map.Entry value : entry.getValue().entrySet()) { + put(family.feature(value.getKey()), (double) value.getValue()); + } + } + } + } + + @Override + public void applyContext(MultiFamilyVector context) { + // TODO (Brad): For now, we copy values in. Ideally, we would instead have this back the vector + // and incur the hit on lookup to save memory and time. + for (FamilyVector vector : context.families()) { + if (vector instanceof DenseVector) { + if (!contains(vector.family())) { + putDense(vector.family(), vector.denseArray()); + } + } else { + for (FeatureValue value : vector) { + if (!containsKey(value.feature())) { + put(value.feature(), value.value()); + } + } + } + } + } + + @Override + public Set families() { + Set familySet = new HashSet<>(); + for (FamilyVector vector : families) { + if (vector != null) { + familySet.add(vector); + } + } + return familySet; + } + + @Override + public void putString(Feature feature) { + put(feature, 1.0d); + } + + @Override + public FamilyVector putDense(Family family, double[] values) { + DenseVector fVec = (DenseVector) family(family, true); + fVec.denseArray(values); + return fVec; + } + + @Override + public double put(Feature feature, double value) { + FamilyVector familyVector = family(feature.family(), false); + return familyVector.put(feature, value); + } + + private FamilyVector family(Family family, boolean isDense) { + if (families.length <= family.index()) { + int newNumFamilies = families.length * 2; + families = Arrays.copyOf(families, newNumFamilies); + } + FamilyVector fVec = families[family.index()]; + if (fVec == null) { + if (isDense) { + fVec = new DenseVector(family, registry); + } else { + // TODO (Brad): Maybe have a small one and a large one? + fVec = new SparseVector(family, registry); + } + families[family.index()] = fVec; + } + return fVec; + } + + @Override + public FeatureRegistry registry() { + return registry; + } + + @Override + public double removeDouble(Object key) { + if (!(key instanceof Feature)) { + return 0.0d; + } + Feature feature = (Feature) key; + FamilyVector fVec = families[feature.family().index()]; + if (fVec != null) { + double result = fVec.removeDouble(feature); + if (fVec.size() == 0) { + remove(feature.family()); + } + return result; + } + return 0.0d; + } + + @Override + public FamilyVector remove(Family family) { + if (family.index() >= families.length) { + return null; + } + + FamilyVector fVec = families[family.index()]; + families[family.index()] = null; + return fVec; + } + + @Override + public FamilyVector get(Family family) { + if (family.index() >= families.length) { + return null; + } + return families[family.index()]; + } + + @Override + public boolean containsKey(Object k) { + if (!(k instanceof Feature)) { + return false; + } + FamilyVector vec = get(((Feature) k).family()); + return vec != null && vec.containsKey(k); + } + + @Override + public boolean contains(Family family) { + return family.index() < families.length && families[family.index()] != null; + } + + @Override + public int size() { + int size = 0; + for (FamilyVector fVec : families) { + if (fVec != null) { + size += fVec.size(); + } + } + return size; + } + + @Override + public double getDouble(Object key) { + if (key instanceof Feature) { + Feature feature = (Feature) key; + if (feature.family().index() < families.length) { + FamilyVector fVec = families[feature.family().index()]; + if (fVec != null) { + return fVec.getDouble(feature); + } + } + } + return 0.0d; + } + + @Override + public MultiFamilyVector withFamilyDropout(double dropout) { + // TODO (Brad): Copying to a new vector is slow. It would be nice if we could make views on + // existing vectors with predicates. + MultiFamilyVector newVector = new BasicMultiFamilyVector(registry); + for (FamilyVector vector : families()) { + if (vector.family().isDense()) { + if (ThreadLocalRandom.current().nextDouble() >= dropout) { + newVector.putDense(vector.family(), vector.denseArray()); + } + } else { + for (FeatureValue value : vector.withDropout(dropout)) { + newVector.put(value.feature(), value.value()); + } + } + } + return newVector; + } + + @Override + public Iterator iterator() { + return (Iterator) new FastMultiFamilyVectorIterator(true); + } + + @Override + public Iterator fastIterator() { + return (Iterator) new FastMultiFamilyVectorIterator(true); + } + + @Override + public ObjectSet> object2DoubleEntrySet() { + return new FastFeatureVectorEntrySet(); + } + + public class FastFeatureVectorEntrySet extends AbstractObjectSet> implements + FastEntrySet { + + @Override + public ObjectIterator> iterator() { + return (ObjectIterator) new FastMultiFamilyVectorIterator(false); + } + + @Override + public int size() { + return BasicMultiFamilyVector.this.size(); + } + + @Override + public ObjectIterator> fastIterator() { + return (ObjectIterator) new FastMultiFamilyVectorIterator(true); + } + } + + private class FastMultiFamilyVectorIterator extends + AbstractObjectIterator { + private int index = -1; + private Iterator iterator; + private final boolean useFast; + private Iterator nextIterator; + private int nextIndex = -1; + + public FastMultiFamilyVectorIterator(boolean useFast) { + this.useFast = useFast; + } + + private void nextIterator() { + nextIndex = index; + while ((nextIterator == null || !nextIterator.hasNext()) && nextIndex < families.length - 1) { + nextIndex++; + if (families[nextIndex] != null) { + nextIterator = useFast + ? (Iterator) families[nextIndex].fastIterator() + : (Iterator) families[nextIndex].iterator(); + } + } + } + + @Override + public boolean hasNext() { + if (iterator != null && iterator.hasNext()) { + return true; + } + if (nextIterator != null) { + return nextIterator.hasNext(); + } + nextIterator(); + return nextIterator != null && nextIterator.hasNext(); + } + + @Override + public FeatureValueEntry next() { + if (iterator != null && iterator.hasNext()) { + return iterator.next(); + } + if (nextIterator == null) { + nextIterator(); + } + iterator = nextIterator; + index = nextIndex; + nextIterator = null; + if (iterator == null) { + throw new NoSuchElementException(); + } + return iterator.next(); + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/DenseVector.java b/core/src/main/java/com/airbnb/aerosolve/core/features/DenseVector.java new file mode 100644 index 00000000..1a9d4aeb --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/DenseVector.java @@ -0,0 +1,169 @@ +package com.airbnb.aerosolve.core.features; + +import it.unimi.dsi.fastutil.objects.AbstractObject2DoubleMap; +import it.unimi.dsi.fastutil.objects.AbstractObjectIterator; +import it.unimi.dsi.fastutil.objects.AbstractObjectSet; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import it.unimi.dsi.fastutil.objects.ObjectSet; +import java.io.Serializable; +import java.util.Iterator; +import java.util.NoSuchElementException; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +/** + * + */ +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +public class DenseVector extends AbstractObject2DoubleMap + implements FamilyVector, Serializable { + + private final Family family; + private final FeatureRegistry registry; + private double[] denseArray; + + public DenseVector(Family family, FeatureRegistry registry) { + family.markDense(); + this.family = family; + this.registry = registry; + } + + @Override + public ObjectSet> object2DoubleEntrySet() { + return new DenseVectorEntrySet(); + } + + @Override + public int size() { + return denseArray.length; + } + + @Override + public double getDouble(Object key) { + if (key instanceof Feature) { + int featureIndex = ((Feature) key).index(); + if (featureIndex < denseArray.length) { + return denseArray[featureIndex]; + } + } + return 0.0d; + } + + @Override + public boolean containsKey(Object k) { + return k instanceof Feature && ((Feature) k).index() < denseArray.length; + } + + @Override + public double put(Feature key, double value) { + throw new UnsupportedOperationException("Don't put to a dense vector. Call setValues()."); + } + + @Override + public double removeDouble(Object key) { + throw new UnsupportedOperationException("Cannot remove values from a DenseFamilyVector"); + } + + @Override + public Iterator iterator() { + return fastIterator(); + } + + @Override + public Iterator fastIterator() { + return (Iterator) new FastDenseVectorIterator(); + } + + private class DenseVectorEntrySet extends AbstractObjectSet> implements + FastEntrySet { + @Override + public ObjectIterator> iterator() { + return fastIterator(); + } + + @Override + public int size() { + return denseArray.length; + } + + @Override + public ObjectIterator> fastIterator() { + return (ObjectIterator) new FastDenseVectorIterator(); + } + } + + private class FastDenseVectorIterator extends AbstractObjectIterator { + + private int index = -1; + private FastEntry entry = new FastEntry(); + + @Override + public boolean hasNext() { + return index < denseArray.length - 1; + } + + @Override + public FeatureValueEntry next() { + if (hasNext()) { + index++; + } else { + throw new NoSuchElementException(); + } + return entry; + } + + private class FastEntry implements FeatureValueEntry { + + @Override + @Deprecated + public Double getValue() { + return denseArray[index]; + } + + @Override + public double setValue(double value) { + throw new UnsupportedOperationException(); + } + + @Override + public double value() { + return denseArray[index]; + } + + @Override + public double getDoubleValue() { + return value(); + } + + @Override + public Feature getKey() { + return family.feature(index); + } + + @Override + public Double setValue(Double value) { + throw new UnsupportedOperationException(); + } + + @Override + public Feature feature() { + return getKey(); + } + + @Override + public void feature(Feature feature) { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return feature() + "::" + value(); + } + } + } + + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/Family.java b/core/src/main/java/com/airbnb/aerosolve/core/features/Family.java new file mode 100644 index 00000000..5da28b7d --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/Family.java @@ -0,0 +1,39 @@ +package com.airbnb.aerosolve.core.features; + +/** + * + */ +public interface Family { + + int MIN_ALLOCATION = 4; + int MAX_ALLOCATION = 64; + + int index(); + + String name(); + + void markDense(); + + Feature feature(String featureName); + + int size(); + + default int allocationSize() { + return allocationSize(size()); + } + + static int allocationSize(int size) { + int max = MIN_ALLOCATION; + while (max < size && max < MAX_ALLOCATION) + { + max = max << 1; + } + return max; + } + + Feature feature(int index); + + Feature cross(Feature left, Feature right, String separator); + + boolean isDense(); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FamilyVector.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FamilyVector.java new file mode 100644 index 00000000..93e5202a --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/FamilyVector.java @@ -0,0 +1,10 @@ +package com.airbnb.aerosolve.core.features; + +import com.airbnb.aerosolve.core.FeatureVector; + +/** + * + */ +public interface FamilyVector extends FeatureVector { + Family family(); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/Feature.java b/core/src/main/java/com/airbnb/aerosolve/core/features/Feature.java new file mode 100644 index 00000000..a23ec665 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/Feature.java @@ -0,0 +1,70 @@ +package com.airbnb.aerosolve.core.features; + +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import java.io.Serializable; + +/** + * + */ +public class Feature implements Comparable, Serializable { + private final Family family; + private final String name; + private final int hashCode; + private final int index; + + // Package private because only Family can create features. + Feature(Family family, String name, int index) { + Preconditions.checkNotNull(family, "Family cannot be null for features"); + Preconditions.checkNotNull(name, "Name cannot be null for features"); + this.family = family; + this.name = name; + + this.index = index; + this.hashCode = 31 * family.hashCode() + name.hashCode(); + } + + public Family family() { + return family; + } + + public String name() { + return name; + } + + public int index() { + return index; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Feature)) { + return false; + } + Feature feature = (Feature) o; + return Objects.equal(family, feature.family) && + feature.name().equals(this.name()); + } + + @Override + public int compareTo(Feature other) { + int famCompare = this.family().index() - other.family().index(); + if (famCompare != 0) { + return famCompare; + } + return this.index() - other.index(); + } + + @Override + public String toString() { + return String.format("%s :: %s", family, name); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureFamily.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureFamily.java deleted file mode 100644 index 0c67d3e0..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureFamily.java +++ /dev/null @@ -1,31 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import lombok.Getter; - -public abstract class FeatureFamily { - @Getter - private final String familyName; - - public FeatureFamily(String familyName) { - this.familyName = familyName; - } - - protected boolean isMyFamily(String name) { - return true; - } - - protected String nameTransform(String name) { - return name; - } - - public boolean add(String name, Object feature) { - if (isMyFamily(name)) { - put(nameTransform(name), (T) feature); - return true; - } else { - return false; - } - } - - protected abstract void put(String name, T feature); -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureGen.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureGen.java deleted file mode 100644 index 9fb5fc6a..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureGen.java +++ /dev/null @@ -1,39 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -/* - use Float.MIN_VALUE as NULL for the float feature. - */ -public class FeatureGen { - private final FeatureMapping mapping; - private Object[] values; - - public FeatureGen(FeatureMapping mapping) { - this.mapping = mapping; - values = new Object[mapping.getNames().length]; - } - - public void add(float[] features, Object c) { - FeatureMapping.Entry e = mapping.getMapping().get(c); - assert(e.length == features.length); - // can't do System.arraycopy(features, 0, values, e.start, e.length); - // due to Float.MIN_VALUE means NULL - for (int i = 0; i < e.length; i++) { - if (features[i] != Float.MIN_VALUE) { - values[i + e.start] = new Double(features[i]); - } - } - } - - public void add(Object[] features, Object c) { - FeatureMapping.Entry e = mapping.getMapping().get(c); - assert(e.length == features.length); - System.arraycopy(features, 0, values, e.start, e.length); - } - - public Features gen() { - return Features.builder(). - names(mapping.getNames()). - values(values). - build(); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureRegistry.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureRegistry.java new file mode 100644 index 00000000..e4347a05 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureRegistry.java @@ -0,0 +1,34 @@ +package com.airbnb.aerosolve.core.features; + +import java.io.Serializable; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * + */ +public class FeatureRegistry implements Serializable { + public static final int MIN_PREDICTED_FAMILIES = 10; + private final Map families; + private final AtomicInteger familyCount; + + public FeatureRegistry() { + familyCount = new AtomicInteger(0); + families = new ConcurrentHashMap<>(familyCapacity()); + } + + public int familyCapacity() { + return Math.max(MIN_PREDICTED_FAMILIES, familyCount.get()); + } + + public Feature feature(String familyName, String featureName) { + return family(familyName).feature(featureName); + } + + public Family family(String familyName) { + return families.computeIfAbsent( + familyName, + name -> new BasicFamily(name, familyCount.getAndIncrement())); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValue.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValue.java new file mode 100644 index 00000000..0f58e27a --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValue.java @@ -0,0 +1,9 @@ +package com.airbnb.aerosolve.core.features; + +/** + * + */ +public interface FeatureValue { + double value(); + Feature feature(); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValueEntry.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValueEntry.java new file mode 100644 index 00000000..bd8192d2 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureValueEntry.java @@ -0,0 +1,10 @@ +package com.airbnb.aerosolve.core.features; + +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; + +/** + * + */ +public interface FeatureValueEntry extends Object2DoubleMap.Entry, FeatureValue { + void feature(Feature feature); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureVectorGen.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureVectorGen.java deleted file mode 100644 index 33ebcd03..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureVectorGen.java +++ /dev/null @@ -1,74 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.*; - -/* - Generate Example from input features and defined featureFamily. - refer to ModelScorerTest.java as how to use FeatureVectorGen - */ -public class FeatureVectorGen { - - // TODO add a new function to consider dense feature. - public static FeatureVector toFeatureVector(Features features, - List stringFamilies, - List floatFamilies) { - FeatureVector featureVector = new FeatureVector(); - // Set string features. - final Map> stringFeatures = new HashMap<>(); - featureVector.setStringFeatures(stringFeatures); - setBIAS(stringFeatures); - - for (StringFamily featureFamily : stringFamilies) { - stringFeatures.put(featureFamily.getFamilyName(), featureFamily.getFeatures()); - } - - final Map> floatFeatures = new HashMap<>(); - featureVector.setFloatFeatures(floatFeatures); - for (FloatFamily featureFamily : floatFamilies) { - floatFeatures.put(featureFamily.getFamilyName(), featureFamily.getFeatures()); - } - - for (int i = 0; i < features.names.length; ++i) { - Object feature = features.values[i]; - - if (feature != null) { - // Integer type = features.types[i]; - String name = features.names[i]; - if (feature instanceof Double || feature instanceof Float || - feature instanceof Integer || feature instanceof Long) { - for (FloatFamily featureFamily : floatFamilies) { - if (featureFamily.add(name, feature)) break; - } - } else if (feature instanceof String) { - for (StringFamily featureFamily : stringFamilies) { - if (featureFamily.add(name, feature)) break; - } - } else if (feature instanceof Boolean){ - for (StringFamily featureFamily : stringFamilies) { - if (featureFamily.add(name, (Boolean) feature)) break; - } - } - } - } - return featureVector; - } - - public static Example toSingleFeatureVectorExample(Features features, - List stringFamilies, - List floatFamilies) { - Example example = new Example(); - FeatureVector featureVector = toFeatureVector( - features, stringFamilies, floatFamilies); - example.addToExample(featureVector); - return example; - } - - protected static void setBIAS(final Map> stringFeatures) { - final Set bias = new HashSet<>(); - bias.add("B"); - stringFeatures.put("BIAS", bias); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/Features.java b/core/src/main/java/com/airbnb/aerosolve/core/features/Features.java deleted file mode 100644 index e47879db..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/Features.java +++ /dev/null @@ -1,157 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.google.common.annotations.VisibleForTesting; -import lombok.experimental.Builder; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; - -import java.lang.reflect.Field; -import java.util.*; - -@Builder @Slf4j -public class Features { - public final static String LABEL = "LABEL"; - public final static String LABEL_FEATURE_NAME = ""; - public final static String MISS = "MISS"; - // In RAW case, don't append feature name - public final static String RAW = "RAW"; - private final static char FAMILY_SEPARATOR = '_'; - private final static char TRUE_FEATURE = 'T'; - private final static char FALSE_FEATURE = 'F'; - - public final String[] names; - public final Object[] values; - - /* - Util function to get features for FeatureMapping - */ - public static List getGenericSortedFeatures(Class c) { - return getGenericSortedFeatures(c.getDeclaredFields()); - } - - public static List getGenericSortedFeatures(Field[] fields) { - List features = new ArrayList<>(); - - for (Field field : fields) { - features.add(field.getName()); - } - // Sort the non-amenity features alphabetically - Collections.sort(features); - return features; - } - - // TODO make it more generic, for example, taking care of dense feature - public Example toExample(boolean isMultiClass) { - assert (names.length == values.length); - if (names.length != values.length) { - throw new RuntimeException("names.length != values.length"); - } - Example example = new Example(); - FeatureVector featureVector = new FeatureVector(); - example.addToExample(featureVector); - - // Set string features. - final Map> stringFeatures = new HashMap<>(); - featureVector.setStringFeatures(stringFeatures); - - final Map> floatFeatures = new HashMap<>(); - featureVector.setFloatFeatures(floatFeatures); - - final Set bias = new HashSet<>(); - final Set missing = new HashSet<>(); - bias.add("B"); - stringFeatures.put("BIAS", bias); - stringFeatures.put(MISS, missing); - - for (int i = 0; i < names.length; i++) { - String name = names[i]; - Object value = values[i]; - if (value == null) { - missing.add(name); - } else { - Pair feature = getFamily(name); - if (value instanceof String) { - String str = (String) value; - if (isMultiClass && isLabel(feature)) { - addMultiClassLabel(str, floatFeatures); - } else { - addStringFeature(str, feature, stringFeatures); - } - } else if (value instanceof Boolean) { - Boolean b = (Boolean) value; - addBoolFeature(b, feature, stringFeatures); - } else { - addNumberFeature((Number) value, feature, floatFeatures); - } - } - } - return example; - } - - @VisibleForTesting - static void addNumberFeature( - Number value, Pair featurePair, Map> floatFeatures) { - Map feature = Util.getOrCreateFloatFeature(featurePair.getLeft(), floatFeatures); - feature.put(featurePair.getRight(), value.doubleValue()); - } - - @VisibleForTesting - static void addBoolFeature( - Boolean b, Pair featurePair, Map> stringFeatures) { - Set feature = Util.getOrCreateStringFeature(featurePair.getLeft(), stringFeatures); - String featureName = featurePair.getRight(); - char str = (b.booleanValue()) ? TRUE_FEATURE : FALSE_FEATURE; - feature.add(featureName + ':' + str); - } - - @VisibleForTesting - static void addStringFeature( - String str, Pair featurePair, Map> stringFeatures) { - Set feature = Util.getOrCreateStringFeature(featurePair.getLeft(), stringFeatures); - String featureName = featurePair.getRight(); - if (featureName.equals(RAW)) { - feature.add(str); - } else { - feature.add(featureName + ":" + str); - } - } - - @VisibleForTesting - static void addMultiClassLabel(String str, Map> floatFeatures) { - String[] labels = str.split(","); - for (String s: labels) { - String[] labelTokens = s.split(":"); - if (labelTokens.length != 2) { - throw new RuntimeException(String.format( - "MultiClass LABEL \"%s\" not in format [label1]:[weight1],...!", str)); - } - Map feature = Util.getOrCreateFloatFeature(LABEL, floatFeatures); - feature.put(labelTokens[0], Double.valueOf(labelTokens[1])); - } - } - - static boolean isLabel(Pair feature) { - return feature.getRight().equals(LABEL_FEATURE_NAME); - } - - @VisibleForTesting - static Pair getFamily(String name) { - int pos = name.indexOf(FAMILY_SEPARATOR); - if (pos == -1) { - if (name.compareTo(LABEL) == 0) { - return new ImmutablePair<>(LABEL, LABEL_FEATURE_NAME) ; - } else { - throw new RuntimeException("Column name not in FAMILY_NAME format or is not LABEL! " + name); - } - } else if (pos == 0) { - throw new RuntimeException("Column name can't prefix with _! " + name); - } else { - return new ImmutablePair<>(name.substring(0, pos), - name.substring(pos + 1)); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FloatFamily.java b/core/src/main/java/com/airbnb/aerosolve/core/features/FloatFamily.java deleted file mode 100644 index 7269dca7..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/FloatFamily.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import lombok.Getter; - -import java.util.HashMap; -import java.util.Map; - -public class FloatFamily extends FeatureFamily { - @Getter - private final Map features; - - public FloatFamily(String familyName) { - super(familyName); - features = new HashMap<>(); - } - - protected void put(String name, Double feature) { - features.put(nameTransform(name), feature); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/GenericNamingConvention.java b/core/src/main/java/com/airbnb/aerosolve/core/features/GenericNamingConvention.java new file mode 100644 index 00000000..940131df --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/GenericNamingConvention.java @@ -0,0 +1,148 @@ +package com.airbnb.aerosolve.core.features; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import lombok.Synchronized; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +/** + * + */ +public class GenericNamingConvention implements NamingConvention, Serializable { + + public final static String LABEL = "LABEL"; + public final static String LABEL_FEATURE_NAME = ""; + public final static String MISS = "MISS"; + // In RAW case, don't append feature name + public final static String RAW = "RAW"; + private final static char FAMILY_SEPARATOR = '_'; + private final static char TRUE_FEATURE = 'T'; + private final static char FALSE_FEATURE = 'F'; + public static final char NAME_SEPARATOR = ':'; + + public static GenericNamingConvention INSTANCE; + + private final Function> labelSplitFunction; + + public GenericNamingConvention( + Function> labelSplitFunction) { + this.labelSplitFunction = labelSplitFunction; + } + + public static GenericNamingConvention instance() { + if (INSTANCE == null) { + setInstance(); + } + return INSTANCE; + } + + @Synchronized + private static void setInstance() { + INSTANCE = new GenericNamingConvention(GenericNamingConvention::genericLabelSplitFunction); + } + + @Override + public NamingConventionResult features(String name, Object value, FeatureRegistry registry) { + Preconditions.checkNotNull(name, "Cannot create a feature from a null name"); + Feature feature; + if (value == null) { + feature = registry.feature(MISS, name); + } else if (value instanceof double[]) { + return NamingConventionResult.builder() + .denseFeatures(ImmutableMap.of(registry.family(name), (double[]) value)) + .build(); + } else if (name.equals(LABEL)) { + if (value instanceof String) { + Set values = extractFeatureValuesFromLabel((String) value, registry); + return NamingConventionResult.builder() + .doubleFeatures(values) + .build(); + } + feature = registry.feature(LABEL, LABEL_FEATURE_NAME); + } else { + feature = parseFeature(name, value, registry); + } + checkState(feature != null, "Could not find a way to parse feature name %s for value %s", + name, value); + return createNamingConventionResult(feature, value); + } + + private Set extractFeatureValuesFromLabel(String value, FeatureRegistry registry) { + // TODO(Brad): This could be a bit of a problem if a non-multiclass feature had a String + // as its label value and that String contained a comma. But that doesn't seem like it makes + // a ton of sense. Anyway, that's why I made the labelSplitFunction an input. But I don't + // know if it helps much. + Family labelFamily = registry.family(LABEL); + Map results = labelSplitFunction.apply(value); + Set values = new HashSet<>(); + for (Map.Entry entry : results.entrySet()) { + values.add(new SimpleFeatureValue(labelFamily.feature(entry.getKey()), + entry.getValue())); + } + return values; + } + + // TODO (Brad): This is really gross. Let's change the way we handle labels to something that + // doesn't involve such tortuous naming conventions. + public static Map genericLabelSplitFunction(String label) { + String[] labels = label.split(","); + Map results = new HashMap<>(); + if (labels.length == 1) { + // This shouldn't really happen. If it's not comma-delimited, it should be a double value. + // But this should work for String labels just in case. + results.put(label, 1.0); + return results; + } + for (String s : labels) { + String[] labelTokens = s.split(":"); + checkArgument(labelTokens.length == 2, + "MultiClass LABEL \"%s\" not in format [label1]:[weight1],...!", label); + results.put(labelTokens[0], Double.valueOf(labelTokens[1])); + } + return results; + } + + private static Feature parseFeature(String name, Object value, FeatureRegistry registry) { + int pos = name.indexOf(FAMILY_SEPARATOR); + checkArgument(pos > 0, + "Column name %s is invalid. It must either be %s or start with a family " + + "name followed by %s and a feature name.", name, LABEL, NAME_SEPARATOR); + String familyName = name.substring(0, pos); + String featureName = name.substring(pos + 1); + if (value instanceof String) { + if (featureName.equals(RAW)) { + featureName = (String) value; + } else { + featureName = featureName + NAME_SEPARATOR + value; + } + } else if (value instanceof Boolean) { + featureName = featureName + NAME_SEPARATOR + ((boolean)value ? TRUE_FEATURE : FALSE_FEATURE); + } + return registry.feature(familyName, featureName); + } + + private static NamingConventionResult createNamingConventionResult(Feature feature, + Object value) { + if (value instanceof Number) { + double val = ((Number) value).doubleValue(); + return NamingConventionResult.builder() + .doubleFeatures(ImmutableSet.of(new SimpleFeatureValue(feature, val))) + .build(); + } else { + return NamingConventionResult.builder() + .stringFeatures(ImmutableSet.of(feature)) + .build(); + } + } + + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/InputGenerator.java b/core/src/main/java/com/airbnb/aerosolve/core/features/InputGenerator.java new file mode 100644 index 00000000..fd24f3bf --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/InputGenerator.java @@ -0,0 +1,44 @@ +package com.airbnb.aerosolve.core.features; + +/* + use Float.MIN_VALUE as NULL for the float feature. + */ +// TODO (Brad): I'm not sure this class belongs in core. It's very specific to one way we generate +// our feature vectors from external data. It seems like if we just tracked the String[] per class +// through some external mechanism, we could just call MultiFamilyVector.putAll(String[], Object[]) +// every time we got a new Object[] or double[]. +// For now, it's easy to make it work but I'm going to discuss with Julian before merging. +public class InputGenerator { + private final InputSchema mapping; + private Object[] values; + + public InputGenerator(InputSchema mapping) { + this.mapping = mapping; + values = new Object[mapping.getNames().length]; + } + + public void add(double[] features, Object c) { + InputSchema.Entry e = mapping.getMapping().get(c); + assert(e.length == features.length); + // can't do System.arraycopy(features, 0, values, e.start, e.length); + // due to Float.MIN_VALUE means NULL + for (int i = 0; i < e.length; i++) { + if (features[i] != Float.MIN_VALUE) { + values[i + e.start] = features[i]; + } + } + } + + public void add(Object[] features, Object c) { + InputSchema.Entry e = mapping.getMapping().get(c); + assert(e.length == features.length); + System.arraycopy(features, 0, values, e.start, e.length); + } + + /** + * Load the schema and values in this generator into a MultiFamilyVector. + */ + public MultiFamilyVector load(MultiFamilyVector vector) { + return vector.putAll(mapping.getNames(), values); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureMapping.java b/core/src/main/java/com/airbnb/aerosolve/core/features/InputSchema.java similarity index 87% rename from core/src/main/java/com/airbnb/aerosolve/core/features/FeatureMapping.java rename to core/src/main/java/com/airbnb/aerosolve/core/features/InputSchema.java index ffecc82a..ec519160 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/FeatureMapping.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/InputSchema.java @@ -1,16 +1,19 @@ package com.airbnb.aerosolve.core.features; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Getter; -import java.util.*; - /* Features coming from differenct sources, and output as array[], FeatureMapping helps to save incoming feature in the right index of final output array[] it use incoming feature names array as key to locate the index. refer to ModelScorerTest.java as how to use FeatureMapping */ -public class FeatureMapping { +public class InputSchema { public final static int DEFAULT_SIZE = 100; @Getter private String[] names; @@ -23,11 +26,11 @@ public static final class Entry { int length; } - public FeatureMapping() { + public InputSchema() { this(DEFAULT_SIZE); } - public FeatureMapping(int size) { + public InputSchema(int size) { nameList = new ArrayList<>(size); mapping = new HashMap<>(size); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/MultiFamilyVector.java b/core/src/main/java/com/airbnb/aerosolve/core/features/MultiFamilyVector.java new file mode 100644 index 00000000..d057317f --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/MultiFamilyVector.java @@ -0,0 +1,114 @@ +package com.airbnb.aerosolve.core.features; + +import com.airbnb.aerosolve.core.FeatureVector; +import java.util.Map; +import java.util.Set; + +import javax.validation.constraints.NotNull; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * + */ +public interface MultiFamilyVector extends FeatureVector { + + FamilyVector putDense(Family family, double[] values); + + FamilyVector remove(Family family); + + FamilyVector get(Family family); + + boolean contains(Family family); + + void applyContext(MultiFamilyVector context); + + Set families(); + + default MultiFamilyVector merge(MultiFamilyVector vector) { + for (FamilyVector familyVector : vector.families()) { + // Assuming we want to keep any existing values. + if (familyVector.family().isDense()) { + putDense(familyVector.family(), familyVector.denseArray()); + } else { + for (FeatureValue value : familyVector) { + put(value.feature(), value.value()); + } + } + } + return this; + } + + default FamilyVector putDense(String familyName, double[] values) { + Family family = registry().family(familyName); + return putDense(family, values); + } + + default FamilyVector remove(String familyName) { + Family family = registry().family(familyName); + return remove(family); + } + + default FamilyVector get(String familyName) { + Family family = registry().family(familyName); + return get(family); + } + + default boolean contains(String familyName) { + Family family = registry().family(familyName); + return contains(family); + } + + default MultiFamilyVector putAllObjects(Map features, + NamingConvention namingConvention) { + return putAll(features.keySet().toArray(new String[features.size()]), + features.values().toArray(), + namingConvention); + + } + + default MultiFamilyVector putAllObjects(Map features) { + return putAllObjects(features, GenericNamingConvention.instance()); + } + + default MultiFamilyVector putAll(@NotNull String[] names, @NotNull Object[] values, + NamingConvention namingConvention) { + checkNotNull(names, "Names cannot be null when putting to a MultiFamilyVector"); + checkNotNull(values, "Values cannot be null when putting to a MultiFamilyVector"); + checkArgument(names.length == values.length, + "When putting arrays to a MultiFamilyVector, the names and values" + + " were of different sizes. Names size: %d Values size: %d", + names.length, values.length); + for (int i = 0; i < names.length; i++) { + NamingConvention.NamingConventionResult result = + namingConvention.features(names[i], values[i], registry()); + if (result.getDenseFeatures() != null) { + for (Map.Entry denseEntry : result.getDenseFeatures().entrySet()) { + putDense(denseEntry.getKey(), denseEntry.getValue()); + } + } + if (result.getStringFeatures() != null) { + for (Feature feature : result.getStringFeatures()) { + putString(feature); + } + } + if (result.getDoubleFeatures() != null) { + for (FeatureValue value : result.getDoubleFeatures()) { + put(value.feature(), value.value()); + } + } + } + return this; + } + + default MultiFamilyVector putAll(String[] names, Object[] values) { + return putAll(names, values, GenericNamingConvention.instance()); + } + + default int numFamilies() { + return families().size(); + } + + MultiFamilyVector withFamilyDropout(double dropout); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/NamingConvention.java b/core/src/main/java/com/airbnb/aerosolve/core/features/NamingConvention.java new file mode 100644 index 00000000..646c61ba --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/NamingConvention.java @@ -0,0 +1,22 @@ +package com.airbnb.aerosolve.core.features; + +import java.util.Map; +import java.util.Set; +import lombok.Builder; +import lombok.Value; + +/** + * + */ +public interface NamingConvention { + + NamingConventionResult features(String name, Object value, FeatureRegistry registry); + + @Value + @Builder + class NamingConventionResult { + private final Set stringFeatures; + private final Set doubleFeatures; + private final Map denseFeatures; + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleExample.java b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleExample.java new file mode 100644 index 00000000..ccc76093 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleExample.java @@ -0,0 +1,92 @@ +package com.airbnb.aerosolve.core.features; + +import com.airbnb.aerosolve.core.Example; +import com.airbnb.aerosolve.core.ThriftExample; +import com.airbnb.aerosolve.core.ThriftFeatureVector; +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.transforms.Transformer; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * + */ +public class SimpleExample implements Example, Serializable { + + + + private final MultiFamilyVector context; + private final List vectors; + private final FeatureRegistry registry; + + public SimpleExample(FeatureRegistry registry) { + this.context = new BasicMultiFamilyVector(registry); + this.vectors = new ArrayList<>(); + this.registry = registry; + } + + public SimpleExample(ThriftExample example, FeatureRegistry registry) { + this.registry = registry; + if (example.getContext() != null) { + this.context = new BasicMultiFamilyVector(example.getContext(), registry); + } else { + this.context = new BasicMultiFamilyVector(registry); + } + this.vectors = new ArrayList<>(example.getExampleSize()); + if (example.getExample() != null) { + for (ThriftFeatureVector vector : example.getExample()) { + vectors.add(new BasicMultiFamilyVector(vector, registry)); + } + } + } + + @Override + public MultiFamilyVector context() { + return context; + } + + @Override + public MultiFamilyVector createVector() { + return addToExample(new BasicMultiFamilyVector(registry)); + } + + @Override + public MultiFamilyVector addToExample(MultiFamilyVector vector) { + vectors.add(vector); + return vector; + } + + @Override + public Iterator iterator() { + return vectors.iterator(); + } + + @Override + public Example transform(Transformer transformer, AbstractModel model) { + // TODO (Brad): Enable immutability by handling the return value. + if (transformer.getContextTransform() != null) { + transformer.getContextTransform().apply(context); + } + for (MultiFamilyVector vector : vectors) { + if (transformer.getItemTransform() != null) { + transformer.getItemTransform().apply(vector); + } + if (transformer.getCombinedTransform() != null) { + // TODO (Brad): REVISIT THIS BEFORE MERGING!! + // Ideally we wouldn't copy the context into each vector but instead have it back the + // vector as a fallback. This would avoid the costs of copying and the memory usage. + // But it might be slower. Need to test. + // This changes the semantics of context. Previously an entire context family would + // overwrite an item family. This is a merge that favors features + // in the item. I think this is a better semantics as it's more useful in the cases I've + // seen but I'm not sure what other cases people have encountered. + vector.applyContext(context); + transformer.getCombinedTransform().apply(vector); + } + } + return this; + } + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValue.java b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValue.java new file mode 100644 index 00000000..0b493b37 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValue.java @@ -0,0 +1,24 @@ +package com.airbnb.aerosolve.core.features; + +import java.io.Serializable; +import lombok.Data; +import lombok.experimental.Accessors; + +/** + * + */ +@Data +@Accessors(fluent = true, chain = false) +public class SimpleFeatureValue implements FeatureValue, Serializable { + protected Feature feature; + protected double value; + + SimpleFeatureValue(Feature feature, double value) { + this.feature = feature; + this.value = value; + } + + public static SimpleFeatureValue of(Feature feature, double value) { + return new SimpleFeatureValue(feature, value); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValueEntry.java b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValueEntry.java new file mode 100644 index 00000000..59e96e6f --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/SimpleFeatureValueEntry.java @@ -0,0 +1,43 @@ +package com.airbnb.aerosolve.core.features; + +import java.io.Serializable; + +/** + * + */ +public class SimpleFeatureValueEntry extends SimpleFeatureValue + implements FeatureValueEntry, Serializable { + + SimpleFeatureValueEntry(Feature feature, double value) { + super(feature, value); + } + + @Override + @Deprecated + public Double getValue() { + return value; + } + + @Override + public double setValue(double value) { + this.value = value; + return this.value; + } + + @Override + public double getDoubleValue() { + return value; + } + + @Override + public Feature getKey() { + return feature; + } + + @Override + public Double setValue(Double value) { + return setValue((double)value); + } + + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/SparseVector.java b/core/src/main/java/com/airbnb/aerosolve/core/features/SparseVector.java new file mode 100644 index 00000000..83412077 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/features/SparseVector.java @@ -0,0 +1,78 @@ +package com.airbnb.aerosolve.core.features; + +import it.unimi.dsi.fastutil.objects.AbstractObject2DoubleMap; +import it.unimi.dsi.fastutil.objects.AbstractObjectIterator; +import it.unimi.dsi.fastutil.objects.AbstractObjectSet; +import it.unimi.dsi.fastutil.objects.AbstractReference2DoubleMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import it.unimi.dsi.fastutil.objects.ObjectSet; +import java.io.Serializable; +import java.util.Iterator; +import java.util.NoSuchElementException; +import lombok.Getter; +import lombok.Synchronized; +import lombok.experimental.Accessors; + +/** + * + */ +public class SparseVector extends Object2DoubleOpenHashMap implements FamilyVector, + Serializable { + + @Getter + @Accessors(fluent = true) + private final Family family; + @Getter + @Accessors(fluent = true) + private final FeatureRegistry registry; + + public SparseVector(Family family, FeatureRegistry registry) { + super(family.allocationSize()); + this.family = family; + this.registry = registry; + } + @Override + public Iterator iterator() { + // TODO (Brad): We use fast for speed and memory usage when using foreach. But it's a bit + // dangerous because it mutates the instance returned from the iterator. + return new SparseVectorIterator(this.object2DoubleEntrySet().iterator(), false); + } + + @Override + public Iterator fastIterator() { + return new SparseVectorIterator(this.object2DoubleEntrySet().fastIterator(), true); + } + + private class SparseVectorIterator implements Iterator { + + private final boolean goFast; + private final ObjectIterator> iterator; + private final SimpleFeatureValueEntry value; + + public SparseVectorIterator(ObjectIterator> iterator, + boolean goFast) { + this.iterator = iterator; + this.goFast = goFast; + this.value = new SimpleFeatureValueEntry(null, 0.0d); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public FeatureValueEntry next() { + Entry entry = iterator.next(); + if (goFast) { + value.feature(entry.getKey()); + value.value(entry.getDoubleValue()); + return value; + } else { + return new SimpleFeatureValueEntry(entry.getKey(), entry.getDoubleValue()); + } + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/features/StringFamily.java b/core/src/main/java/com/airbnb/aerosolve/core/features/StringFamily.java deleted file mode 100644 index ea2d46ec..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/features/StringFamily.java +++ /dev/null @@ -1,34 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import lombok.Getter; - -import java.util.HashSet; -import java.util.Set; - -public class StringFamily extends FeatureFamily { - @Getter - private final Set features; - - public StringFamily(String familyName) { - super(familyName); - features = new HashSet<>(); - } - - @Override - protected void put(String name, String feature) { - features.add(feature); - } - - public boolean add(String name, Boolean feature) { - String value = getBooleanFeatureAsString(name, feature); - return add(name, value); - } - - protected String getBooleanFeatureAsString(String name, Boolean feature) { - if (feature) { - return name + ":T"; - } else { - return name + ":F"; - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/AbstractFunction.java b/core/src/main/java/com/airbnb/aerosolve/core/function/AbstractFunction.java deleted file mode 100644 index 5f4efd27..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/AbstractFunction.java +++ /dev/null @@ -1,58 +0,0 @@ -package com.airbnb.aerosolve.core.function; - -import com.airbnb.aerosolve.core.FunctionForm; -import com.airbnb.aerosolve.core.ModelRecord; -import lombok.Getter; -import lombok.Setter; - -import java.lang.reflect.InvocationTargetException; -import java.util.Arrays; -import java.util.List; - -/** - * Base class for functions - */ -public abstract class AbstractFunction implements Function { - @Getter - @Setter - protected float[] weights; - @Getter - protected float minVal; - @Getter - protected float maxVal; - - @Override - public String toString() { - return String.format("minVal=%f, maxVal=%f, weights=%s", - minVal, maxVal, Arrays.toString(weights)); - } - - @Override - public float evaluate(List values) { - throw new RuntimeException("method not implemented"); - } - - @Override - public void update(float delta, List values){ - throw new RuntimeException("method not implemented"); - } - - public static Function buildFunction(ModelRecord record) { - FunctionForm funcForm = record.getFunctionForm(); - try { - return (Function) Class.forName("com.airbnb.aerosolve.core.function." + - funcForm.name()).getDeclaredConstructor(ModelRecord.class).newInstance(record); - } catch (InstantiationException e) { - e.printStackTrace(); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } catch (InvocationTargetException e) { - e.printStackTrace(); - } catch (NoSuchMethodException e) { - e.printStackTrace(); - } catch (ClassNotFoundException e) { - e.printStackTrace(); - } - throw new RuntimeException("no such function " + funcForm.name()); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/Function.java b/core/src/main/java/com/airbnb/aerosolve/core/function/Function.java deleted file mode 100644 index 9d20bba0..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/Function.java +++ /dev/null @@ -1,30 +0,0 @@ -package com.airbnb.aerosolve.core.function; - -import com.airbnb.aerosolve.core.ModelRecord; - -import java.io.Serializable; -import java.util.List; - -public interface Function extends Serializable { - // TODO rename numBins to something else, since it's a Spline specific thing - Function aggregate(Iterable functions, float scale, int numBins); - - float evaluate(float ... x); - // TODO change all float to double - float evaluate(List values); - - void update(float delta, float ... values); - void update(float delta, List values); - - ModelRecord toModelRecord(String featureFamily, String featureName); - - void setPriors(float[] params); - - void LInfinityCap(float cap); - - float LInfinityNorm(); - - void resample(int newBins); - - void smooth(double tolerance); -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/Spline.java b/core/src/main/java/com/airbnb/aerosolve/core/function/Spline.java deleted file mode 100644 index 84707f38..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/Spline.java +++ /dev/null @@ -1,199 +0,0 @@ -package com.airbnb.aerosolve.core.function; - -import com.airbnb.aerosolve.core.FunctionForm; -import com.airbnb.aerosolve.core.ModelRecord; -import com.google.common.primitives.Floats; - -import java.util.ArrayList; -import java.util.List; - -// A piecewise linear spline implementation supporting updates. -public class Spline extends AbstractFunction { - private static final long serialVersionUID = 5166347177557768302L; - - private int numBins; - private float scale; - private float binSize; - private float binScale; - - public Spline(float minVal, float maxVal, float[] weights) { - setupSpline(minVal, maxVal, weights); - } - - public Spline(float minVal, float maxVal, int numBins) { - if (maxVal <= minVal) { - maxVal = minVal + 1.0f; - } - setupSpline(minVal, maxVal, new float[numBins]); - } - - /* - Generates new weights[] from numBins - */ - public float[] weightsByNumBins(int numBins) { - if (numBins == this.numBins) { - return weights; - } else { - return newWeights(numBins); - } - } - - private float[] newWeights(int numBins) { - assert (numBins != this.numBins); - float[] newWeights = new float[numBins]; - float scale = 1.0f / (numBins - 1.0f); - float diff = maxVal - minVal; - for (int i = 0; i < numBins; i++) { - float t = i * scale; - float x = diff * t + minVal; - newWeights[i] = evaluate(x); - } - return newWeights; - } - - // A constructor from model record - public Spline(ModelRecord record) { - this.minVal = (float) record.getMinVal(); - this.maxVal = (float) record.getMaxVal(); - List weightVec = record.getWeightVector(); - this.numBins = weightVec.size(); - this.weights = new float[this.numBins]; - for (int j = 0; j < numBins; j++) { - this.weights[j] = weightVec.get(j).floatValue(); - } - float diff = Math.max(maxVal - minVal, 1e-10f); - this.scale = 1.0f / diff; - this.binSize = diff / (numBins - 1.0f); - this.binScale = 1.0f / binSize; - } - - private void setupSpline(float minVal, float maxVal, float[] weights) { - this.weights = weights; - this.numBins = weights.length; - this.minVal = minVal; - this.maxVal = maxVal; - float diff = Math.max(maxVal - minVal, 1e-10f); - this.scale = 1.0f / diff; - this.binSize = diff / (numBins - 1.0f); - this.binScale = 1.0f / binSize; - } - - @Override - public Function aggregate(Iterable functions, float scale, int numBins) { - float[] aggWeights = new float[numBins]; - - for (Function fun : functions) { - Spline spline = (Spline) fun; - float[] w = spline.weightsByNumBins(numBins); - for (int i = 0; i < numBins; i++) { - aggWeights[i] += scale * w[i]; - } - } - return new Spline(minVal, maxVal, aggWeights); - } - - @Override - public float evaluate(float... x) { - int bin = getBin(x[0]); - if (bin == numBins - 1) { - return weights[numBins - 1]; - } - float t = getBinT(x[0], bin); - t = Math.max(0.0f, Math.min(1.0f, t)); - float result = (1.0f - t) * weights[bin] + t * weights[bin + 1]; - return result; - } - - @Override - public void update(float delta, float... values) { - float x = values[0]; - int bin = getBin(x); - if (bin == numBins - 1) { - weights[numBins - 1] += delta; - } else { - float t = getBinT(x, bin); - t = Math.max(0.0f, Math.min(1.0f, t)); - weights[bin] += (1.0f - t) * delta; - weights[bin + 1] += t * delta; - } - } - - @Override - public ModelRecord toModelRecord(String featureFamily, String featureName) { - ModelRecord record = new ModelRecord(); - record.setFunctionForm(FunctionForm.Spline); - record.setFeatureFamily(featureFamily); - record.setFeatureName(featureName); - ArrayList arrayList = new ArrayList(); - for (int i = 0; i < weights.length; i++) { - arrayList.add((double) weights[i]); - } - record.setWeightVector(arrayList); - record.setMinVal(minVal); - record.setMaxVal(maxVal); - return record; - } - - @Override - public void resample(int newBins) { - if (newBins != numBins) { - setupSpline(minVal, maxVal, newWeights(newBins)); - } - } - - // Returns the lower bound bin - public int getBin(float x) { - int bin = (int) Math.floor((x - minVal) * scale * (numBins - 1)); - bin = Math.max(0, Math.min(numBins - 1, bin)); - return bin; - } - - // Returns the t value in the bin (0, 1) - public float getBinT(float x, int bin) { - float lowerX = bin * binSize + minVal; - float t = (x - lowerX) * binScale; - t = Math.max(0.0f, Math.min(1.0f, t)); - return t; - } - - public float L1Norm() { - float sum = 0.0f; - for (int i = 0; i < weights.length; i++) { - sum += Math.abs(weights[i]); - } - return sum; - } - - @Override - public float LInfinityNorm() { - return Math.max(Floats.max(weights), Math.abs(Floats.min(weights))); - } - - @Override - public void LInfinityCap(float cap) { - if (cap <= 0.0f) return; - float currentNorm = this.LInfinityNorm(); - if (currentNorm > cap) { - float scale = cap / currentNorm; - for (int i = 0; i < weights.length; i++) { - weights[i] *= scale; - } - } - } - - @Override - public void setPriors(float[] params) { - float start = params[0]; - float end = params[1]; - // fit a line based on the input starting weight and ending weight - for (int i = 0; i < numBins; i++) { - float t = i / (numBins - 1.0f); - weights[i] = ((1.0f - t) * start + t * end); - } - } - - @Override - public void smooth(double tolerance) { - FunctionUtil.smooth(tolerance, weights); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/functions/AbstractFunction.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/AbstractFunction.java new file mode 100644 index 00000000..313bf5d0 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/AbstractFunction.java @@ -0,0 +1,64 @@ +package com.airbnb.aerosolve.core.functions; + +import com.airbnb.aerosolve.core.FunctionForm; +import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.util.Util; +import java.lang.reflect.Constructor; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; + +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; + +/** + * Base class for functions + */ +public abstract class AbstractFunction implements Function { + @Getter + @Setter + protected double[] weights; + @Getter + protected double minVal; + @Getter + protected double maxVal; + + private static Map> FUNCTION_CONSTRUCTORS; + + @Override + public String toString() { + return String.format("minVal=%f, maxVal=%f, weights=%s", + minVal, maxVal, Arrays.toString(weights)); + } + + public static Function buildFunction(ModelRecord record) { + FunctionForm funcForm = record.getFunctionForm(); + if (FUNCTION_CONSTRUCTORS == null) { + loadFunctionConstructors(); + } + String name = funcForm.name().toLowerCase(); + Constructor constructor = + FUNCTION_CONSTRUCTORS.get(funcForm.name().toLowerCase()); + if (constructor == null) { + throw new IllegalArgumentException( + String.format("No function exists with name %s", name)); + } + try { + return constructor.newInstance(record); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException( + String.format("There was an error instantiating Function of type %s : %s", + name, e.getMessage()), e); + } + } + + private static synchronized void loadFunctionConstructors() { + if (FUNCTION_CONSTRUCTORS != null) { + return; + } + FUNCTION_CONSTRUCTORS = Util.loadConstructorsFromPackage(AbstractFunction.class, + "com.airbnb.aerosolve.core.functions", + "", + ModelRecord.class); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/functions/Function.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/Function.java new file mode 100644 index 00000000..dd9b2b72 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/Function.java @@ -0,0 +1,26 @@ +package com.airbnb.aerosolve.core.functions; + +import com.airbnb.aerosolve.core.ModelRecord; + +import java.io.Serializable; + +public interface Function extends Serializable { + // TODO rename numBins to something else, since it's a Spline specific thing + Function aggregate(Iterable functions, double scale, int numBins); + + double evaluate(double ... x); + + void update(double delta, double ... values); + + ModelRecord toModelRecord(String featureFamily, String featureName); + + void setPriors(double[] params); + + void LInfinityCap(double cap); + + double LInfinityNorm(); + + void resample(int newBins); + + void smooth(double tolerance); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/FunctionUtil.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/FunctionUtil.java similarity index 53% rename from core/src/main/java/com/airbnb/aerosolve/core/function/FunctionUtil.java rename to core/src/main/java/com/airbnb/aerosolve/core/functions/FunctionUtil.java index e842224d..ea594ee1 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/FunctionUtil.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/FunctionUtil.java @@ -1,32 +1,31 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import java.util.Arrays; -import java.util.List; public class FunctionUtil { - public static float[] fitPolynomial(float[] data) { + public static double[] fitPolynomial(double[] data) { int numCoeff = 6; int iterations = numCoeff * 4; - float[] initial = new float[numCoeff]; + double[] initial = new double[numCoeff]; - float[] initialStep = new float[numCoeff]; - Arrays.fill(initialStep, 1.0f); + double[] initialStep = new double[numCoeff]; + Arrays.fill(initialStep, 1.0d); return optimize(1.0 / 512.0, iterations, initial, initialStep, - new ImmutablePair(-10.0f, 10.0f), data); + new ImmutablePair(-10.0d, 10.0d), data); } - public static float evaluatePolynomial(float[] coeff, float[] data, boolean overwrite) { + public static double evaluatePolynomial(double[] coeff, double[] data, boolean overwrite) { int len = data.length; - float err = 0; + double err = 0; long count = 0; for (int i = 0; i < len; i++) { - float t = (float) i / (len - 1); - float tinv = 1 - t; - float diracStart = (i == 0) ? coeff[0] : 0; - float diracEnd = (i == len - 1) ? coeff[1] : 0; + double t = (double) i / (len - 1); + double tinv = 1 - t; + double diracStart = (i == 0) ? coeff[0] : 0; + double diracEnd = (i == len - 1) ? coeff[1] : 0; double eval = coeff[2] * tinv * tinv * tinv + coeff[3] * 3.0 * tinv * tinv * t + coeff[4] * 3.0 * tinv * t * t + @@ -38,29 +37,29 @@ public static float evaluatePolynomial(float[] coeff, float[] data, boolean over count++; } if (overwrite) { - data[i] = (float) eval; + data[i] = eval; } } return err / count; } // CyclicCoordinateDescent - public static float[] optimize(double tolerance, int iterations, - float[] initial, float[] initialStep, - Pair bounds, float[] data) { - float[] best = initial; - float bestF = evaluatePolynomial(best, data, false); + public static double[] optimize(double tolerance, int iterations, + double[] initial, double[] initialStep, + Pair bounds, double[] data) { + double[] best = initial; + double bestF = evaluatePolynomial(best, data, false); int maxDim = initial.length; for (int i = 0; i < iterations; ++i) { for (int dim = 0; dim < maxDim; ++dim) { - float step = initialStep[dim]; + double step = initialStep[dim]; while (step > tolerance) { - float[] left = best.clone(); + double[] left = best.clone(); left[dim] = Math.max(bounds.getLeft(), best[dim] - step); - float leftF = evaluatePolynomial(left, data, false); - float[] right = best.clone(); + double leftF = evaluatePolynomial(left, data, false); + double[] right = best.clone(); right[dim] = Math.min(bounds.getRight(), best[dim] + step); - float rightF = evaluatePolynomial(right, data, false); + double rightF = evaluatePolynomial(right, data, false); if (leftF < bestF) { best = left; bestF = leftF; @@ -76,25 +75,17 @@ public static float[] optimize(double tolerance, int iterations, return best; } - public static float[] toFloat(List list) { - float[] result = new float[list.size()]; - for (int i = 0; i < result.length; i++) { - result[i] = list.get(i).floatValue(); - } - return result; - } - /* * @param tolerance if fitted array's deviation from weights is less than tolerance * use the fitted, otherwise keep original weights. * @param weights the curve you want to smooth * @return true if weights is modified by fitted curve. */ - public static boolean smooth(double tolerance, float[] weights) { + public static boolean smooth(double tolerance, double[] weights) { // TODO use apache math's PolynomialCurveFitter // compile 'org.apache.commons:commons-math3:3.6.1' - float[] best = FunctionUtil.fitPolynomial(weights); - float errAndCoeff = FunctionUtil.evaluatePolynomial(best, weights, false); + double[] best = FunctionUtil.fitPolynomial(weights); + double errAndCoeff = FunctionUtil.evaluatePolynomial(best, weights, false); if (errAndCoeff < tolerance) { FunctionUtil.evaluatePolynomial(best, weights, true); return true; diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/Linear.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/Linear.java similarity index 65% rename from core/src/main/java/com/airbnb/aerosolve/core/function/Linear.java rename to core/src/main/java/com/airbnb/aerosolve/core/functions/Linear.java index 318bdc1d..f07918a0 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/Linear.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/Linear.java @@ -1,8 +1,7 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelRecord; - import java.util.ArrayList; import java.util.List; @@ -17,20 +16,20 @@ public Linear(Linear other) { maxVal = other.getMaxVal(); } - public Linear(float minVal, float maxVal) { - this(minVal, maxVal, new float[2]); + public Linear(double minVal, double maxVal) { + this(minVal, maxVal, new double[2]); } - public Linear(float minVal, float maxVal, float[] weights) { + public Linear(double minVal, double maxVal, double[] weights) { this.weights = weights; this.minVal = minVal; this.maxVal = maxVal; } @Override - public Function aggregate(Iterable functions, float scale, int numBins) { + public Function aggregate(Iterable functions, double scale, int numBins) { int length = weights.length; - float[] aggWeights = new float[length]; + double[] aggWeights = new double[length]; for (Function fun: functions) { Linear linear = (Linear) fun; @@ -44,52 +43,51 @@ public Function aggregate(Iterable functions, float scale, int numBins public Linear(ModelRecord record) { List weightVec = record.getWeightVector(); int n = weightVec.size(); - weights = new float[2]; + weights = new double[2]; for (int j = 0; j < Math.min(n, 2); j++) { - weights[j] = weightVec.get(j).floatValue(); + weights[j] = weightVec.get(j); } - minVal = (float) record.getMinVal(); - maxVal = (float) record.getMaxVal(); + minVal = record.getMinVal(); + maxVal = record.getMaxVal(); } @Override - public void update(float delta, float ... values) { + public void update(double delta, double... values) { weights[0] += delta; weights[1] += delta * normalization(values[0]); } @Override - public void setPriors(float[] params) { + public void setPriors(double[] params) { weights[0] = params[0]; weights[1] = params[1]; } - @Override - public float evaluate(float ... x) { + public double evaluate(double... x) { return weights[0] + weights[1] * normalization(x[0]); } @Override public ModelRecord toModelRecord(String featureFamily, String featureName) { ModelRecord record = new ModelRecord(); - record.setFunctionForm(FunctionForm.Linear); + record.setFunctionForm(FunctionForm.LINEAR); record.setFeatureFamily(featureFamily); record.setFeatureName(featureName); record.setMinVal(minVal); record.setMaxVal(maxVal); - ArrayList arrayList = new ArrayList(); - arrayList.add((double) weights[0]); - arrayList.add((double) weights[1]); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(weights[0]); + arrayList.add(weights[1]); record.setWeightVector(arrayList); return record; } @Override - public void LInfinityCap(float cap) { + public void LInfinityCap(double cap) { if (cap <= 0.0f) return; - float currentNorm = this.LInfinityNorm(); + double currentNorm = this.LInfinityNorm(); if (currentNorm > cap) { - float scale = cap / currentNorm; + double scale = cap / currentNorm; for (int i = 0; i < weights.length; i++) { weights[i] *= scale; } @@ -97,13 +95,13 @@ public void LInfinityCap(float cap) { } @Override - public float LInfinityNorm() { - float f0 = weights[0]; - float f1 = weights[0] + weights[1]; + public double LInfinityNorm() { + double f0 = weights[0]; + double f1 = weights[0] + weights[1]; return Math.max(Math.abs(f0), Math.abs(f1)); } - private float normalization(float x) { + private double normalization(double x) { if (minVal < maxVal) { return (x - minVal) / (maxVal - minVal); } else if (minVal == maxVal && maxVal != 0){ diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionPoint.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionPoint.java similarity index 94% rename from core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionPoint.java rename to core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionPoint.java index 84abee3a..c5d359c7 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionPoint.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionPoint.java @@ -1,4 +1,4 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.util.Util; import lombok.Getter; @@ -25,7 +25,7 @@ public MultiDimensionPoint(List coordinates) { public void updateWeight(double delta) { weight += delta; } - public void scaleWeight(float scale) { + public void scaleWeight(double scale) { weight *= scale; } @@ -104,11 +104,11 @@ public String toString() { return sb.toString(); } - public float getDistance(float[] coordinates) { + public double getDistance(double[] coordinates) { return Util.euclideanDistance(coordinates, this.coordinates); } - public float getDistance(List coordinates) { + public double getDistance(List coordinates) { return Util.euclideanDistance(coordinates, this.coordinates); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionSpline.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionSpline.java similarity index 68% rename from core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionSpline.java rename to core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionSpline.java index a4d3a275..71e92368 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/function/MultiDimensionSpline.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/MultiDimensionSpline.java @@ -1,10 +1,14 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.NDTreeNode; import com.airbnb.aerosolve.core.models.NDTreeModel; +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.primitives.Doubles; import lombok.extern.slf4j.Slf4j; import java.util.*; @@ -22,7 +26,7 @@ public MultiDimensionSpline(NDTreeModel ndTreeModel) { Map, MultiDimensionPoint> pointsMap = new HashMap<>(); weights = new HashMap<>(); - NDTreeNode[] nodes = ndTreeModel.getNodes(); + NDTreeNode[] nodes = ndTreeModel.nodes(); for (int i = 0; i < nodes.length; i++) { NDTreeNode node = nodes[i]; if (node.getAxisIndex() == NDTreeModel.LEAF) { @@ -56,10 +60,10 @@ public MultiDimensionSpline(NDTreeModel ndTreeModel, List weights) { // Spline is multi scale, so it needs numBins // MultiDimensionSpline does not support multi scale. @Override - public Function aggregate(Iterable functions, float scale, int numBins) { + public Function aggregate(Iterable functions, double scale, int numBins) { // functions size == 1/scale int length = points.size(); - float[] aggWeights = new float[length]; + double[] aggWeights = new double[length]; for (Function fun: functions) { MultiDimensionSpline spline = (MultiDimensionSpline) fun; @@ -74,7 +78,7 @@ public Function aggregate(Iterable functions, float scale, int numBins } @Override - public float evaluate(float ... coordinates) { + public double evaluate(double ... coordinates) { List list = getNearbyPoints(coordinates); double[] distance = new double[list.size()]; double sum = 0; @@ -86,21 +90,8 @@ public float evaluate(float ... coordinates) { return score(list, distance, sum); } - @Override - public float evaluate(List coordinates) { - List list = getNearbyPoints(coordinates); - double[] distance = new double[list.size()]; - double sum = 0; - for (int i = 0; i < list.size(); i++) { - MultiDimensionPoint point = list.get(i); - distance[i] = point.getDistance(coordinates); - sum += distance[i]; - } - return score(list, distance, sum); - } - - private static float score(List list, double[] distance, double sum) { - float score = 0; + private static double score(List list, double[] distance, double sum) { + double score = 0; for (int i = 0; i < list.size(); i++) { MultiDimensionPoint point = list.get(i); score += point.getWeight() * (distance[i]/sum); @@ -109,7 +100,7 @@ private static float score(List list, double[] distance, do } @Override - public void update(float delta, float ... values) { + public void update(double delta, double ... values) { List list = getNearbyPoints(values); double[] distance = new double[list.size()]; double sum = 0; @@ -121,20 +112,7 @@ public void update(float delta, float ... values) { update(delta, list, distance, sum); } - @Override - public void update(float delta, List values){ - List list = getNearbyPoints(values); - double[] distance = new double[list.size()]; - double sum = 0; - for (int i = 0; i < list.size(); i++) { - MultiDimensionPoint point = list.get(i); - distance[i] = point.getDistance(values); - sum += distance[i]; - } - update(delta, list, distance, sum); - } - - private static void update(float delta, List list, double[] distance, double sum) { + private static void update(double delta, List list, double[] distance, double sum) { for (int i = 0; i < list.size(); i++) { MultiDimensionPoint point = list.get(i); point.updateWeight(delta * (distance[i]/sum)); @@ -144,10 +122,10 @@ private static void update(float delta, List list, double[] @Override public ModelRecord toModelRecord(String featureFamily, String featureName) { ModelRecord record = new ModelRecord(); - record.setFunctionForm(FunctionForm.MultiDimensionSpline); + record.setFunctionForm(FunctionForm.MULTI_DIMENSION_SPLINE); record.setFeatureFamily(featureFamily); record.setWeightVector(getWeightsFromList()); - record.setNdtreeModel(Arrays.asList(ndTreeModel.getNodes())); + record.setNdtreeModel(Arrays.asList(ndTreeModel.nodes())); return record; } @@ -176,7 +154,7 @@ public static List toDouble(List list) { return r; } - @Override public void setPriors(float[] params) { + @Override public void setPriors(double[] params) { assert (params.length == points.size()); for (int i = 0; i < points.size(); i++) { MultiDimensionPoint p = points.get(i); @@ -185,11 +163,11 @@ public static List toDouble(List list) { } @Override - public void LInfinityCap(float cap) { + public void LInfinityCap(double cap) { if (cap <= 0.0f) return; - float currentNorm = LInfinityNorm(); + double currentNorm = LInfinityNorm(); if (currentNorm > cap) { - float scale = cap / currentNorm; + double scale = cap / currentNorm; for (int i = 0; i < points.size(); i++) { points.get(i).scaleWeight(scale); } @@ -197,12 +175,12 @@ public void LInfinityCap(float cap) { } @Override - public float LInfinityNorm() { - return (float) Math.max(Collections.max(points).getWeight(), + public double LInfinityNorm() { + return (double) Math.max(Collections.max(points).getWeight(), Math.abs(Collections.min(points).getWeight())); } - private List getNearbyPoints(float ... coordinates) { + private List getNearbyPoints(double ... coordinates) { int index = ndTreeModel.leaf(coordinates); assert (index != -1 && weights.containsKey(index)); return weights.get(index); @@ -221,10 +199,10 @@ public void resample(int newBins) { @Override public void smooth(double tolerance) { if (!canDoSmooth()) return; - float[] weights = new float[points.size()]; + double[] weights = new double[points.size()]; for (int i = 0; i < points.size(); i++) { MultiDimensionPoint p = points.get(i); - weights[i] = (float) p.getWeight(); + weights[i] = (double) p.getWeight(); } if (FunctionUtil.smooth(tolerance, weights)) { for (int i = 0; i < points.size(); i++) { @@ -235,22 +213,6 @@ public void smooth(double tolerance) { } private boolean canDoSmooth() { - return ndTreeModel.getDimension() == 1; - } - - /* - This drop out is specific for MultiDimensionSpline - */ - public static Map> featureDropout( - FeatureVector featureVector, - double dropout) { - Map> denseFeatures = featureVector.getDenseFeatures(); - if (denseFeatures == null) return Collections.EMPTY_MAP; - Map> out = new HashMap<>(); - for (Map.Entry> feature : denseFeatures.entrySet()) { - if (Math.random() < dropout) continue; - out.put(feature.getKey(), feature.getValue()); - } - return out; + return ndTreeModel.dimension() == 1; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/functions/Spline.java b/core/src/main/java/com/airbnb/aerosolve/core/functions/Spline.java new file mode 100644 index 00000000..74de55d5 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/functions/Spline.java @@ -0,0 +1,197 @@ +package com.airbnb.aerosolve.core.functions; + +import com.airbnb.aerosolve.core.FunctionForm; +import com.airbnb.aerosolve.core.ModelRecord; +import com.google.common.primitives.Doubles; +import java.util.ArrayList; +import java.util.List; + +// A piecewise linear spline implementation supporting updates. +public class Spline extends AbstractFunction { + private static final long serialVersionUID = 5166347177557768302L; + + private int numBins; + private double scale; + private double binSize; + private double binScale; + + public Spline(double minVal, double maxVal, double[] weights) { + setupSpline(minVal, maxVal, weights); + } + + public Spline(double minVal, double maxVal, int numBins) { + if (maxVal <= minVal) { + maxVal = minVal + 1.0d; + } + setupSpline(minVal, maxVal, new double[numBins]); + } + + /* + Generates new weights[] from numBins + */ + public double[] weightsByNumBins(int numBins) { + if (numBins == this.numBins) { + return weights; + } else { + return newWeights(numBins); + } + } + + private double[] newWeights(int numBins) { + assert (numBins != this.numBins); + double[] newWeights = new double[numBins]; + double scale = 1.0d / (numBins - 1.0d); + double diff = maxVal - minVal; + for (int i = 0; i < numBins; i++) { + double t = i * scale; + double x = diff * t + minVal; + newWeights[i] = evaluate(x); + } + return newWeights; + } + + // A constructor from model record + public Spline(ModelRecord record) { + this.minVal = record.getMinVal(); + this.maxVal = record.getMaxVal(); + List weightVec = record.getWeightVector(); + this.numBins = weightVec.size(); + this.weights = new double[this.numBins]; + for (int j = 0; j < numBins; j++) { + this.weights[j] = weightVec.get(j); + } + double diff = Math.max(maxVal - minVal, 1e-10d); + this.scale = 1.0d / diff; + this.binSize = diff / (numBins - 1.0d); + this.binScale = 1.0d / binSize; + } + + private void setupSpline(double minVal, double maxVal, double[] weights) { + this.weights = weights; + this.numBins = weights.length; + this.minVal = minVal; + this.maxVal = maxVal; + double diff = Math.max(maxVal - minVal, 1e-10d); + this.scale = 1.0d / diff; + this.binSize = diff / (numBins - 1.0d); + this.binScale = 1.0d / binSize; + } + + @Override + public Function aggregate(Iterable functions, double scale, int numBins) { + int length = weights.length; + double[] aggWeights = new double[length]; + + for (Function fun : functions) { + Spline spline = (Spline) fun; + double[] w = spline.weightsByNumBins(numBins); + for (int i = 0; i < numBins; i++) { + aggWeights[i] += scale * w[i]; + } + } + return new Spline(minVal, maxVal, aggWeights); + } + + @Override + public double evaluate(double... x) { + int bin = getBin(x[0]); + if (bin == numBins - 1) { + return weights[numBins - 1]; + } + double t = getBinT(x[0], bin); + return (1.0f - t) * weights[bin] + t * weights[bin + 1]; + } + + @Override + public void update(double delta, double... values) { + double x = values[0]; + int bin = getBin(x); + if (bin == numBins - 1) { + weights[numBins - 1] += delta; + } else { + double t = getBinT(x, bin); + t = Math.max(0.0d, Math.min(1.0d, t)); + weights[bin] += (1.0d - t) * delta; + weights[bin + 1] += t * delta; + } + } + + @Override + public ModelRecord toModelRecord(String featureFamily, String featureName) { + ModelRecord record = new ModelRecord(); + record.setFunctionForm(FunctionForm.SPLINE); + record.setFeatureFamily(featureFamily); + record.setFeatureName(featureName); + ArrayList arrayList = new ArrayList<>(); + for (double weight : weights) { + arrayList.add(weight); + } + record.setWeightVector(arrayList); + record.setMinVal(minVal); + record.setMaxVal(maxVal); + return record; + } + + @Override + public void resample(int newBins) { + if (newBins != numBins) { + setupSpline(minVal, maxVal, newWeights(newBins)); + } + } + + // Returns the lower bound bin + public int getBin(double x) { + int bin = (int) ((x - minVal) * scale * (numBins - 1)); + bin = Math.max(0, Math.min(numBins - 1, bin)); + return bin; + } + + // Returns the t value in the bin (0, 1) + public double getBinT(double x, int bin) { + double lowerX = bin * binSize + minVal; + double t = (x - lowerX) * binScale; + t = Math.max(0.0d, Math.min(1.0d, t)); + return t; + } + + public double L1Norm() { + double sum = 0.0d; + for (double weight : weights) { + sum += Math.abs(weight); + } + return sum; + } + + @Override + public double LInfinityNorm() { + return Math.max(Doubles.max(weights), Math.abs(Doubles.min(weights))); + } + + @Override + public void LInfinityCap(double cap) { + if (cap <= 0.0d) return; + double currentNorm = this.LInfinityNorm(); + if (currentNorm > cap) { + double scale = cap / currentNorm; + for (int i = 0; i < weights.length; i++) { + weights[i] *= scale; + } + } + } + + @Override + public void setPriors(double[] params) { + double start = params[0]; + double end = params[1]; + // fit a line based on the input starting weight and ending weight + for (int i = 0; i < numBins; i++) { + double t = i / (numBins - 1.0d); + weights[i] = ((1.0d - t) * start + t * end); + } + } + + @Override + public void smooth(double tolerance) { + FunctionUtil.smooth(tolerance, weights); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/images/HOGFeature.java b/core/src/main/java/com/airbnb/aerosolve/core/images/HOGFeature.java index eeba794d..b0c645f3 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/images/HOGFeature.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/images/HOGFeature.java @@ -1,8 +1,6 @@ package com.airbnb.aerosolve.core.images; import java.awt.image.BufferedImage; -import java.lang.Override; -import java.lang.Math; /* Creates a histogram of oriented gradients. diff --git a/core/src/main/java/com/airbnb/aerosolve/core/images/HSVFeature.java b/core/src/main/java/com/airbnb/aerosolve/core/images/HSVFeature.java index 78690604..c7bc5786 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/images/HSVFeature.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/images/HSVFeature.java @@ -1,9 +1,7 @@ package com.airbnb.aerosolve.core.images; +import java.awt.*; import java.awt.image.BufferedImage; -import java.awt.Color; -import java.lang.Override; -import java.lang.Math; /* Creates a histogram of Hue, saturation, value. diff --git a/core/src/main/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractor.java b/core/src/main/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractor.java index ff42e3d6..40dc1b2b 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractor.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractor.java @@ -1,10 +1,12 @@ package com.airbnb.aerosolve.core.images; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import java.awt.image.BufferedImage; import java.io.Serializable; -import java.util.*; - -import com.airbnb.aerosolve.core.FeatureVector; +import java.util.ArrayList; +import java.util.List; /* Calls all known features and adds them as dense features in a feature vector. @@ -31,17 +33,15 @@ private ImageFeatureExtractor() { features.add(new HSVFeature()); } - public FeatureVector getFeatureVector(BufferedImage image) { - FeatureVector featureVector = new FeatureVector(); - Map> denseFeatures = new HashMap<>(); - featureVector.setDenseFeatures(denseFeatures); + public MultiFamilyVector getFeatureVector(BufferedImage image, FeatureRegistry registry) { + MultiFamilyVector featureVector = new BasicMultiFamilyVector(registry); for (ImageFeature feature : features) { List values = feature.extractFeatureSPMK(image); - List dblValues = new ArrayList<>(); - for (Float f : values) { - dblValues.add(f.doubleValue()); + double[] dblValues = new double[values.size()]; + for (int i = 0; i < values.size(); i++) { + dblValues[i] = values.get(i); } - denseFeatures.put(feature.featureName(), dblValues); + featureVector.putDense(feature.featureName(), dblValues); } return featureVector; } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/images/LBPFeature.java b/core/src/main/java/com/airbnb/aerosolve/core/images/LBPFeature.java index 182d80d8..fc9408f8 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/images/LBPFeature.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/images/LBPFeature.java @@ -1,7 +1,6 @@ package com.airbnb.aerosolve.core.images; import java.awt.image.BufferedImage; -import java.lang.Override; /* Creates a histogram of local binary patterns diff --git a/core/src/main/java/com/airbnb/aerosolve/core/images/RGBFeature.java b/core/src/main/java/com/airbnb/aerosolve/core/images/RGBFeature.java index 82b311d6..68ea7245 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/images/RGBFeature.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/images/RGBFeature.java @@ -1,7 +1,6 @@ package com.airbnb.aerosolve.core.images; import java.awt.image.BufferedImage; -import java.lang.Override; /* Creates histograms of quantized RGB images. diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/AbstractModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/AbstractModel.java index e3080361..f8063f88 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/AbstractModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/AbstractModel.java @@ -4,22 +4,24 @@ import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.MulticlassScoringResult; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.util.FloatVector; -import lombok.Getter; -import lombok.Setter; - import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import java.util.Map; +import lombok.Getter; +import lombok.Setter; +import lombok.experimental.Accessors; + /** * Created by hector_yee on 8/25/14. * Base class for models */ - +@Accessors(fluent = true, chain = true) public abstract class AbstractModel implements Model, Serializable { private static final long serialVersionUID = -5011350794437028492L; @@ -29,14 +31,20 @@ public abstract class AbstractModel implements Model, Serializable { @Getter @Setter protected double slope = 1.0; + @Getter + protected final FeatureRegistry registry; + + public AbstractModel(FeatureRegistry registry) { + this.registry = registry; + } // Scores a single item. The transforms should already have been applied to // the context and item and combined item. - abstract public float scoreItem(FeatureVector combinedItem); + abstract public double scoreItem(FeatureVector combinedItem); // Debug scores a single item. These are explanations for why a model // came up with the score. - abstract public float debugScoreItem(FeatureVector combinedItem, + abstract public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder); abstract public List debugScoreComponents(FeatureVector combinedItem); @@ -47,7 +55,7 @@ abstract public float debugScoreItem(FeatureVector combinedItem, abstract public void save(BufferedWriter writer) throws IOException; // returns probability: 1 / (1 + exp(-(offset + scale * score)) - public double scoreProbability(float score) { + public double scoreProbability(double score) { return 1.0 / (1.0 + Math.exp(-(offset + slope * score))); } @@ -62,16 +70,16 @@ public ArrayList scoreItemMulticlass(FeatureVector comb public void scoreToProbability(ArrayList results) { FloatVector vec = new FloatVector(results.size()); for (int i = 0; i < results.size(); i++) { - vec.values[i] = (float) results.get(i).score; + vec.values[i] = (float) results.get(i).getScore(); } vec.softmax(); for (int i = 0; i < results.size(); i++) { - results.get(i).probability = vec.values[i]; + results.get(i).setProbability(vec.values[i]); } } // Optional method implemented by online updatable models e.g. Spline, RBF - public void onlineUpdate(float grad, float learningRate, Map> flatFeatures) { + public void onlineUpdate(double grad, double learningRate, FeatureVector vector) { assert(false); } @@ -98,4 +106,8 @@ public static float fobosUpdate( float step = (float) Math.max(0.0, Math.abs(wt) - l1Reg * etaTHalf); return sign * step; } + + public boolean needsFeature(Feature feature) { + return true; + } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/AdditiveModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/AdditiveModel.java index 2ad86ef7..b7e17dc1 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/AdditiveModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/AdditiveModel.java @@ -4,115 +4,146 @@ import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.function.AbstractFunction; -import com.airbnb.aerosolve.core.function.Function; -import com.airbnb.aerosolve.core.function.FunctionUtil; -import com.airbnb.aerosolve.core.function.MultiDimensionSpline; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.functions.AbstractFunction; +import com.airbnb.aerosolve.core.functions.Function; +import com.airbnb.aerosolve.core.functions.MultiDimensionSpline; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; import com.airbnb.aerosolve.core.util.Util; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; - +import com.google.common.primitives.Doubles; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; -import java.util.*; - -import static com.airbnb.aerosolve.core.function.FunctionUtil.toFloat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import lombok.Getter; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.tuple.Pair; // A generalized additive model with a parametric function per feature. // See http://en.wikipedia.org/wiki/Generalized_additive_model @Slf4j +@Accessors(fluent = true, chain = true) public class AdditiveModel extends AbstractModel { - private static final String DENSE_FAMILY = "dense"; + @Getter @Setter - private Map> weights = new HashMap<>(); + private Map weights = + new Object2ObjectOpenHashMap<>(); - // only MultiDimensionSpline using denseWeights - // whole dense features belongs to feature family DENSE_FAMILY - private Map denseWeights; + @Getter + private Map familyWeights = + new Object2ObjectOpenHashMap<>(); - private Map getOrCreateDenseWeights() { - if (denseWeights == null) { - denseWeights = weights.get(DENSE_FAMILY); - if (denseWeights == null) { - denseWeights = new HashMap<>(); - weights.put(DENSE_FAMILY, denseWeights); - } - } - return denseWeights; + public AdditiveModel(FeatureRegistry registry) { + super(registry); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - return scoreFlatFeatures(flatFeatures) + scoreDenseFeatures(combinedItem.getDenseFeatures()); + public boolean needsFeature(Feature feature) { + return weights.containsKey(feature); + } + + @Override + public double scoreItem(FeatureVector combinedItem) { + return scoreItemInternal(combinedItem, null, null); + } + + public double scoreItemInternal(FeatureVector combinedItem, + PriorityQueue> scores, + List scoreRecordsList) { + double sum = 0.0d; + + if (combinedItem instanceof MultiFamilyVector && familyWeights != null + && !familyWeights.isEmpty()) { + MultiFamilyVector multiFamilyVector = ((MultiFamilyVector) combinedItem); + for (FamilyVector familyVector : multiFamilyVector.families()) { + Function familyFunction = familyWeights.get(familyVector.family()); + + if (familyFunction == null) { + sum += scoreVector(familyVector, scores, scoreRecordsList); + } else { + double[] val = familyVector.denseArray(); + double subscore = familyFunction.evaluate(val); + sum += subscore; + if (scores != null) { + String str = familyVector.family().name() + ":null=" + Arrays.toString(val) + + " = " + subscore + "
\n"; + scores.add(Pair.of(str, subscore)); + } + if (scoreRecordsList != null) { + DebugScoreRecord record = new DebugScoreRecord(); + record.setFeatureFamily(familyVector.family().name()); + record.setFeatureName(null); + record.setDenseFeatureValue(Doubles.asList(val)); + record.setFeatureWeight(subscore); + scoreRecordsList.add(record); + } + } + } + } else { + sum = scoreVector(combinedItem, scores, scoreRecordsList); + } + + return sum; } - public float scoreDenseFeatures(Map> denseFeatures) { - float sum = 0; - if (denseFeatures != null && !denseFeatures.isEmpty()) { - assert (denseWeights != null); - for (Map.Entry> feature : denseFeatures.entrySet()) { - String featureName = feature.getKey(); - Function fun = denseWeights.get(featureName); - sum += fun.evaluate(toFloat(feature.getValue())); + private double scoreVector(FeatureVector combinedItem, + PriorityQueue> scores, + List scoreRecordsList) { + double sum = 0.0d; + for (FeatureValue value : combinedItem) { + Function func = weights.get(value.feature()); + if (func == null) + continue; + double subscore = func.evaluate(value.value()); + sum += subscore; + if (scores != null) { + String str = value.feature().family().name() + ":" + value.feature().name() + + "=" + value.value() + " = " +subscore + "
\n"; + scores.add(Pair.of(str, subscore)); + } + if (scoreRecordsList != null) { + DebugScoreRecord record = new DebugScoreRecord(); + record.setFeatureFamily(value.feature().family().name()); + record.setFeatureName(value.feature().name()); + record.setFeatureValue(value.value()); + record.setFeatureWeight(subscore); + scoreRecordsList.add(record); } } return sum; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - - float sum = 0.0f; // order by the absolute value - PriorityQueue> scores = + PriorityQueue> scores = new PriorityQueue<>(100, new LinearModel.EntryComparator()); - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weights.get(featureFamily.getKey()); - if (familyWeightMap == null) - continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - Function func = familyWeightMap.get(feature.getKey()); - if (func == null) - continue; - float val = feature.getValue().floatValue(); - float subScore = func.evaluate(val); - sum += subScore; - String str = featureFamily.getKey() + ":" + feature.getKey() + "=" + val - + " = " + subScore + "
\n"; - scores.add(new AbstractMap.SimpleEntry(str, subScore)); - } - } - - Map> denseFeatures = combinedItem.getDenseFeatures(); - if (denseFeatures != null) { - assert (denseWeights != null); - for (Map.Entry> feature : denseFeatures.entrySet()) { - String featureName = feature.getKey(); - Function fun = denseWeights.get(featureName); - float[] val = toFloat(feature.getValue()); - float subScore = fun.evaluate(val); - sum += subScore; - String str = DENSE_FAMILY + ":" + featureName + "=" + val - + " = " + subScore + "
\n"; - scores.add(new AbstractMap.SimpleEntry(str, subScore)); - } - } + double sum = scoreItemInternal(combinedItem, scores, null); final int MAX_COUNT = 100; builder.append("Top scores ===>\n"); if (!scores.isEmpty()) { int count = 0; - float subsum = 0.0f; + double subsum = 0.0d; while (!scores.isEmpty()) { - Map.Entry entry = scores.poll(); - builder.append(entry.getKey()); - float val = entry.getValue(); + Map.Entry entry = scores.poll(); + String str = entry.getKey(); + builder.append(str); + double val = entry.getValue(); subsum += val; count = count + 1; if (count >= MAX_COUNT) { @@ -128,43 +159,8 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); List scoreRecordsList = new ArrayList<>(); - - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weights.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - Function func = familyWeightMap.get(feature.getKey()); - if (func == null) continue; - float val = feature.getValue().floatValue(); - float weight = func.evaluate(val); - DebugScoreRecord record = new DebugScoreRecord(); - record.setFeatureFamily(featureFamily.getKey()); - record.setFeatureName(feature.getKey()); - record.setFeatureValue(val); - record.setFeatureWeight(weight); - scoreRecordsList.add(record); - } - } - - Map> denseFeatures = combinedItem.getDenseFeatures(); - if (denseFeatures != null) { - assert (denseWeights != null); - for (Map.Entry> feature : denseFeatures.entrySet()) { - String featureName = feature.getKey(); - Function fun = denseWeights.get(featureName); - float[] val = toFloat(feature.getValue()); - float weight = fun.evaluate(val); - DebugScoreRecord record = new DebugScoreRecord(); - record.setFeatureFamily(DENSE_FAMILY); - record.setFeatureName(feature.getKey()); - record.setDenseFeatureValue(feature.getValue()); - record.setFeatureWeight(weight); - scoreRecordsList.add(record); - } - } - + scoreItemInternal(combinedItem, null, scoreRecordsList); return scoreRecordsList; } @@ -173,18 +169,20 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO long rows = header.getNumRecords(); slope = header.getSlope(); offset = header.getOffset(); - weights = new HashMap<>(); + weights = new Reference2ObjectOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = weights.get(family); - if (inner == null) { - inner = new HashMap<>(); - weights.put(family, inner); + Function function = AbstractFunction.buildFunction(record); + // TODO (Brad): Do we need to check the type? I did it for safety in case some null names have + // been serialized. But it's a bit brittle to use instanceof. + if (name == null && function instanceof MultiDimensionSpline) { + familyWeights.put(registry.family(family), function); + } else { + weights.put(registry.feature(family, name), function); } - inner.put(name, AbstractFunction.buildFunction(record)); } } @@ -194,102 +192,62 @@ public void save(BufferedWriter writer) throws IOException { header.setModelType("additive"); header.setSlope(slope); header.setOffset(offset); - long count = 0; - for (Map.Entry> familyMap : weights.entrySet()) { - count += familyMap.getValue().size(); - } - header.setNumRecords(count); + header.setNumRecords(weights.size()); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : weights.entrySet()) { - String featureFamily = familyMap.getKey(); - for (Map.Entry feature : familyMap.getValue().entrySet()) { - Function func = feature.getValue(); - String featureName = feature.getKey(); - writer.write(Util.encode(func.toModelRecord(featureFamily, featureName))); - writer.newLine(); - } + for (Map.Entry entry : weights.entrySet()) { + Function func = entry.getValue(); + String featureName = entry.getKey().name(); + writer.write(Util.encode(func.toModelRecord(entry.getKey().family().name(), featureName))); + writer.newLine(); + } + // TODO (Brad): Talk to Julian. This changes the serialization from what he was using but + // I don't think it should be a problem since we're still training. + for (Map.Entry entry : familyWeights.entrySet()) { + Function func = entry.getValue(); + writer.write(Util.encode(func.toModelRecord(entry.getKey().name(), null))); + writer.newLine(); } writer.flush(); } - public float scoreFlatFeatures(Map> flatFeatures) { - float sum = 0.0f; - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weights.get(featureFamily.getKey()); - if (familyWeightMap == null) { - // not important families/features are removed from model - log.debug("miss featureFamily {}", featureFamily.getKey()); - continue; - } - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - Function func = familyWeightMap.get(feature.getKey()); - if (func == null) - continue; - float val = feature.getValue().floatValue(); - sum += func.evaluate(val); - } + public void addFunction(Feature feature, Function function, boolean overwrite) { + if (function == null) { + throw new RuntimeException(feature + " function null"); } - return sum; - } - - public Map getOrCreateFeatureFamily(String featureFamily) { - Map featFamily = weights.get(featureFamily); - if (featFamily == null) { - featFamily = new HashMap(); - weights.put(featureFamily, featFamily); + if (overwrite || !weights.containsKey(feature)) { + weights.put(feature, function); } - return featFamily; } - public void addFunction(String featureFamily, String featureName, - Function function, boolean overwrite) { + public void addFunction(Family family, Function function, boolean overwrite) { if (function == null) { - throw new RuntimeException(featureFamily + " " + featureName + " function null"); + throw new RuntimeException(family + " function null"); } - Map featFamily = getOrCreateFeatureFamily(featureFamily); - if (overwrite || !featFamily.containsKey(featureName)) { - featFamily.put(featureName, function); + if (overwrite || !familyWeights.containsKey(family)) { + familyWeights.put(family, function); } } - public void addDenseFunction(String featureName, NDTreeModel ndtreeModel) { - MultiDimensionSpline spline = new MultiDimensionSpline(ndtreeModel); - Map dense = getOrCreateDenseWeights(); - dense.put(featureName, spline); - } - // Update weights based on gradient and learning rate - public void update(float gradWithLearningRate, - float cap, - Map> flatFeatures) { - // update with lInfinite cap - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weights.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - Function func = familyWeightMap.get(feature.getKey()); - if (func == null) continue; - float val = feature.getValue().floatValue(); - func.update(-gradWithLearningRate, val); - func.LInfinityCap(cap); - } - } - } - - public void updateDense(float gradWithLearningRate, - float cap, - Map> denseFeatures) { - // update with lInfinite cap - if (denseFeatures != null && !denseFeatures.isEmpty()) { - assert (denseWeights != null); - for (Map.Entry> feature : denseFeatures.entrySet()) { - String featureName = feature.getKey(); - Function func = denseWeights.get(featureName); - float[] val = FunctionUtil.toFloat(feature.getValue()); - func.update(-gradWithLearningRate, val); + public void update(double gradWithLearningRate, + double cap, + MultiFamilyVector vector) { + for (FamilyVector familyVector : vector.families()) { + Function func = familyWeights.get(familyVector.family()); + if (func == null) { + for (FeatureValue value : familyVector) { + func = weights.get(value.feature()); + if (func == null) continue; + // update with lInfinite cap + func.update(-gradWithLearningRate, value.value()); + func.LInfinityCap(cap); + } + } else { + // update with lInfinite cap + func.update(-gradWithLearningRate, vector.denseArray()); func.LInfinityCap(cap); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/BoostedStumpsModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/BoostedStumpsModel.java index ab3454c0..43dd716a 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/BoostedStumpsModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/BoostedStumpsModel.java @@ -1,21 +1,23 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.util.Map; -import java.util.List; -import java.util.ArrayList; - +import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.DebugScoreRecord; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.util.Util; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A simple boosted decision stump model that only operates on float features. +@Accessors(fluent = true, chain = true) public class BoostedStumpsModel extends AbstractModel { private static final long serialVersionUID = 3651061358422885377L; @@ -23,54 +25,41 @@ public class BoostedStumpsModel extends AbstractModel { @Getter @Setter protected List stumps; - public BoostedStumpsModel() { + public BoostedStumpsModel(FeatureRegistry registry) { + super(registry); } // Returns true if >= stump, false otherwise. - public static boolean getStumpResponse(ModelRecord stump, - Map> floatFeatures) { - Map feat = floatFeatures.get(stump.featureFamily); - // missing feature corresponding to false (left branch) - if (feat == null) { - return false; - } - Double val = feat.get(stump.featureName); - if (val == null) { - return false; - } - if (val >= stump.getThreshold()) { - return true; - } else { - return false; - } + public static boolean getStumpResponse(ModelRecord stump, FeatureVector vector) { + Feature feature = vector.registry().feature(stump.getFeatureFamily(), stump.getFeatureName()); + return vector.containsKey(feature) && + vector.getDouble(feature) >= stump.getThreshold(); } @Override - public float scoreItem(FeatureVector combinedItem) { + public double scoreItem(FeatureVector combinedItem) { float sum = 0.0f; - Map> floatFeatures = Util.flattenFeature(combinedItem); for (ModelRecord stump : stumps) { - if (getStumpResponse(stump, floatFeatures)) { - sum += stump.featureWeight; + if (getStumpResponse(stump, combinedItem)) { + sum += stump.getFeatureWeight(); } } return sum; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { float sum = 0.0f; - Map> floatFeatures = Util.flattenFeature(combinedItem); for (ModelRecord stump : stumps) { - boolean response = getStumpResponse(stump, floatFeatures); - String output = stump.featureFamily + ':' + stump.getFeatureName(); - Double threshold = stump.threshold; - Double weight = stump.featureWeight; + boolean response = getStumpResponse(stump, combinedItem); + String output = stump.getFeatureFamily() + ':' + stump.getFeatureName(); + Double threshold = stump.getThreshold(); + Double weight = stump.getFeatureWeight(); if (response) { builder.append(output); builder.append(" >= " + threshold.toString() + " ==> " + weight.toString()); - sum += stump.featureWeight; + sum += stump.getFeatureWeight(); } } return sum; @@ -79,17 +68,17 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { List scoreRecordsList = new ArrayList<>(); - Map> floatFeatures = Util.flattenFeature(combinedItem); for (ModelRecord stump : stumps) { - boolean response = getStumpResponse(stump, floatFeatures); + boolean response = getStumpResponse(stump, combinedItem); if (response) { + Feature feature = registry.feature(stump.getFeatureFamily(), stump.getFeatureName()); DebugScoreRecord record = new DebugScoreRecord(); - record.setFeatureFamily(stump.featureFamily); - record.setFeatureName(stump.featureName); - record.setFeatureValue(floatFeatures.get(stump.featureFamily).get(stump.featureName)); - record.setFeatureWeight(stump.featureWeight); + record.setFeatureFamily(stump.getFeatureFamily()); + record.setFeatureName(stump.getFeatureName()); + record.setFeatureValue(combinedItem.get(feature)); + record.setFeatureWeight(stump.getFeatureWeight()); scoreRecordsList.add(record); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/DecisionTreeModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/DecisionTreeModel.java index c23c911c..86e0e29c 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/DecisionTreeModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/DecisionTreeModel.java @@ -1,19 +1,24 @@ package com.airbnb.aerosolve.core.models; +import com.airbnb.aerosolve.core.DebugScoreRecord; +import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.ModelHeader; +import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.MulticlassScoringResult; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.util.Util; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; -import java.lang.StringBuilder; -import java.util.Map; -import java.util.List; import java.util.ArrayList; - -import com.airbnb.aerosolve.core.*; -import com.airbnb.aerosolve.core.util.Util; +import java.util.List; +import java.util.Map; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A simple decision tree model. +@Accessors(fluent = true, chain = true) public class DecisionTreeModel extends AbstractModel { private static final long serialVersionUID = 3651061358422885379L; @@ -21,39 +26,30 @@ public class DecisionTreeModel extends AbstractModel { @Getter @Setter protected ArrayList stumps; - public DecisionTreeModel() { + public DecisionTreeModel(FeatureRegistry registry) { + super(registry); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> floatFeatures = Util.flattenFeature(combinedItem); - return scoreFlattenedFeature(floatFeatures); - } - - @Override - public ArrayList scoreItemMulticlass(FeatureVector combinedItem) { - Map> floatFeatures = Util.flattenFeature(combinedItem); - return scoreFlattenedFeatureMulticlass(floatFeatures); - } - - public float scoreFlattenedFeature(Map> floatFeatures) { - int leaf = getLeafIndex(floatFeatures); - if (leaf < 0) return 0.0f; + public double scoreItem(FeatureVector vector) { + int leaf = getLeafIndex(vector); + if (leaf < 0) return 0.0d; ModelRecord stump = stumps.get(leaf); - return (float) stump.featureWeight; + return stump.getFeatureWeight(); } - public ArrayList scoreFlattenedFeatureMulticlass(Map> floatFeatures) { + @Override + public ArrayList scoreItemMulticlass(FeatureVector vector) { ArrayList results = new ArrayList<>(); - int leaf = getLeafIndex(floatFeatures); + int leaf = getLeafIndex(vector); if (leaf < 0) return results; ModelRecord stump = stumps.get(leaf); - if (stump.labelDistribution == null) return results; + if (stump.getLabelDistribution() == null) return results; - for (Map.Entry entry : stump.labelDistribution.entrySet()) { + for (Map.Entry entry : stump.getLabelDistribution().entrySet()) { MulticlassScoringResult result = new MulticlassScoringResult(); result.setLabel(entry.getKey()); result.setScore(entry.getValue()); @@ -62,7 +58,7 @@ public ArrayList scoreFlattenedFeatureMulticlass(Map> floatFeatures) { + public int getLeafIndex(FeatureVector vector) { if (stumps.isEmpty()) return -1; int index = 0; @@ -71,11 +67,11 @@ public int getLeafIndex(Map> floatFeatures) { if (!stump.isSetLeftChild() || !stump.isSetRightChild()) { break; } - boolean response = BoostedStumpsModel.getStumpResponse(stump, floatFeatures); + boolean response = BoostedStumpsModel.getStumpResponse(stump, vector); if (response) { - index = stump.rightChild; + index = stump.getRightChild(); } else { - index = stump.leftChild; + index = stump.getLeftChild(); } } return index; @@ -83,7 +79,7 @@ public int getLeafIndex(Map> floatFeatures) { @Override // Decision trees don't usually have debuggable components. - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { return 0.0f; } @@ -134,25 +130,25 @@ public String toDot() { ModelRecord stump = stumps.get(i); if (stump.isSetLeftChild()) { sb.append(String.format("\"node%d\" [\n", i)); - double thresh = stump.threshold; + double thresh = stump.getThreshold(); sb.append(String.format( "label = \" %s:%s | less than %f | greater than or equal%f\";\n", - stump.featureFamily, - stump.featureName, + stump.getFeatureFamily(), + stump.getFeatureName(), thresh, thresh)); sb.append("shape = \"record\";\n"); sb.append("];\n"); } else { sb.append(String.format("\"node%d\" [\n", i)); - if (stump.labelDistribution != null) { + if (stump.getLabelDistribution() != null) { sb.append(String.format("label = \" ")); - for (Map.Entry entry : stump.labelDistribution.entrySet()) { + for (Map.Entry entry : stump.getLabelDistribution().entrySet()) { sb.append(String.format("%s : %f ", entry.getKey(), entry.getValue())); } sb.append(" \";\n"); } else { - sb.append(String.format("label = \" Weight %f\";\n", stump.featureWeight)); + sb.append(String.format("label = \" Weight %f\";\n", stump.getFeatureWeight())); } sb.append("shape = \"record\";\n"); sb.append("];\n"); @@ -162,9 +158,11 @@ public String toDot() { for (int i = 0; i < stumps.size(); i++) { ModelRecord stump = stumps.get(i); if (stump.isSetLeftChild()) { - sb.append(String.format("\"node%d\":f1 -> \"node%d\":f0 [ id = %d ];\n", i, stump.leftChild, count)); + sb.append(String.format("\"node%d\":f1 -> \"node%d\":f0 [ id = %d ];\n", i, + stump.getLeftChild(), count)); count = count + 1; - sb.append(String.format("\"node%d\":f2 -> \"node%d\":f0 [id = %d];\n", i, stump.rightChild, count)); + sb.append(String.format("\"node%d\":f2 -> \"node%d\":f0 [id = %d];\n", i, + stump.getRightChild(), count)); count = count + 1; } } @@ -183,13 +181,14 @@ public String toHumanReadableTransform() { // Parent node, node id, family, name, threshold, left, right sb.append( String.format("P,%d,%s,%s,%f,%d,%d", i, - stump.featureFamily, - stump.featureName, - stump.threshold, - stump.leftChild, stump.rightChild)); + stump.getFeatureFamily(), + stump.getFeatureName(), + stump.getThreshold(), + stump.getLeftChild(), + stump.getRightChild())); } else { // Leaf node, node id, feature weight, human readable leaf name. - sb.append(String.format("L,%d,%f,LEAF_%d", i, stump.featureWeight, i)); + sb.append(String.format("L,%d,%f,LEAF_%d", i, stump.getFeatureWeight(), i)); } sb.append("\"\n"); } @@ -198,10 +197,11 @@ public String toHumanReadableTransform() { } // Constructs a tree from human readable transform list. - public static DecisionTreeModel fromHumanReadableTransform(List rows) { - DecisionTreeModel tree = new DecisionTreeModel(); + public static DecisionTreeModel fromHumanReadableTransform(List rows, + FeatureRegistry registry) { + DecisionTreeModel tree = new DecisionTreeModel(registry); ArrayList records = new ArrayList<>(); - tree.setStumps(records); + tree.stumps(records); for (String row : rows) { ModelRecord rec = new ModelRecord(); records.add(rec); diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/ForestModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/ForestModel.java index 0a5d468d..5e7f99f2 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/ForestModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/ForestModel.java @@ -1,20 +1,25 @@ package com.airbnb.aerosolve.core.models; +import com.airbnb.aerosolve.core.DebugScoreRecord; +import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.ModelHeader; +import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.MulticlassScoringResult; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.util.Util; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; -import java.lang.StringBuilder; -import java.util.Map; -import java.util.List; -import java.util.HashMap; import java.util.ArrayList; - -import com.airbnb.aerosolve.core.*; -import com.airbnb.aerosolve.core.util.Util; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A tree forest model. +@Accessors(fluent = true, chain = true) public class ForestModel extends AbstractModel { private static final long serialVersionUID = 3651061358422885378L; @@ -22,18 +27,17 @@ public class ForestModel extends AbstractModel { @Getter @Setter protected ArrayList trees; - public ForestModel() { + public ForestModel(FeatureRegistry registry) { + super(registry); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> floatFeatures = Util.flattenFeature(combinedItem); - + public double scoreItem(FeatureVector combinedItem) { float sum = 0.0f; // Note: we sum instead of average so that the trainer has the option of boosting the // trees together. - for (int i = 0; i < trees.size(); i++) { - sum += trees.get(i).scoreFlattenedFeature(floatFeatures); + for (DecisionTreeModel tree : trees) { + sum += tree.scoreItem(combinedItem); } return sum; } @@ -42,19 +46,16 @@ public float scoreItem(FeatureVector combinedItem) { public ArrayList scoreItemMulticlass(FeatureVector combinedItem) { HashMap map = new HashMap<>(); - Map> floatFeatures = Util.flattenFeature(combinedItem); - // Note: we sum instead of average so that the trainer has the option of boosting the // trees together. - for (int i = 0; i < trees.size(); i++) { - ArrayList tmp = trees.get(i).scoreFlattenedFeatureMulticlass( - floatFeatures); + for (DecisionTreeModel tree : trees) { + ArrayList tmp = tree.scoreItemMulticlass(combinedItem); for (MulticlassScoringResult result : tmp) { - Double v = map.get(result.label); + Double v = map.get(result.getLabel()); if (v == null) { - map.put(result.label, result.score); + map.put(result.getLabel(), result.getScore()); } else { - map.put(result.label, v + result.score); + map.put(result.getLabel(), v + result.getScore()); } } } @@ -72,7 +73,7 @@ public ArrayList scoreItemMulticlass(FeatureVector comb @Override // Forests don't usually have debuggable components. - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { return 0.0f; } @@ -110,7 +111,7 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO for (long i = 0; i < numTrees; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); - DecisionTreeModel tree = new DecisionTreeModel(); + DecisionTreeModel tree = new DecisionTreeModel(registry); tree.loadInternal(record.getModelHeader(), reader); trees.add(tree); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/FullRankLinearModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/FullRankLinearModel.java index 78a9bbc3..d2b55e66 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/FullRankLinearModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/FullRankLinearModel.java @@ -1,33 +1,41 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.LabelDictionaryEntry; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.MulticlassScoringResult; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; import com.airbnb.aerosolve.core.util.FloatVector; - +import com.airbnb.aerosolve.core.util.Util; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A full rank linear model that supports multi-class classification. // The class vector Y = W' X where X is the feature vector. // It is full rank because the matrix W is num-features by num-labels in dimension. // Use a low rank model if you want better generalization. +@Accessors(fluent = true, chain = true) public class FullRankLinearModel extends AbstractModel { private static final long serialVersionUID = -849900702679383420L; - @Getter @Setter - private Map> weightVector; + @Getter + private Map weightVector; @Getter @Setter private ArrayList labelDictionary; @@ -35,24 +43,24 @@ public class FullRankLinearModel extends AbstractModel { @Getter @Setter private Map labelToIndex; - public FullRankLinearModel() { - weightVector = new HashMap<>(); + public FullRankLinearModel(FeatureRegistry registry) { + super(registry); + weightVector = new Object2ObjectOpenHashMap<>(); labelDictionary = new ArrayList<>(); } // In the binary case this is just the score for class 0. // Ideally use a binary model for binary classification. @Override - public float scoreItem(FeatureVector combinedItem) { + public double scoreItem(FeatureVector combinedItem) { // Not supported. assert(false); - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector sum = scoreFlatFeature(flatFeatures); + FloatVector sum = scoreFlatFeature(combinedItem); return sum.values[0]; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { // TODO(hector_yee) : implement debug. return scoreItem(combinedItem); @@ -60,28 +68,19 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); List scoreRecordsList = new ArrayList<>(); int dim = labelDictionary.size(); - for (Map.Entry> entry : flatFeatures.entrySet()) { - String familyKey = entry.getKey(); - Map family = weightVector.get(familyKey); - if (family != null) { - for (Map.Entry feature : entry.getValue().entrySet()) { - String featureKey = feature.getKey(); - FloatVector featureWeights = family.get(featureKey); - float val = feature.getValue().floatValue(); - if (featureWeights != null) { - for (int i = 0; i < dim; i++) { - DebugScoreRecord record = new DebugScoreRecord(); - record.setFeatureFamily(familyKey); - record.setFeatureName(featureKey); - record.setFeatureValue(val); - record.setFeatureWeight(featureWeights.get(i)); - record.setLabel(labelDictionary.get(i).label); - scoreRecordsList.add(record); - } - } + for (FeatureValue value : combinedItem) { + FloatVector featureWeights = weightVector.get(value.feature()); + if (featureWeights != null) { + for (int i = 0; i < dim; i++) { + DebugScoreRecord record = new DebugScoreRecord(); + record.setFeatureFamily(value.feature().family().name()); + record.setFeatureName(value.feature().name()); + record.setFeatureValue(value.value()); + record.setFeatureWeight(featureWeights.get(i)); + record.setLabel(labelDictionary.get(i).getLabel()); + scoreRecordsList.add(record); } } } @@ -90,8 +89,7 @@ public List debugScoreComponents(FeatureVector combinedItem) { public ArrayList scoreItemMulticlass(FeatureVector combinedItem) { ArrayList results = new ArrayList<>(); - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector sum = scoreFlatFeature(flatFeatures); + FloatVector sum = scoreFlatFeature(combinedItem); for (int i = 0; i < labelDictionary.size(); i++) { MulticlassScoringResult result = new MulticlassScoringResult(); @@ -103,19 +101,14 @@ public ArrayList scoreItemMulticlass(FeatureVector comb return results; } - public FloatVector scoreFlatFeature(Map> flatFeatures) { + public FloatVector scoreFlatFeature(FeatureVector vector) { int dim = labelDictionary.size(); FloatVector sum = new FloatVector(dim); - for (Map.Entry> entry : flatFeatures.entrySet()) { - Map family = weightVector.get(entry.getKey()); - if (family != null) { - for (Map.Entry feature : entry.getValue().entrySet()) { - FloatVector vec = family.get(feature.getKey()); - if (vec != null) { - sum.multiplyAdd(feature.getValue().floatValue(), vec); - } - } + for (FeatureValue value : vector) { + FloatVector vec = weightVector.get(value.feature()); + if (vec != null) { + sum.multiplyAdd(value.value(), vec); } } return sum; @@ -124,36 +117,30 @@ public FloatVector scoreFlatFeature(Map> flatFeature public void buildLabelToIndex() { labelToIndex = new HashMap<>(); for (int i = 0; i < labelDictionary.size(); i++) { - labelToIndex.put(labelDictionary.get(i).label, i); + labelToIndex.put(labelDictionary.get(i).getLabel(), i); } } public void save(BufferedWriter writer) throws IOException { ModelHeader header = new ModelHeader(); header.setModelType("full_rank_linear"); - long count = 0; - for (Map.Entry> familyMap : weightVector.entrySet()) { - count += familyMap.getValue().entrySet().size(); - } - header.setNumRecords(count); + header.setNumRecords(weightVector.size()); header.setLabelDictionary(labelDictionary); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : weightVector.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - ArrayList arrayList = new ArrayList(); - for (int i = 0; i < feature.getValue().values.length; i++) { - arrayList.add((double) feature.getValue().values[i]); - } - record.setWeightVector(arrayList); - writer.write(Util.encode(record)); - writer.newLine(); + for (Map.Entry entry : weightVector.entrySet()) { + ModelRecord record = new ModelRecord(); + record.setFeatureFamily(entry.getKey().family().name()); + record.setFeatureName(entry.getKey().name()); + ArrayList arrayList = new ArrayList(); + for (int i = 0; i < entry.getValue().values.length; i++) { + arrayList.add((double) entry.getValue().values[i]); } + record.setWeightVector(arrayList); + writer.write(Util.encode(record)); + writer.newLine(); } writer.flush(); } @@ -166,22 +153,18 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO labelDictionary.add(entry); } buildLabelToIndex(); - weightVector = new HashMap<>(); + weightVector = new Reference2ObjectOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = weightVector.get(family); - if (inner == null) { - inner = new HashMap<>(); - weightVector.put(family, inner); - } + Feature feature = registry.feature(family, name); FloatVector vec = new FloatVector(record.getWeightVector().size()); for (int j = 0; j < record.getWeightVector().size(); j++) { vec.values[j] = record.getWeightVector().get(j).floatValue(); } - inner.put(name, vec); + weightVector.put(feature, vec); } } } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/KDTreeModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/KDTreeModel.java index 59299a0d..34066f8f 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/KDTreeModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/KDTreeModel.java @@ -3,21 +3,22 @@ import com.airbnb.aerosolve.core.KDTreeNode; import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; -import lombok.Getter; -import org.apache.commons.codec.binary.Base64; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.ByteArrayInputStream; import java.io.InputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Stack; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.apache.commons.codec.binary.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static com.airbnb.aerosolve.core.KDTreeNodeType.LEAF; // A specialized 2D kd-tree that supports point and box queries. +@Accessors(fluent = true, chain = true) public class KDTreeModel implements Serializable { private static final long serialVersionUID = -2884260218927875695L; @@ -76,25 +77,25 @@ public ArrayList query(double x, double y) { private int next(int currIdx, double x, double y) { KDTreeNode node = nodes[currIdx]; int nextIndex = -1; - switch(node.nodeType) { + switch(node.getNodeType()) { case X_SPLIT: { - if (x < node.splitValue) { - nextIndex = node.leftChild; + if (x < node.getSplitValue()) { + nextIndex = node.getLeftChild(); } else { - nextIndex = node.rightChild; + nextIndex = node.getRightChild(); } } break; case Y_SPLIT: { - if (y < node.splitValue) { - nextIndex = node.leftChild; + if (y < node.getSplitValue()) { + nextIndex = node.getLeftChild(); } else { - nextIndex = node.rightChild; + nextIndex = node.getRightChild(); } } break; default: - assert (node.nodeType == LEAF); + assert (node.getNodeType() == LEAF); break; } return nextIndex; @@ -112,22 +113,22 @@ public ArrayList queryBox(double minX, double minY, double maxX, double int currIdx = stack.pop(); idx.add(currIdx); KDTreeNode node = nodes[currIdx]; - switch (node.nodeType) { + switch (node.getNodeType()) { case X_SPLIT: { - if (minX < node.splitValue) { - stack.push(node.leftChild); + if (minX < node.getSplitValue()) { + stack.push(node.getLeftChild()); } - if (maxX >= node.splitValue) { - stack.push(node.rightChild); + if (maxX >= node.getSplitValue()) { + stack.push(node.getRightChild()); } } break; case Y_SPLIT: { - if (minY < node.splitValue) { - stack.push(node.leftChild); + if (minY < node.getSplitValue()) { + stack.push(node.getLeftChild()); } - if (maxY >= node.splitValue) { - stack.push(node.rightChild); + if (maxY >= node.getSplitValue()) { + stack.push(node.getRightChild()); } } case LEAF: @@ -165,16 +166,16 @@ public static Optional readFromGzippedBase64String(String encoded) public static KDTreeNode stripNode(KDTreeNode node) { KDTreeNode newNode = new KDTreeNode(); if (node.isSetNodeType()) { - newNode.setNodeType(node.nodeType); + newNode.setNodeType(node.getNodeType()); } if (node.isSetSplitValue()) { - newNode.setSplitValue(node.splitValue); + newNode.setSplitValue(node.getSplitValue()); } if (node.isSetLeftChild()) { - newNode.setLeftChild(node.leftChild); + newNode.setLeftChild(node.getLeftChild()); } if (node.isSetRightChild()) { - newNode.setRightChild(node.rightChild); + newNode.setRightChild(node.getRightChild()); } return newNode; } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/KernelModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/KernelModel.java index f050dc3b..a5cec569 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/KernelModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/KernelModel.java @@ -1,34 +1,28 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.Serializable; -import java.util.Map; -import java.util.List; -import java.util.HashMap; -import java.util.ArrayList; -import java.util.PriorityQueue; -import java.util.AbstractMap; - -import com.airbnb.aerosolve.core.DictionaryRecord; +import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.DebugScoreRecord; -import com.airbnb.aerosolve.core.FunctionForm; -import com.airbnb.aerosolve.core.util.Util; -import com.airbnb.aerosolve.core.util.StringDictionary; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.util.FloatVector; +import com.airbnb.aerosolve.core.util.StringDictionary; import com.airbnb.aerosolve.core.util.SupportVector; - +import com.airbnb.aerosolve.core.util.Util; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A kernel machine with arbitrary kernels. Different support vectors can have different kernels. // The conversion from sparse features to dense features is done by dictionary lookup. Also since // non-linear kernels are used there is no need to cross features, the feature interactions are done by // considering kernel responses to the support vectors. Try to keep features under a thousand. +@Accessors(fluent = true, chain = true) public class KernelModel extends AbstractModel { private static final long serialVersionUID = 7651061358422885397L; @@ -38,25 +32,24 @@ public class KernelModel extends AbstractModel { @Getter @Setter List supportVectors; - public KernelModel() { + public KernelModel(FeatureRegistry registry) { + super(registry); dictionary = new StringDictionary(); supportVectors = new ArrayList<>(); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector vec = dictionary.makeVectorFromSparseFloats(flatFeatures); + public double scoreItem(FeatureVector combinedItem) { + FloatVector vec = dictionary.makeVectorFromSparseFloats(combinedItem); float sum = 0.0f; - for (int i = 0; i < supportVectors.size(); i++) { - SupportVector sv = supportVectors.get(i); + for (SupportVector sv : supportVectors) { sum += sv.evaluate(vec); } return sum; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { return 0.0f; } @@ -64,18 +57,17 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { // (TODO) implement debugScoreComponents - List scoreRecordsList = new ArrayList<>(); - return scoreRecordsList; + return new ArrayList<>(); } @Override - public void onlineUpdate(float grad, float learningRate, Map> flatFeatures) { - FloatVector vec = dictionary.makeVectorFromSparseFloats(flatFeatures); - float deltaG = - learningRate * grad; + public void onlineUpdate(double grad, double learningRate, FeatureVector vector) { + FloatVector vec = dictionary.makeVectorFromSparseFloats(vector); + double deltaG = - learningRate * grad; for (SupportVector sv : supportVectors) { - float response = sv.evaluateUnweighted(vec); - float deltaW = deltaG * response; - sv.setWeight(sv.getWeight() + deltaW); + double response = sv.evaluateUnweighted(vec); + double deltaW = deltaG * response; + sv.setWeight((float) (sv.getWeight() + deltaW)); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/LinearModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/LinearModel.java index fdb91b78..ac1c588e 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/LinearModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/LinearModel.java @@ -1,59 +1,55 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.util.*; -import java.util.Map.Entry; - import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.util.Util; -import com.google.common.hash.HashCode; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2DoubleMap; +import it.unimi.dsi.fastutil.objects.Reference2DoubleOpenHashMap; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.PriorityQueue; import lombok.Getter; -import lombok.Setter; +import lombok.experimental.Accessors; import org.apache.http.annotation.NotThreadSafe; /** * A linear model backed by a hash map. */ @NotThreadSafe +@Accessors(fluent = true, chain = true) public class LinearModel extends AbstractModel { - @Getter @Setter - protected Map> weights; + @Getter + protected Object2DoubleMap weights; - @Override - public float scoreItem(FeatureVector combinedItem) { - Map> stringFeatures = combinedItem.getStringFeatures(); - if (stringFeatures == null || weights == null) { - return 0.0f; - } - float sum = 0.0f; - for (Entry> entry : stringFeatures.entrySet()) { - String family = entry.getKey(); - Map inner = weights.get(family); - if (inner == null) { - continue; - } + public LinearModel(FeatureRegistry registry) { + super(registry); + weights = new Object2DoubleOpenHashMap<>(); + } - for (String value : entry.getValue()) { - Float weight = inner.get(value); - if (weight != null) { - sum += weight; - } - } - } - return sum; + @Override + public double scoreItem(FeatureVector combinedItem) { + return scoreItemInternal(combinedItem, null, null, null); } - public static class EntryComparator implements Comparator> { + public static class EntryComparator implements Comparator> { @Override - public int compare(Entry e1, Entry e2) { - float v1 = Math.abs(e1.getValue()); - float v2 = Math.abs(e2.getValue()); + public int compare(Entry e1, Entry e2) { + double v1 = Math.abs(e1.getValue()); + double v2 = Math.abs(e2.getValue()); if (v1 > v2) { return -1; } else if (v1 < v2) { @@ -66,47 +62,25 @@ public int compare(Entry e1, Entry e2) { // Debug scores a single item. These are explanations for why a model // came up with the score. @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { - Map> stringFeatures = combinedItem.getStringFeatures(); - if (stringFeatures == null || weights == null) { - return 0.0f; - } - float sum = 0.0f; - PriorityQueue> scores = new PriorityQueue<>(100, new EntryComparator()); - Map familyScores = new HashMap<>(); - for (Entry> entry : stringFeatures.entrySet()) { - String family = entry.getKey(); - for (String value : entry.getValue()) { - HashCode code = Util.getHashCode(family, value); - Map inner = weights.get(family); - if (inner != null) { - Float weight = inner.get(value); - if (weight != null) { - String str = family + ':' + value + " = " + weight + '\n'; - if (familyScores.containsKey(family)) { - Float wt = familyScores.get(family); - familyScores.put(family, wt + weight); - } else { - familyScores.put(family, weight); - } - AbstractMap.SimpleEntry ent = new AbstractMap.SimpleEntry( - str, weight); - scores.add(ent); - sum += weight; - } - } - } + if (weights == null) { + return 0.0d; } + + PriorityQueue> scores = new PriorityQueue<>(100, new EntryComparator()); + Reference2DoubleMap familyScores = new Reference2DoubleOpenHashMap<>(); + double sum = scoreItemInternal(combinedItem, familyScores, scores, null); + builder.append("Scores by family ===>\n"); if (!familyScores.isEmpty()) { - PriorityQueue> familyPQ = new PriorityQueue<>(10, new EntryComparator()); - for (Entry entry : familyScores.entrySet()) { + PriorityQueue> familyPQ = new PriorityQueue<>(10, new EntryComparator()); + for (Entry entry : familyScores.entrySet()) { familyPQ.add(entry); } while (!familyPQ.isEmpty()) { - Entry entry = familyPQ.poll(); - builder.append(entry.getKey() + " = " + entry.getValue() + '\n'); + Entry entry = familyPQ.poll(); + builder.append(entry.getKey().name() + " = " + entry.getValue() + '\n'); } } builder.append("Top 15 scores ===>\n"); @@ -114,9 +88,12 @@ public float debugScoreItem(FeatureVector combinedItem, int count = 0; float subsum = 0.0f; while (!scores.isEmpty()) { - Entry entry = scores.poll(); - builder.append(entry.getKey()); - float val = entry.getValue(); + Entry entry = scores.poll(); + Feature feature = entry.getKey(); + double val = entry.getValue(); + String str = feature.family().name() + ':' + feature.name() + + " = " + val + '\n'; + builder.append(str); subsum += val; count = count + 1; if (count >= 15) { @@ -131,30 +108,56 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { - // linear model takes only string features - Map> stringFeatures = combinedItem.getStringFeatures(); List scoreRecordsList = new ArrayList<>(); - if (stringFeatures == null || weights == null) { - return scoreRecordsList; + scoreItemInternal(combinedItem, null, null, scoreRecordsList); + return scoreRecordsList; + } + + private double scoreItemInternal(FeatureVector combinedItem, + Reference2DoubleMap familyScores, + PriorityQueue> scores, + List scoreRecordsList) { + if (weights == null) { + return 0.0d; } - for (Entry> entry : stringFeatures.entrySet()) { - String family = entry.getKey(); - Map inner = weights.get(family); - if (inner == null) continue; - for (String value : entry.getValue()) { - DebugScoreRecord record = new DebugScoreRecord(); - Float weight = inner.get(value); - if (weight != null) { - record.setFeatureFamily(family); - record.setFeatureName(value); - // 1.0 if the string feature exists, 0.0 otherwise - record.setFeatureValue(1.0); - record.setFeatureWeight(weight); - scoreRecordsList.add(record); + double sum = 0.0d; + + // No need to filter to string values. Just iterate the keys and see if they exist in the model + // If we passed in a vector with values that matter then we made a mistake using a linear model. + for (Feature feature : combinedItem.keySet()) { + if (!weights.containsKey(feature)) { + continue; + } + double weight = weights.getDouble(feature); + sum += weight; + + if (familyScores != null) { + Family family = feature.family(); + if (familyScores.containsKey(family)) { + double wt = familyScores.getDouble(family); + familyScores.put(family, wt + weight); + } else { + familyScores.put(family, weight); } } + + if (scores != null) { + AbstractMap.SimpleEntry ent = new AbstractMap.SimpleEntry<>( + feature, weight); + scores.add(ent); + } + + if (scoreRecordsList != null) { + DebugScoreRecord record = new DebugScoreRecord(); + record.setFeatureFamily(feature.family().name()); + record.setFeatureName(feature.name()); + // 1.0 if the string feature exists, 0.0 otherwise + record.setFeatureValue(1.0); + record.setFeatureWeight(weight); + scoreRecordsList.add(record); + } } - return scoreRecordsList; + return sum; } // Loads model from a buffered stream. @@ -168,19 +171,15 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO if (header.isSetOffset()) { offset = header.getOffset(); } - weights = new HashMap<>(); + weights = new Object2DoubleOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = weights.get(family); - if (inner == null) { - inner = new HashMap<>(); - weights.put(family, inner); - } - float weight = (float) record.getFeatureWeight(); - inner.put(name, weight); + Feature feature = registry.feature(family, name); + double weight = record.getFeatureWeight(); + weights.put(feature, weight); } } @@ -190,26 +189,19 @@ public void save(BufferedWriter writer) throws IOException { header.setModelType("linear"); header.setSlope(slope); header.setOffset(offset); - long count = 0; - for (Map.Entry> familyMap : weights.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - count++; - } - } - header.setNumRecords(count); + header.setNumRecords(weights.size()); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : weights.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - record.setFeatureWeight(feature.getValue()); - writer.write(Util.encode(record)); - writer.newLine(); - } + for (Object2DoubleMap.Entry entry : weights.object2DoubleEntrySet()) { + Feature feature = entry.getKey(); + ModelRecord record = new ModelRecord(); + record.setFeatureFamily(feature.family().name()); + record.setFeatureName(feature.name()); + record.setFeatureWeight(entry.getDoubleValue()); + writer.write(Util.encode(record)); + writer.newLine(); } writer.flush(); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/LowRankLinearModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/LowRankLinearModel.java index 64ff52ea..c96ed430 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/LowRankLinearModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/LowRankLinearModel.java @@ -1,22 +1,29 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.LabelDictionaryEntry; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.MulticlassScoringResult; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; import com.airbnb.aerosolve.core.util.FloatVector; - +import com.airbnb.aerosolve.core.util.Util; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; // A low rank linear model that supports multi-class classification. // The class vector y = W' * V * x where x is d-dim feature vector. @@ -24,6 +31,7 @@ // V: D-by-d matrix, mapping from feature space to the joint embedding // W: D-by-Y matrix, mapping from label space to the joint embedding // Reference: Jason Weston et al. "WSABIE: Scaling Up To Large Vocabulary Image Annotation", IJCAI 2011. +@Accessors(fluent = true, chain = true) public class LowRankLinearModel extends AbstractModel { static final long serialVersionUID = -8894096678183767660L; @@ -33,7 +41,7 @@ public class LowRankLinearModel extends AbstractModel { // each FloatVector in the map is a D-dim vector @Getter @Setter - private Map> featureWeightVector; + private Map featureWeightVector; // labelWeightVector represents the projection from label space to embedding // Map label to a row in W, each FloatVector in the map is a D-dim vector @@ -54,8 +62,9 @@ public class LowRankLinearModel extends AbstractModel { @Setter private int embeddingDimension; - public LowRankLinearModel() { - featureWeightVector = new HashMap<>(); + public LowRankLinearModel(FeatureRegistry registry) { + super(registry); + featureWeightVector = new Object2ObjectOpenHashMap<>(); labelWeightVector = new HashMap<>(); labelDictionary = new ArrayList<>(); } @@ -63,16 +72,15 @@ public LowRankLinearModel() { // In the binary case this is just the score for class 0. // Ideally use a binary model for binary classification. @Override - public float scoreItem(FeatureVector combinedItem) { + public double scoreItem(FeatureVector combinedItem) { // Not supported. assert (false); - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector sum = scoreFlatFeature(flatFeatures); + FloatVector sum = scoreFlatFeature(combinedItem); return sum.values[0]; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { // TODO(peng) : implement debug. return scoreItem(combinedItem); @@ -88,8 +96,7 @@ public List debugScoreComponents(FeatureVector combinedItem) { public ArrayList scoreItemMulticlass(FeatureVector combinedItem) { ArrayList results = new ArrayList<>(); - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector sum = scoreFlatFeature(flatFeatures); + FloatVector sum = scoreFlatFeature(combinedItem); for (int i = 0; i < labelDictionary.size(); i++) { MulticlassScoringResult result = new MulticlassScoringResult(); @@ -101,25 +108,22 @@ public ArrayList scoreItemMulticlass(FeatureVector comb return results; } - public FloatVector scoreFlatFeature(Map> flatFeatures) { - FloatVector fvProjection = projectFeatureToEmbedding(flatFeatures); + public FloatVector scoreFlatFeature(FeatureVector vector) { + FloatVector fvProjection = projectFeatureToEmbedding(vector); return projectEmbeddingToLabel(fvProjection); } - public FloatVector projectFeatureToEmbedding(Map> flatFeatures) { + public FloatVector projectFeatureToEmbedding(FeatureVector vector) { FloatVector fvProjection = new FloatVector(embeddingDimension); // compute the projection from feature space to D-dim joint space - for (Map.Entry> entry : flatFeatures.entrySet()) { - Map family = featureWeightVector.get(entry.getKey()); - if (family != null) { - for (Map.Entry feature : entry.getValue().entrySet()) { - FloatVector vec = family.get(feature.getKey()); - if (vec != null) { - fvProjection.multiplyAdd(feature.getValue().floatValue(), vec); - } - } + for (FeatureValue value : vector) { + if (!featureWeightVector.containsKey(value.feature())) { + continue; } + + FloatVector vec = featureWeightVector.get(value.feature()); + fvProjection.multiplyAdd(value.value(), vec); } return fvProjection; } @@ -142,7 +146,7 @@ public FloatVector projectEmbeddingToLabel(FloatVector fvProjection) { public void buildLabelToIndex() { labelToIndex = new HashMap<>(); for (int i = 0; i < labelDictionary.size(); i++) { - String labelKey = labelDictionary.get(i).label; + String labelKey = labelDictionary.get(i).getLabel(); labelToIndex.put(labelKey, i); } } @@ -150,11 +154,7 @@ public void buildLabelToIndex() { public void save(BufferedWriter writer) throws IOException { ModelHeader header = new ModelHeader(); header.setModelType("low_rank_linear"); - long count = 0; - for (Map.Entry> familyMap : featureWeightVector.entrySet()) { - count += familyMap.getValue().entrySet().size(); - } - header.setNumRecords(count); + header.setNumRecords(featureWeightVector.size()); header.setLabelDictionary(labelDictionary); Map> labelEmbedding = new HashMap<>(); for (Map.Entry labelRepresentation : labelWeightVector.entrySet()) { @@ -172,19 +172,18 @@ public void save(BufferedWriter writer) throws IOException { headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : featureWeightVector.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - ArrayList arrayList = new ArrayList<>(); - for (int i = 0; i < feature.getValue().values.length; i++) { - arrayList.add((double) feature.getValue().values[i]); - } - record.setWeightVector(arrayList); - writer.write(Util.encode(record)); - writer.newLine(); + for (Map.Entry entry : featureWeightVector.entrySet()) { + ModelRecord record = new ModelRecord(); + Feature feature = entry.getKey(); + record.setFeatureFamily(feature.family().name()); + record.setFeatureName(feature.name()); + ArrayList arrayList = new ArrayList<>(); + for (int i = 0; i < entry.getValue().values.length; i++) { + arrayList.add((double) entry.getValue().values[i]); } + record.setWeightVector(arrayList); + writer.write(Util.encode(record)); + writer.newLine(); } writer.flush(); } @@ -211,22 +210,18 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO labelWeightVector.put(labelKey, labelWeight); } - featureWeightVector = new HashMap<>(); + featureWeightVector = new Reference2ObjectOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = featureWeightVector.get(family); - if (inner == null) { - inner = new HashMap<>(); - featureWeightVector.put(family, inner); - } + Feature feature = registry.feature(family, name); FloatVector vec = new FloatVector(record.getWeightVector().size()); for (int j = 0; j < record.getWeightVector().size(); j++) { vec.values[j] = record.getWeightVector().get(j).floatValue(); } - inner.put(name, vec); + featureWeightVector.put(feature, vec); } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/MaxoutModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/MaxoutModel.java index de18cf87..87a69618 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/MaxoutModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/MaxoutModel.java @@ -1,25 +1,35 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; import com.airbnb.aerosolve.core.util.FloatVector; +import com.airbnb.aerosolve.core.util.Util; +import it.unimi.dsi.fastutil.objects.Object2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.Serializable; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; import lombok.Getter; import lombok.Setter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import lombok.experimental.Accessors; // A 2 layer maxout unit that can represent functions using difference // of piecewise linear convex functions. // http://arxiv.org/abs/1302.4389 +@Accessors(fluent = true, chain = true) public class MaxoutModel extends AbstractModel { private static final long serialVersionUID = -849900702679383422L; @@ -28,7 +38,7 @@ public class MaxoutModel extends AbstractModel { private int numHidden; @Getter @Setter - private Map> weightVector; + private Map weightVector; private WeightVector bias; @@ -55,67 +65,65 @@ public static class WeightVector implements Serializable { public float scale; } - public MaxoutModel() { + public MaxoutModel(FeatureRegistry registry) { + super(registry); } public void initForTraining(int numHidden) { this.numHidden = numHidden; - weightVector = new HashMap<>(); + weightVector = new Object2ObjectOpenHashMap<>(); bias = new WeightVector(1.0f, numHidden, false); - Map special = new HashMap<>(); - weightVector.put("$SPECIAL", special); - special.put("$BIAS", bias); + Feature specialBias = registry.feature("$SPECIAL", "$BIAS"); + weightVector.put(specialBias, bias); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - return scoreFlatFeatures(flatFeatures); + public double scoreItem(FeatureVector combinedItem) { + FloatVector response = getResponse(combinedItem); + FloatVector.MinMaxResult result = response.getMinMaxResult(); + + return result.maxValue - result.minValue; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - FloatVector response = getResponse(flatFeatures); + FloatVector response = getResponse(combinedItem); FloatVector.MinMaxResult result = response.getMinMaxResult(); - float sum = result.maxValue - result.minValue; + double sum = result.maxValue - result.minValue; - PriorityQueue> scores = + PriorityQueue> scores = new PriorityQueue<>(100, new LinearModel.EntryComparator()); float[] biasWt = bias.weights.getValues(); - float biasScore = biasWt[result.maxIndex] - biasWt[result.minIndex]; - scores.add(new AbstractMap.SimpleEntry( + double biasScore = biasWt[result.maxIndex] - biasWt[result.minIndex]; + scores.add(new AbstractMap.SimpleEntry<>( "bias = " + biasScore + "
\n", biasScore)); - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightVector.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightVector weightVec = familyWeightMap.get(feature.getKey()); - if (weightVec == null) continue; - float val = feature.getValue().floatValue(); - float[] wt = weightVec.weights.getValues(); - float p = wt[result.maxIndex] * weightVec.scale; - float n = wt[result.minIndex] * weightVec.scale; - float subscore = val * (p - n); - String str = featureFamily.getKey() + ":" + feature.getKey() + "=" + val - + " * (" + p + " - " + n + ") = " + subscore + "
\n"; - scores.add(new AbstractMap.SimpleEntry(str, subscore)); - } + for (FeatureValue value : combinedItem) { + WeightVector weightVec = weightVector.get(value.feature()); + if (weightVec == null) continue; + Feature feature = value.feature(); + double val = value.value(); + float[] wt = weightVec.weights.getValues(); + float p = wt[result.maxIndex] * weightVec.scale; + float n = wt[result.minIndex] * weightVec.scale; + double subscore = val * (p - n); + String str = feature.family().name() + ":" + feature.name() + "=" + val + + " * (" + p + " - " + n + ") = " + subscore + "
\n"; + scores.add(new AbstractMap.SimpleEntry<>(str, subscore)); } builder.append("Top 15 scores ===>\n"); if (!scores.isEmpty()) { int count = 0; float subsum = 0.0f; while (!scores.isEmpty()) { - Map.Entry entry = scores.poll(); + Map.Entry entry = scores.poll(); builder.append(entry.getKey()); - float val = entry.getValue(); + double val = entry.getValue(); subsum += val; count = count + 1; if (count >= 15) { @@ -136,15 +144,11 @@ public List debugScoreComponents(FeatureVector combinedItem) { return scoreRecordsList; } + // TODO (Brad): Get rid of floats // Adds a new vector with a specified scale. - public void addVector(String family, String feature, float scale) { - Map featFamily = weightVector.get(family); - if (featFamily == null) { - featFamily = new HashMap<>(); - weightVector.put(family, featFamily); - } + public void addVector(Feature feature, float scale) { WeightVector vec = new WeightVector(scale, numHidden, true); - featFamily.put(feature, vec); + weightVector.put(feature, vec); } // Updates the gradient @@ -154,24 +158,20 @@ public void update(float grad, float l2Reg, float momentum, FloatVector.MinMaxResult result, - Map> flatFeatures) { - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightVector.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightVector weightVec = familyWeightMap.get(feature.getKey()); - if (weightVec == null) continue; - float val = feature.getValue().floatValue() * weightVec.scale; - updateWeightVector(result.minIndex, - result.maxIndex, - val, - grad, - learningRate, - l1Reg, - l2Reg, - momentum, - weightVec); - } + Iterable vector) { + for (FeatureValue value : vector) { + WeightVector weightVec = weightVector.get(value.feature()); + if (weightVec == null) continue; + float val = (float) value.value() * weightVec.scale; + updateWeightVector(result.minIndex, + result.maxIndex, + val, + grad, + learningRate, + l1Reg, + l2Reg, + momentum, + weightVec); } updateWeightVector(result.minIndex, result.maxIndex, @@ -224,24 +224,12 @@ private void updateWeightVector(int minIndex, } } - public float scoreFlatFeatures(Map> flatFeatures) { - FloatVector response = getResponse(flatFeatures); - FloatVector.MinMaxResult result = response.getMinMaxResult(); - - return result.maxValue - result.minValue; - } - - public FloatVector getResponse(Map> flatFeatures) { + public FloatVector getResponse(Iterable vector) { FloatVector sum = new FloatVector(numHidden); - for (Map.Entry> entry : flatFeatures.entrySet()) { - Map family = weightVector.get(entry.getKey()); - if (family != null) { - for (Map.Entry feature : entry.getValue().entrySet()) { - WeightVector hidden = family.get(feature.getKey()); - if (hidden != null) { - sum.multiplyAdd(feature.getValue().floatValue() * hidden.scale, hidden.weights); - } - } + for (FeatureValue value : vector) { + WeightVector hidden = weightVector.get(value.feature()); + if (hidden != null) { + sum.multiplyAdd(value.value() * hidden.scale, hidden.weights); } } sum.add(bias.weights); @@ -252,31 +240,24 @@ public void save(BufferedWriter writer) throws IOException { ModelHeader header = new ModelHeader(); header.setModelType("maxout"); header.setNumHidden(numHidden); - long count = 0; - for (Map.Entry> familyMap : weightVector.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - count++; - } - } - header.setNumRecords(count); + header.setNumRecords(weightVector.size()); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : weightVector.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - ArrayList arrayList = new ArrayList(); - for (int i = 0; i < feature.getValue().weights.values.length; i++) { - arrayList.add((double) feature.getValue().weights.values[i]); - } - record.setWeightVector(arrayList); - record.setScale(feature.getValue().scale); - writer.write(Util.encode(record)); - writer.newLine(); + for (Map.Entry entry : weightVector.entrySet()) { + ModelRecord record = new ModelRecord(); + Feature feature = entry.getKey(); + record.setFeatureFamily(feature.family().name()); + record.setFeatureName(feature.name()); + ArrayList arrayList = new ArrayList(); + for (int i = 0; i < entry.getValue().weights.values.length; i++) { + arrayList.add((double) entry.getValue().weights.values[i]); } + record.setWeightVector(arrayList); + record.setScale(entry.getValue().scale); + writer.write(Util.encode(record)); + writer.newLine(); } writer.flush(); } @@ -285,28 +266,21 @@ public void save(BufferedWriter writer) throws IOException { protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException { long rows = header.getNumRecords(); numHidden = header.getNumHidden(); - weightVector = new HashMap<>(); + weightVector = new Reference2ObjectOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = weightVector.get(family); - if (inner == null) { - inner = new HashMap<>(); - weightVector.put(family, inner); - } + Feature feature = registry.feature(family, name); WeightVector vec = new WeightVector(); vec.scale = (float) record.getScale(); vec.weights = new FloatVector(numHidden); for (int j = 0; j < numHidden; j++) { vec.weights.values[j] = record.getWeightVector().get(j).floatValue(); } - inner.put(name, vec); + weightVector.put(feature, vec); } - Map special = weightVector.get("$SPECIAL"); - assert(special != null); - bias = special.get("$BIAS"); - assert(bias != null); + assert(weightVector.containsKey(registry.family("$SPECIAL").feature("$BIAS"))); } } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/MlpModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/MlpModel.java index f1e3bed0..05461fb8 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/MlpModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/MlpModel.java @@ -1,25 +1,36 @@ package com.airbnb.aerosolve.core.models; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.IOException; -import java.util.*; - import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.transforms.LegacyNames; import com.airbnb.aerosolve.core.util.FloatVector; +import com.airbnb.aerosolve.core.util.Util; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Getter; import lombok.Setter; +import lombok.experimental.Accessors; /** * Multilayer perceptron (MLP) model https://en.wikipedia.org/wiki/Multilayer_perceptron * The current implementation is for the case where there is only one output node. */ +@LegacyNames("multilayer_perceptron") +@Accessors(fluent = true, chain = true) public class MlpModel extends AbstractModel { private static final long serialVersionUID = -6870862764598907090L; @@ -28,7 +39,7 @@ public class MlpModel extends AbstractModel { // or the output layer @Getter @Setter - private Map> inputLayerWeights; + private Reference2ObjectMap inputLayerWeights; // if there is hidden layer, this defines the projection // from one hidden layer to the next hidden layer or output layer @@ -60,16 +71,19 @@ public class MlpModel extends AbstractModel { @Setter private Map layerActivations; - public MlpModel() { + public MlpModel(FeatureRegistry registry) { + super(registry); layerNodeNumber = new ArrayList<>(); - inputLayerWeights = new HashMap<>(); + inputLayerWeights = new Reference2ObjectOpenHashMap<>(); hiddenLayerWeights = new HashMap<>(); layerActivations = new HashMap<>(); bias = new HashMap<>(); activationFunction = new ArrayList<>(); } - public MlpModel(ArrayList activation, ArrayList nodeNumbers) { + public MlpModel(ArrayList activation, ArrayList nodeNumbers, + FeatureRegistry registry) { + super(registry); // n is the number of hidden layers (including output layer, excluding input layer) // activation specifies activation function // nodeNumbers: specifies number of nodes in each hidden layer @@ -77,7 +91,7 @@ public MlpModel(ArrayList activation, ArrayList nodeNumbe activationFunction = activation; layerNodeNumber = nodeNumbers; assert(activation.size() == numHiddenLayers + 1); - inputLayerWeights = new HashMap<>(); + inputLayerWeights = new Reference2ObjectOpenHashMap<>(); hiddenLayerWeights = new HashMap<>(); // bias including the bias added at the output layer bias = new HashMap<>(); @@ -94,33 +108,32 @@ public MlpModel(ArrayList activation, ArrayList nodeNumbe } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - return forwardPropagation(flatFeatures); + public double scoreItem(FeatureVector combinedItem) { + return forwardPropagation(combinedItem); } - public float forwardPropagation(Map> flatFeatures) { - projectInputLayer(flatFeatures, 0.0); + public double forwardPropagation(FeatureVector vector) { + projectInputLayer(vector, 0.0); for (int i = 0; i < numHiddenLayers; i++) { projectHiddenLayer(i, 0.0); } return layerActivations.get(numHiddenLayers).get(0); } - public float forwardPropagationWithDropout(Map> flatFeatures, Double dropout) { + public double forwardPropagationWithDropout(FeatureVector vector, Double dropout) { // reference: George E. Dahl et al. "IMPROVING DEEP NEURAL NETWORKS FOR LVCSR USING RECTIFIED LINEAR UNITS AND DROPOUT" // scale the input to a node by 1/(1-dropout), so that we don't need to rescale model weights after training // make sure the value is between 0 and 1 assert(dropout > 0.0); assert(dropout < 1.0); - projectInputLayer(flatFeatures, dropout); + projectInputLayer(vector, dropout); for (int i = 0; i < numHiddenLayers; i++) { projectHiddenLayer(i, dropout); } return layerActivations.get(numHiddenLayers).get(0); } - public FloatVector projectInputLayer(Map> flatFeatures, Double dropout) { + public FloatVector projectInputLayer(FeatureVector vector, Double dropout) { // compute the projection from input feature space to the first hidden layer or // output layer if there is no hidden layer // output: fvProjection is a float vector representing the activation at the first layer after input layer @@ -134,16 +147,11 @@ public FloatVector projectInputLayer(Map> flatFeatur fvProjection.setConstant(0.0f); } - for (Map.Entry> entry : flatFeatures.entrySet()) { - Map family = inputLayerWeights.get(entry.getKey()); - if (family != null) { - for (Map.Entry feature : entry.getValue().entrySet()) { - FloatVector vec = family.get(feature.getKey()); - if (vec != null) { - if (dropout > 0.0 && Math.random() < dropout) continue; - fvProjection.multiplyAdd(feature.getValue().floatValue(), vec); - } - } + for (FeatureValue value : vector) { + FloatVector vec = inputLayerWeights.get(value.feature()); + if (vec != null) { + if (dropout > 0.0 && Math.random() < dropout) continue; + fvProjection.multiplyAdd(value.value(), vec); } } if (dropout > 0.0 && dropout < 1.0) { @@ -205,7 +213,7 @@ private void applyActivation(FloatVector input, FunctionForm func) { } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { // TODO(peng): implement debug return scoreItem(combinedItem); @@ -227,31 +235,26 @@ public void save(BufferedWriter writer) throws IOException { nodeNum.add(layerNodeNumber.get(i)); } header.setNumberHiddenNodes(nodeNum); - long count = 0; - for (Map.Entry> familyMap : inputLayerWeights.entrySet()) { - count += familyMap.getValue().entrySet().size(); - } // number of record for the input layer weights - header.setNumRecords(count); + header.setNumRecords(inputLayerWeights.size()); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); // save the input layer weight, one record per feature - for (Map.Entry> familyMap : inputLayerWeights.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - ArrayList arrayList = new ArrayList<>(); - for (int i = 0; i < feature.getValue().length(); i++) { - arrayList.add((double) feature.getValue().values[i]); - } - record.setWeightVector(arrayList); - writer.write(Util.encode(record)); - writer.newLine(); + for (Map.Entry entry : inputLayerWeights.entrySet()) { + ModelRecord record = new ModelRecord(); + Feature feature = entry.getKey(); + record.setFeatureFamily(feature.family().name()); + record.setFeatureName(feature.name()); + ArrayList arrayList = new ArrayList<>(); + for (int i = 0; i < entry.getValue().length(); i++) { + arrayList.add((double) entry.getValue().values[i]); } + record.setWeightVector(arrayList); + writer.write(Util.encode(record)); + writer.newLine(); } // save the bias for each layer after input layer, one record per layer @@ -302,16 +305,12 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = inputLayerWeights.get(family); - if (inner == null) { - inner = new HashMap<>(); - inputLayerWeights.put(family, inner); - } FloatVector vec = new FloatVector(record.getWeightVector().size()); for (int j = 0; j < record.getWeightVector().size(); j++) { vec.values[j] = record.getWeightVector().get(j).floatValue(); } - inner.put(name, vec); + Feature feature = registry.feature(family, name); + inputLayerWeights.put(feature, vec); } // load bias and activation function for (int i = 0; i < numHiddenLayers + 1; i++) { diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/Model.java b/core/src/main/java/com/airbnb/aerosolve/core/models/Model.java index f85a692f..b3120242 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/Model.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/Model.java @@ -9,9 +9,9 @@ interface Model { // Scores a single item. The transforms should already have been applied to // the context and item and combined item. - float scoreItem(FeatureVector combinedItem); + double scoreItem(FeatureVector combinedItem); // Debug scores a single item. These are explanations for why a model // came up with the score. - float debugScoreItem(FeatureVector combinedItem, + double debugScoreItem(FeatureVector combinedItem, StringBuilder builder); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/ModelFactory.java b/core/src/main/java/com/airbnb/aerosolve/core/models/ModelFactory.java index 430cb681..9410ae01 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/ModelFactory.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/ModelFactory.java @@ -2,40 +2,61 @@ import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.BufferedReader; import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.HashMap; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +@Slf4j public final class ModelFactory { - private static final Logger log = LoggerFactory.getLogger(ModelFactory.class); private ModelFactory() { } - // Creates - @SuppressWarnings("deprecation") - public static AbstractModel createByName(String name) { - switch (name) { - case "linear": return new LinearModel(); - case "maxout": return new MaxoutModel(); - case "spline": return new SplineModel(); - case "boosted_stumps": return new BoostedStumpsModel(); - case "decision_tree": return new DecisionTreeModel(); - case "forest": return new ForestModel(); - case "additive": return new AdditiveModel(); - case "kernel" : return new KernelModel(); - case "full_rank_linear" : return new FullRankLinearModel(); - case "low_rank_linear" : return new LowRankLinearModel(); - case "multilayer_perceptron" : return new MlpModel(); + /** + * This a Map from the type name of a model ("linear" for instance) to a constructor that takes + * a single argument of type FeatureRegistry. This caches most of the reflection and makes + * instantiation easy. But we don't have to explicitly include new models in this class. They + * are found from the classpath. + */ + private static Map> MODEL_CONSTRUCTORS; + + public static AbstractModel createByName(String name, FeatureRegistry registry) { + if (MODEL_CONSTRUCTORS == null) { + loadModelConstructors(); + } + Constructor constructor = MODEL_CONSTRUCTORS.get(name); + if (constructor == null) { + throw new IllegalArgumentException( + String.format("No model exists with name %s", name)); + } + try { + return constructor.newInstance(registry); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException( + String.format("There was an error instantiating Model of type %s : %s", + name, e.getMessage()), e); + } + } + + private static synchronized void loadModelConstructors() { + if (MODEL_CONSTRUCTORS != null) { + return; } - log.error("Could not create model of type " + name); - return null; + MODEL_CONSTRUCTORS = Util.loadConstructorsFromPackage(AbstractModel.class, + "com.airbnb.aerosolve.core.models", + "Model", + FeatureRegistry.class); } - public static Optional createFromReader(BufferedReader reader) throws IOException { + + public static Optional createFromReader(BufferedReader reader, + FeatureRegistry registry) throws IOException { Optional model = Optional.absent(); String headerLine = reader.readLine(); ModelRecord record = Util.decodeModel(headerLine); @@ -45,7 +66,7 @@ public static Optional createFromReader(BufferedReader reader) th } ModelHeader header = record.getModelHeader(); if (header != null) { - AbstractModel result = createByName(header.getModelType()); + AbstractModel result = createByName(header.getModelType(), registry); if (result != null) { result.loadInternal(header, reader); model = Optional.of(result); diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/NDTreeModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/NDTreeModel.java index 6a222063..8e082c8a 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/NDTreeModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/NDTreeModel.java @@ -4,6 +4,7 @@ import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; import lombok.Getter; +import lombok.experimental.Accessors; import org.apache.commons.codec.binary.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,6 +20,7 @@ /* N-Dimensional KDTreeModel. */ +@Accessors(fluent = true, chain = true) public class NDTreeModel implements Serializable { private static final long serialVersionUID = -2884260218927875615L; private static final Logger log = LoggerFactory.getLogger(NDTreeModel.class); @@ -33,7 +35,7 @@ public NDTreeModel(NDTreeNode[] nodes) { this.nodes = nodes; int max = 0; for (NDTreeNode node : nodes) { - max = Math.max(max, node.axisIndex); + max = Math.max(max, node.getAxisIndex()); } dimension = max + 1; } @@ -42,7 +44,7 @@ public NDTreeModel(List nodeList) { this(nodeList.toArray(new NDTreeNode[nodeList.size()])); } - public int leaf(float ... coordinates) { + public int leaf(double ... coordinates) { if (nodes == null || nodes.length == 0) return -1; return binarySearch(nodes, coordinates, 0); } @@ -80,13 +82,13 @@ public List queryBox(List min, List max) { int currIdx = stack.pop(); idx.add(currIdx); NDTreeNode node = nodes[currIdx]; - int index = node.axisIndex; + int index = node.getAxisIndex(); if (index > LEAF) { - if (min.get(index) < node.splitValue) { - stack.push(node.leftChild); + if (min.get(index) < node.getSplitValue()) { + stack.push(node.getLeftChild()); } - if (max.get(index) >= node.splitValue) { - stack.push(node.rightChild); + if (max.get(index) >= node.getSplitValue()) { + stack.push(node.getRightChild()); } } } @@ -143,7 +145,7 @@ private static List query(NDTreeNode[] a, Object key, int currIdx) { // TODO use https://github.com/facebook/swift private static int next(NDTreeNode node, Object key) { - int index = node.axisIndex; + int index = node.getAxisIndex(); if (index == NDTreeModel.LEAF) { // leaf return -1; @@ -164,18 +166,18 @@ private static int next(NDTreeNode node, Object key) { } private static int nextChild(NDTreeNode node, float value) { - if (value < node.splitValue) { - return node.leftChild; + if (value < node.getSplitValue()) { + return node.getLeftChild(); } else { - return node.rightChild; + return node.getRightChild(); } } private static int nextChild(NDTreeNode node, Number value) { - if (value.doubleValue() < node.splitValue) { - return node.leftChild; + if (value.doubleValue() < node.getSplitValue()) { + return node.getLeftChild(); } else { - return node.rightChild; + return node.getRightChild(); } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/models/SplineModel.java b/core/src/main/java/com/airbnb/aerosolve/core/models/SplineModel.java index 19d429ba..9599add0 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/models/SplineModel.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/models/SplineModel.java @@ -4,16 +4,25 @@ import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.function.Spline; +import com.airbnb.aerosolve.core.functions.Spline; import com.airbnb.aerosolve.core.util.Util; -import lombok.Getter; -import lombok.Setter; - +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.io.Serializable; -import java.util.*; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import lombok.Getter; +import lombok.Setter; // A linear piecewise spline based model with a spline per feature. // See http://en.wikipedia.org/wiki/Generalized_additive_model @@ -28,7 +37,7 @@ public class SplineModel extends AbstractModel { private int numBins; @Getter @Setter - private Map> weightSpline; + private Map weightSpline; @Getter @Setter // Cap on the L_infinity norm of the spline. Defaults to 0 which is no cap. @@ -40,8 +49,8 @@ public static class WeightSpline implements Serializable { public WeightSpline() { } - public WeightSpline(float minVal, float maxVal, int numBins) { - splineWeights = new float[numBins]; + public WeightSpline(double minVal, double maxVal, int numBins) { + splineWeights = new double[numBins]; spline = new Spline(minVal, maxVal, splineWeights); } @@ -49,27 +58,32 @@ public void resample(int newBins) { spline.resample(newBins); splineWeights = spline.getWeights(); } + public Spline spline; - public float[] splineWeights; - public float L1Norm() { - float sum = 0.0f; + + public double[] splineWeights; + + public double L1Norm() { + double sum = 0.0f; for (int i = 0; i < splineWeights.length; i++) { sum += Math.abs(splineWeights[i]); } return sum; } - public float LInfinityNorm() { - float best = 0.0f; + + public double LInfinityNorm() { + double best = 0.0f; for (int i = 0; i < splineWeights.length; i++) { best = Math.max(best, Math.abs(splineWeights[i])); } return best; } + public void LInfinityCap(float cap) { if (cap <= 0.0f) return; - float currentNorm = this.LInfinityNorm(); + double currentNorm = this.LInfinityNorm(); if (currentNorm > cap) { - float scale = cap / currentNorm; + double scale = cap / currentNorm; for (int i = 0; i < splineWeights.length; i++) { splineWeights[i] *= scale; } @@ -77,54 +91,70 @@ public void LInfinityCap(float cap) { } } - public SplineModel() { + public SplineModel(FeatureRegistry registry) { + super(registry); } public void initForTraining(int numBins) { this.numBins = numBins; - weightSpline = new HashMap<>(); + weightSpline = new Object2ObjectOpenHashMap<>(); } @Override - public float scoreItem(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - return scoreFlatFeatures(flatFeatures); + public double scoreItem(FeatureVector combinedItem) { + return scoreItemInternal(combinedItem, null, null); + } + + private double scoreItemInternal(FeatureVector combinedItem, + PriorityQueue> scores, + List scoreRecordsList) { + double sum = 0.0d; + + for (FeatureValue value : combinedItem) { + WeightSpline ws = weightSpline.get(value.feature()); + if (ws == null) + continue; + double val = value.value(); + double subscore = ws.spline.evaluate(val); + sum += subscore; + if (scores != null) { + scores.add(new AbstractMap.SimpleEntry<>(value, subscore)); + } + if (scoreRecordsList != null) { + DebugScoreRecord record = new DebugScoreRecord(); + record.setFeatureFamily(value.feature().family().name()); + record.setFeatureName(value.feature().name()); + record.setFeatureValue(val); + record.setFeatureWeight(subscore); + scoreRecordsList.add(record); + } + } + return sum; } @Override - public float debugScoreItem(FeatureVector combinedItem, + public double debugScoreItem(FeatureVector combinedItem, StringBuilder builder) { - Map> flatFeatures = Util.flattenFeature(combinedItem); - float sum = 0.0f; - - PriorityQueue> scores = + PriorityQueue> scores = new PriorityQueue<>(100, new LinearModel.EntryComparator()); - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightSpline.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightSpline ws = familyWeightMap.get(feature.getKey()); - if (ws == null) continue; - float val = feature.getValue().floatValue(); - float subscore = ws.spline.evaluate(val); - sum += subscore; - String str = featureFamily.getKey() + ":" + feature.getKey() + "=" + val - + " = " + subscore + "
\n"; - scores.add(new AbstractMap.SimpleEntry(str, subscore)); - } - } + double sum = scoreItemInternal(combinedItem, scores, null); + final int MAX_COUNT = 100; builder.append("Top scores ===>\n"); if (!scores.isEmpty()) { int count = 0; float subsum = 0.0f; while (!scores.isEmpty()) { - Map.Entry entry = scores.poll(); - builder.append(entry.getKey()); - float val = entry.getValue(); - subsum += val; + Map.Entry entry = scores.poll(); + Feature feature = entry.getKey().feature(); + double subscore = entry.getValue(); + String str = feature.family().name() + ":" + feature.name() + + "=" + entry.getKey().value() + + " = " + subscore + "
\n"; + builder.append(str); + subsum += subscore; count = count + 1; if (count >= MAX_COUNT) { builder.append("Leftover = " + (sum - subsum) + '\n'); @@ -139,93 +169,51 @@ public float debugScoreItem(FeatureVector combinedItem, @Override public List debugScoreComponents(FeatureVector combinedItem) { - Map> flatFeatures = Util.flattenFeature(combinedItem); List scoreRecordsList = new ArrayList<>(); - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightSpline.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightSpline ws = familyWeightMap.get(feature.getKey()); - if (ws == null) continue; - float val = feature.getValue().floatValue(); - float weight = ws.spline.evaluate(val); - DebugScoreRecord record = new DebugScoreRecord(); - record.setFeatureFamily(featureFamily.getKey()); - record.setFeatureName(feature.getKey()); - record.setFeatureValue(val); - record.setFeatureWeight(weight); - scoreRecordsList.add(record); - } - } + scoreItemInternal(combinedItem, null, scoreRecordsList); return scoreRecordsList; } // Updates the gradient - public void update(float grad, - float learningRate, - Map> flatFeatures) { - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightSpline.get(featureFamily.getKey()); - if (familyWeightMap == null) continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightSpline ws = familyWeightMap.get(feature.getKey()); - if (ws == null) continue; - float val = feature.getValue().floatValue(); - updateWeightSpline(val, grad, learningRate,ws); - } + public void update(double grad, + double learningRate, + FeatureVector vector) { + for (FeatureValue value : vector) { + WeightSpline ws = weightSpline.get(value.feature()); + if (ws == null) + continue; + double val = value.value(); + updateWeightSpline(val, grad, learningRate, ws); } } @Override - public void onlineUpdate(float grad, float learningRate, Map> flatFeatures) { - update(grad, learningRate, flatFeatures); + public void onlineUpdate(double grad, double learningRate, FeatureVector vector) { + update(grad, learningRate, vector); } // Adds a new spline - public void addSpline(String family, String feature, float minVal, float maxVal, Boolean overwrite) { + public void addSpline(Feature feature, double minVal, double maxVal, + boolean overwrite) { // if overwrite=true, we overwrite an existing spline, otherwise we don't modify an existing spline - Map featFamily = weightSpline.get(family); - if (featFamily == null) { - featFamily = new HashMap<>(); - weightSpline.put(family, featFamily); - } - - if (overwrite || !featFamily.containsKey(feature)) { + if (overwrite || !weightSpline.containsKey(feature)) { if (maxVal <= minVal) { maxVal = minVal + 1.0f; } WeightSpline ws = new WeightSpline(minVal, maxVal, numBins); - featFamily.put(feature, ws); + weightSpline.put(feature, ws); } } - private void updateWeightSpline(float val, - float grad, - float learningRate, + private void updateWeightSpline(double val, + double grad, + double learningRate, WeightSpline ws) { ws.spline.update(-grad * learningRate, val); ws.LInfinityCap(splineNormCap); } - public float scoreFlatFeatures(Map> flatFeatures) { - float sum = 0.0f; - - for (Map.Entry> featureFamily : flatFeatures.entrySet()) { - Map familyWeightMap = weightSpline.get(featureFamily.getKey()); - if (familyWeightMap == null) - continue; - for (Map.Entry feature : featureFamily.getValue().entrySet()) { - WeightSpline ws = familyWeightMap.get(feature.getKey()); - if (ws == null) - continue; - float val = feature.getValue().floatValue(); - sum += ws.spline.evaluate(val); - } - } - return sum; - } - @Override public void save(BufferedWriter writer) throws IOException { ModelHeader header = new ModelHeader(); @@ -233,32 +221,24 @@ public void save(BufferedWriter writer) throws IOException { header.setNumHidden(numBins); header.setSlope(slope); header.setOffset(offset); - long count = 0; - for (Map.Entry> familyMap : weightSpline.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - count++; - } - } - header.setNumRecords(count); + header.setNumRecords(weightSpline.size()); ModelRecord headerRec = new ModelRecord(); headerRec.setModelHeader(header); writer.write(Util.encode(headerRec)); writer.newLine(); - for (Map.Entry> familyMap : weightSpline.entrySet()) { - for (Map.Entry feature : familyMap.getValue().entrySet()) { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily(familyMap.getKey()); - record.setFeatureName(feature.getKey()); - ArrayList arrayList = new ArrayList(); - for (int i = 0; i < feature.getValue().splineWeights.length; i++) { - arrayList.add((double) feature.getValue().splineWeights[i]); - } - record.setWeightVector(arrayList); - record.setMinVal(feature.getValue().spline.getMinVal()); - record.setMaxVal(feature.getValue().spline.getMaxVal()); - writer.write(Util.encode(record)); - writer.newLine(); + for (Map.Entry entry : weightSpline.entrySet()) { + ModelRecord record = new ModelRecord(); + record.setFeatureFamily(entry.getKey().family().name()); + record.setFeatureName(entry.getKey().name()); + ArrayList arrayList = new ArrayList(); + for (int i = 0; i < entry.getValue().splineWeights.length; i++) { + arrayList.add(entry.getValue().splineWeights[i]); } + record.setWeightVector(arrayList); + record.setMinVal(entry.getValue().spline.getMinVal()); + record.setMaxVal(entry.getValue().spline.getMaxVal()); + writer.write(Util.encode(record)); + writer.newLine(); } writer.flush(); } @@ -269,25 +249,20 @@ protected void loadInternal(ModelHeader header, BufferedReader reader) throws IO numBins = header.getNumHidden(); slope = header.getSlope(); offset = header.getOffset(); - weightSpline = new HashMap<>(); + weightSpline = new Reference2ObjectOpenHashMap<>(); for (long i = 0; i < rows; i++) { String line = reader.readLine(); ModelRecord record = Util.decodeModel(line); String family = record.getFeatureFamily(); String name = record.getFeatureName(); - Map inner = weightSpline.get(family); - if (inner == null) { - inner = new HashMap<>(); - weightSpline.put(family, inner); - } - float minVal = (float) record.getMinVal(); - float maxVal = (float) record.getMaxVal(); + double minVal = record.getMinVal(); + double maxVal = record.getMaxVal(); WeightSpline vec = new WeightSpline(minVal, maxVal, numBins); for (int j = 0; j < numBins; j++) { vec.splineWeights[j] = record.getWeightVector().get(j).floatValue(); } - inner.put(name, vec); + weightSpline.put(registry.feature(family, name), vec); } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelConfig.java b/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelConfig.java index 44419c8a..a41573e9 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelConfig.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelConfig.java @@ -1,7 +1,7 @@ package com.airbnb.aerosolve.core.scoring; +import lombok.Builder; import lombok.Getter; -import lombok.experimental.Builder; @Builder public class ModelConfig { diff --git a/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelScorer.java b/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelScorer.java index 5b661256..7aeb9461 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelScorer.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/scoring/ModelScorer.java @@ -4,27 +4,38 @@ import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.models.AbstractModel; import com.airbnb.aerosolve.core.models.ModelFactory; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; import com.airbnb.aerosolve.core.transforms.Transformer; import com.google.common.base.Optional; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; -import lombok.extern.slf4j.Slf4j; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; +import lombok.extern.slf4j.Slf4j; @Slf4j public class ModelScorer { private final AbstractModel model; private final Transformer transformer; + private final Feature biasFeature; - public ModelScorer(BufferedReader reader, ModelConfig model) throws IOException { - Optional modelOpt = ModelFactory.createFromReader(reader); + public ModelScorer(BufferedReader reader, ModelConfig modelConfig) + throws IOException { + // Let the model create the FeatureRegistry because we're using one Transformer per Model in + // this class. + FeatureRegistry registry = new FeatureRegistry(); + Optional modelOpt = ModelFactory.createFromReader(reader, registry); + if (!modelOpt.isPresent()) { + throw new IllegalStateException("Could not create model from reader"); + } this.model = modelOpt.get(); - Config modelConfig = ConfigFactory.load(model.getConfigName()); - this.transformer = new Transformer(modelConfig, model.getKey()); + Config transformerConfig = ConfigFactory.load(modelConfig.getConfigName()); + this.transformer = new Transformer(transformerConfig, modelConfig.getKey(), + this.model.registry(), this.model); + this.biasFeature = this.model.registry().feature("BIAS", "B"); } /* @@ -40,9 +51,17 @@ public double rawProbability(Example example) { return model.scoreProbability(score(example)); } - public float score(Example example) { - FeatureVector featureVector = example.getExample().get(0); - transformer.combineContextAndItems(example); + public double score(Example example) { + // TODO (Brad): This is really kind of odd. Why have an example if we assume it's one item?! + FeatureVector featureVector = example.iterator().next(); + + // TODO (Brad): Maybe this should be a part of the transform itself. Why is it important? Does + // everyone want it? + featureVector.putString(biasFeature); + + // TODO (Brad): Abusing the fact that transforms are in place. Let's fix this too. + example.transform(transformer); + return model.scoreItem(featureVector); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransform.java index 8a9a36cf..e5dc75ad 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransform.java @@ -1,67 +1,82 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; import com.typesafe.config.Config; - -import java.util.Map; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; /** * Given a fieldName1, low, upper key * Remaps fieldName2's key2 value such that low = 0, upper = 1.0 thus approximating * the percentile using linear interpolation. */ -public class ApproximatePercentileTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String lowKey; - private String upperKey; - private String key2; - private String outputName; - private String outputKey; - private double minDiff; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class ApproximatePercentileTransform + extends ConfigurableTransform { - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - lowKey = config.getString(key + ".low"); - upperKey = config.getString(key + ".upper"); - minDiff = config.getDouble(key + ".minDiff"); - key2 = config.getString(key + ".key2"); - outputName = config.getString(key + ".output"); - outputKey = config.getString(key + ".outputKey"); - } + protected String lowFamilyName; + protected String lowFeatureName; + protected String upperFamilyName; + protected String upperFeatureName; + protected String valueFamilyName; + protected String valueFeatureName; + protected String outputFamilyName; + protected String outputFeatureName; + protected double minDiff; - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); + @Setter(AccessLevel.NONE) + protected Feature upperFeature; + @Setter(AccessLevel.NONE) + protected Feature lowFeature; + @Setter(AccessLevel.NONE) + protected Feature valueFeature; + @Setter(AccessLevel.NONE) + protected Feature outputFeature; - if (floatFeatures == null) { - return; - } + @Override + public ApproximatePercentileTransform configure(Config config, String key) { + return lowFamilyName(stringFromConfig(config, key, ".field1")) + .lowFeatureName(stringFromConfig(config, key, ".low")) + .upperFamilyName(stringFromConfig(config, key, ".field1")) + .upperFeatureName(stringFromConfig(config, key, ".upper")) + .valueFamilyName(stringFromConfig(config, key, ".field2")) + .valueFeatureName(stringFromConfig(config, key, ".key2")) + .minDiff(doubleFromConfig(config, key, ".minDiff")) + .outputFamilyName(stringFromConfig(config, key, ".output")) + .outputFeatureName(stringFromConfig(config, key, ".outputKey")); + } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + protected void setup() { + super.setup(); + this.lowFeature = registry.feature(lowFamilyName, lowFeatureName); + this.upperFeature = registry.feature(upperFamilyName, upperFeatureName); + this.valueFeature = registry.feature(valueFamilyName, valueFeatureName); + this.outputFeature = registry.feature(outputFamilyName, outputFeatureName); + } - Map feature2 = floatFeatures.get(fieldName2); - if (feature2 == null) { - return; - } + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && + vector.containsKey(lowFeature) && + vector.containsKey(upperFeature) && + vector.containsKey(valueFeature); + } - Double val = feature2.get(key2); - if (val == null) { - return; - } + @Override + protected void doTransform(MultiFamilyVector vector) { - Double low = feature1.get(lowKey); - Double upper = feature1.get(upperKey); - - if (low == null || upper == null) { - return; - } + double low = vector.getDouble(lowFeature); + double upper = vector.getDouble(upperFeature); + double val = vector.getDouble(valueFeature); // Abstain if the percentiles are too close. double denom = upper - low; @@ -69,17 +84,15 @@ public void doTransform(FeatureVector featureVector) { return; } - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - Double outVal = 0.0; + double outVal; if (val <= low) { - outVal = 0.0; + outVal = 0.0d; } else if (val >= upper) { - outVal = 1.0; + outVal = 1.0d; } else { outVal = (val - low) / denom; } - output.put(outputKey, outVal); + vector.put(outputFeature, outVal); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransform.java deleted file mode 100644 index 88059701..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransform.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; - -import java.util.Map; -import java.util.Map.Entry; - -/** - * Buckets float features and places them in a new float column. - */ -public class BucketFloatTransform implements Transform { - private String fieldName1; - private double bucket; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - bucket = config.getDouble(key + ".bucket"); - outputName = config.getString(key + ".output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null || feature1.isEmpty()) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - for (Entry feature : feature1.entrySet()) { - Double dbl = TransformUtil.quantize(feature.getValue(), bucket); - Double newVal = feature.getValue() - dbl; - String name = feature.getKey() + '[' + bucket + "]=" + dbl; - output.put(name, newVal); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketTransform.java new file mode 100644 index 00000000..246f8125 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/BucketTransform.java @@ -0,0 +1,39 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; +import com.airbnb.aerosolve.core.util.TransformUtil; +import com.typesafe.config.Config; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; + +/** + * Buckets float features and places them in a new float column. + */ +@LegacyNames("bucket_float") +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class BucketTransform extends BaseFeaturesTransform { + private double bucket; + + public BucketTransform configure(Config config, String key) { + return super.configure(config, key) + .bucket(doubleFromConfig(config, key, ".bucket")); + } + + @Override + public void doTransform(MultiFamilyVector featureVector) { + for (FeatureValue value : getInput(featureVector)) { + double dbl = TransformUtil.quantize(value.value(), bucket); + double newVal = value.value() - dbl; + String name = value.feature().name() + '[' + bucket + "]=" + dbl; + featureVector.put(outputFamily.feature(name), newVal); + } + } +} \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapFloatTransform.java deleted file mode 100644 index 628faea1..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapFloatTransform.java +++ /dev/null @@ -1,30 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.transforms.types.FloatTransform; -import com.typesafe.config.Config; - -import java.util.List; -import java.util.Map; - -public class CapFloatTransform extends FloatTransform { - private List keys; - private double lowerBound; - private double upperBound; - - @Override - public void init(Config config, String key) { - keys = config.getStringList(key + ".keys"); - lowerBound = config.getDouble(key + ".lower_bound"); - upperBound = config.getDouble(key + ".upper_bound"); - } - - @Override - public void output(Map input, Map output) { - for (String key : keys) { - Double v = input.get(key); - if (v != null) { - output.put(key, Math.min(upperBound, Math.max(lowerBound, v))); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapTransform.java new file mode 100644 index 00000000..b7418d73 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CapTransform.java @@ -0,0 +1,21 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BoundedFeaturesTransform; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +@LegacyNames("cap_float") +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class CapTransform extends BoundedFeaturesTransform { + + @Override + public void doTransform(MultiFamilyVector featureVector) { + for (FeatureValue value : getInput(featureVector)) { + double v = value.value(); + featureVector.put(produceOutputFeature(value.feature()), + Math.min(upperBound, Math.max(lowerBound, v))); + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransform.java index 32949a78..b8f7312c 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransform.java @@ -1,8 +1,12 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.transforms.types.StringTransform; - +import com.airbnb.aerosolve.core.transforms.base.StringTransform; import com.typesafe.config.Config; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; /** * Converts strings to either all lowercase or all uppercase @@ -11,28 +15,20 @@ * "output" optionally specifies the key of the output feature, if it is not given the transform * overwrites / replaces the input feature */ -public class ConvertStringCaseTransform extends StringTransform { - private boolean convertToUppercase; - - @Override - public void init(Config config, String key) { - convertToUppercase = config.getBoolean(key + ".convert_to_uppercase"); +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class ConvertStringCaseTransform extends StringTransform { + protected boolean convertToUppercase; + + public ConvertStringCaseTransform configure(Config config, String key) { + return super.configure(config, key) + .convertToUppercase(booleanFromConfig(config, key, ".convert_to_uppercase")); } @Override public String processString(String rawString) { - if (rawString == null) { - return null; - } - - String convertedString; - - if (convertToUppercase) { - convertedString = rawString.toUpperCase(); - } else { - convertedString = rawString.toLowerCase(); - } - - return convertedString; + return convertToUppercase ? rawString.toUpperCase() : rawString.toLowerCase(); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CrossTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CrossTransform.java index 1684a02d..9ccfb278 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CrossTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CrossTransform.java @@ -1,52 +1,97 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFamilyTransform; +import com.airbnb.aerosolve.core.util.TransformUtil; import com.typesafe.config.Config; -import java.util.HashSet; -import java.util.Set; -import java.util.Map; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; /** * Created by hector_yee on 8/25/14. * Takes the cross product of stringFeatures named in field1 and field2 * and places it in a stringFeature with family name specified in output. */ -public class CrossTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String outputName; +@LegacyNames({"float_cross_float", "string_cross_float", "self_cross"}) +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class CrossTransform extends DualFamilyTransform + implements ModelAware { + private Double bucket; + private Double cap; + private boolean putValue = true; + private boolean ignoreNonModelFeatures = false; - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - outputName = config.getString(key + ".output"); + private AbstractModel model; + + @Setter(AccessLevel.NONE) + private boolean selfCross = false; + + public CrossTransform configure(Config config, String key) { + return super.configure(config, key) + .bucket(doubleFromConfig(config, key, ".bucket", false)) + .cap(doubleFromConfig(config, key, ".cap", false)) + .putValue(shouldPutValue(config, key)) + .ignoreNonModelFeatures(booleanFromConfig(config, key, ".ignore_non_model_features")); + } + + private boolean shouldPutValue(Config config, String key) { + String type = getTransformType(config, key); + return booleanFromConfig(config, key, ".putValue") + || (type != null && type.endsWith("float")) + || bucket != null || cap != null; } @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) return; - - Set set1 = stringFeatures.get(fieldName1); - if (set1 == null || set1.isEmpty()) return; - Set set2 = stringFeatures.get(fieldName2); - if (set2 == null || set2.isEmpty()) return; - - Set output = stringFeatures.get(outputName); - if (output == null) { - output = new HashSet<>(); - stringFeatures.put(outputName, output); - } - cross(set1, set2, output); + protected void setup() { + super.setup(); + selfCross = otherFamily == inputFamily; } - public static void cross(Set set1, Set set2, Set output) { - for (String s1 : set1) { - String prefix = s1 + '^'; - for (String s2 : set2) { - output.add(prefix + s2); + @Override + protected void doTransform(MultiFamilyVector featureVector) { + FamilyVector otherFamilyVector = featureVector.get(otherFamily); + for (FeatureValue value : getInput(featureVector)) { + String separator = "^"; + if (putValue) { + double doubleVal = value.value(); + if (cap != null) { + doubleVal = Math.min(cap, doubleVal); + } + if (bucket != null) { + double quantizedVal = TransformUtil.quantize(doubleVal, bucket); + separator = "=" + quantizedVal + "^"; + } + } + for (FeatureValue otherValue : otherFamilyVector) { + if (value.feature().equals(otherValue.feature())) { + continue; + } + Feature outputFeature = outputFamily.cross(value.feature(), otherValue.feature(), separator); + boolean neededForModel = !ignoreNonModelFeatures + || model == null + || model.needsFeature(outputFeature); + // For self crosses, there is no reason to cross both ways. + boolean neededIfSelfCross = !selfCross + || value.feature().compareTo(otherValue.feature()) < 0; + if (neededForModel && neededIfSelfCross) { + if (putValue) { + featureVector.put(outputFeature, otherValue.value()); + } else { + featureVector.putString(outputFeature); + } + } } } } - } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransform.java deleted file mode 100644 index 98381b94..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransform.java +++ /dev/null @@ -1,123 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigObject; -import com.typesafe.config.ConfigValue; - -import java.util.*; -import java.util.Map.Entry; - -/** - * A custom quantizer that quantizes features based on upper limits and bucket sizes from config - * "field1" specifies feature family name. - * If "select_features" is specified, we only transform features in the select_features list. - * If "exclude_features" is specified, we transform features that are not in the exclude_features list. - * If both "select_features" and "exclude_features" are specified, we transform features that are in - * "select_features" list and not in "exclude_features" list. - */ -public class CustomLinearLogQuantizeTransform implements Transform { - - private String fieldName1; - private String outputName; - private TreeMap limitBucketPairsMap; - private double upperLimit; - - private List excludeFeatures; - private List selectFeatures; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - outputName = config.getString(key + ".output"); - if (config.hasPath(key + ".exclude_features")) { - excludeFeatures = config.getStringList(key + ".exclude_features"); - } - - if (config.hasPath(key + ".select_features")) { - selectFeatures = config.getStringList(key + ".select_features"); - } - limitBucketPairsMap = - parseTokensOutOfLimitBucketPairs(config.getObjectList(key + ".limit_bucket")); - upperLimit = limitBucketPairsMap.lastKey(); - } - - private static TreeMap parseTokensOutOfLimitBucketPairs( - List pairs) { - TreeMap parsedTokensMap = new TreeMap<>(); - for (ConfigObject configObject : pairs) { - List> entries = new ArrayList<>(configObject.entrySet()); - parsedTokensMap.put(Double.parseDouble(entries.get(0).getKey()), - Double.parseDouble(entries.get(0).getValue().unwrapped().toString())); - } - - return parsedTokensMap; - } - - private String transformFeature(String featureName, - double featureValue, - StringBuilder sb) { - sb.setLength(0); - sb.append(featureName); - boolean isValueNegative = false; - if (featureValue < 0.0) { - isValueNegative = true; - featureValue = -featureValue; - } - - if (featureValue < 1e-2) { - sb.append("=0.0"); - } else { - double limit; - double bucket; - if (featureValue >= upperLimit) { - featureValue = upperLimit; - bucket = limitBucketPairsMap.get(upperLimit); - } else { - limit = limitBucketPairsMap.higherKey(featureValue); - bucket = limitBucketPairsMap.get(limit); - } - - Double val = TransformUtil.quantize(featureValue, bucket) * 1000; - - sb.append('='); - if (isValueNegative) { - sb.append('-'); - } - - sb.append(val.intValue()/1000.0); - } - - return sb.toString(); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null || feature1.isEmpty()) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - StringBuilder sb = new StringBuilder(); - for (Entry feature : feature1.entrySet()) { - if ((excludeFeatures == null || !excludeFeatures.contains(feature.getKey())) && - (selectFeatures == null || selectFeatures.contains(feature.getKey()))) { - String transformedFeature = transformFeature(feature.getKey(), - feature.getValue(), - sb); - - output.add(transformedFeature); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransform.java deleted file mode 100644 index b64d4547..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransform.java +++ /dev/null @@ -1,82 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.*; -import java.util.Map.Entry; - -/** - * Quantize the floatFeature named in "field1" with buckets in "bucket" before placing - * it in the stringFeature named "output". - * "field1" specifies feature family name. - * If "select_features" is specified, we only transform features in the select_features list. - * If "exclude_features" is specified, we transform features that are not in the exclude_features list. - * If both "select_features" and "exclude_features" are specified, we transform features that are in - * "select_features" list and not in "exclude_features" list. - */ - -public class CustomMultiscaleQuantizeTransform implements Transform { - private String fieldName1; - private List buckets; - private String outputName; - private List excludeFeatures; - private List selectFeatures; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - buckets = config.getDoubleList(key + ".buckets"); - outputName = config.getString(key + ".output"); - if (config.hasPath(key + ".exclude_features")) { - excludeFeatures = config.getStringList(key + ".exclude_features"); - } - - if (config.hasPath(key + ".select_features")) { - selectFeatures = config.getStringList(key + ".select_features"); - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - for (Entry feature : feature1.entrySet()) { - if ((excludeFeatures == null || !excludeFeatures.contains(feature.getKey())) && - (selectFeatures == null || selectFeatures.contains(feature.getKey()))) { - transformAndAddFeature(buckets, - feature.getKey(), - feature.getValue(), - output); - } - } - } - - public static void transformAndAddFeature(List buckets, - String featureName, - Double featureValue, - Set output) { - if (featureValue == 0.0) { - output.add(featureName + "=0"); - return; - } - - for (double bucket : buckets) { - double quantized = TransformUtil.quantize(featureValue, bucket); - output.add(featureName + '[' + bucket + "]=" + quantized); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutFloatTransform.java deleted file mode 100644 index 60c8cca7..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutFloatTransform.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.List; -import java.util.Map; - -/* - remove features larger than upperBound or smaller than lowerBound - */ -public class CutFloatTransform implements Transform { - private String fieldName1; - private List keys; - private double lowerBound; - private double upperBound; - private String outputName; // output family name, if not specified, output to fieldName1 - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - keys = config.getStringList(key + ".keys"); - if (config.hasPath(key + ".lower_bound")) { - lowerBound = config.getDouble(key + ".lower_bound"); - } else { - lowerBound = -Double.MAX_VALUE; - } - if (config.hasPath(key + ".upper_bound")) { - upperBound = config.getDouble(key + ".upper_bound"); - } else { - upperBound = Double.MAX_VALUE; - } - if (config.hasPath(key + ".output")) { - outputName = config.getString(key + ".output"); - } else { - outputName = fieldName1; - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - Map feature2 = Util.getOrCreateFloatFeature(outputName, floatFeatures); - for (String key : keys) { - Double v = feature1.get(key); - if (v != null) { - if (v > upperBound || v < lowerBound) { - if (feature2 == feature1) { - feature2.remove(key); - } - } else if (feature2 != feature1) { - feature2.put(key, v); - } - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutTransform.java new file mode 100644 index 00000000..d64ca119 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/CutTransform.java @@ -0,0 +1,36 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BoundedFeaturesTransform; +import java.util.ArrayList; +import java.util.List; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +/** + * remove features larger than upperBound or smaller than lowerBound + */ +@LegacyNames("cut_float") +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class CutTransform extends BoundedFeaturesTransform { + + @Override + public void doTransform(MultiFamilyVector featureVector) { + List toDelete = new ArrayList<>(); + for (FeatureValue value : getInput(featureVector)) { + double v = value.value(); + if (v > upperBound || v < lowerBound) { + if (inputFamily.equals(outputFamily)) { + toDelete.add(value.feature()); + } + } else if (!inputFamily.equals(outputFamily)) { + featureVector.put(produceOutputFeature(value.feature()), v); + } + } + for (Feature feature : toDelete) { + featureVector.remove(feature); + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateDiffTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateDiffTransform.java index a08bb1d1..ac1f8092 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateDiffTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateDiffTransform.java @@ -1,67 +1,45 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFamilyTransform; import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; - -import java.text.DateFormat; -import java.text.ParseException; -import java.text.SimpleDateFormat; -import java.util.Date; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.joda.time.DateTime; +import org.joda.time.Duration; /** * output = date_diff(field1, field2) * get the date difference between dates in features of key "field1" and * dates in features of key "field2" */ -public class DateDiffTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String outputName; - private final static SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd"); - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - outputName = config.getString(key + ".output"); - } +// TODO (Brad): Configurable date time format. +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DateDiffTransform extends DualFamilyTransform { @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - Map> floatFeatures = featureVector.getFloatFeatures(); - if (stringFeatures == null || floatFeatures == null) { - return ; - } - - Set feature1 = stringFeatures.get(fieldName1); - Set feature2 = stringFeatures.get(fieldName2); - if (feature1 == null || feature2 == null) { - return ; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); + protected void doTransform(MultiFamilyVector featureVector) { + // TODO (Brad): I made a mistake when refactoring this because it was unintuitive that the date + // in field1 is the end date. Is that intentional? try { - for (String endDateStr : feature1) { - Date endDate = format.parse(endDateStr); - for (String startDateStr : feature2) { - Date startDate = format.parse(startDateStr); - long diff = endDate.getTime() - startDate.getTime(); - long diffDays = TimeUnit.DAYS.convert(diff, TimeUnit.MILLISECONDS); - output.put(endDateStr + "-m-" + startDateStr, (double)diffDays); + for (FeatureValue value : getInput(featureVector)) { + String endDateStr = value.feature().name(); + DateTime endDate = Util.DATE_FORMAT.parseDateTime(endDateStr); + for (FeatureValue value2 : featureVector.get(otherFamily)) { + String startDateStr = value2.feature().name(); + DateTime startDate = Util.DATE_FORMAT.parseDateTime(startDateStr); + Duration duration = new Duration(startDate, endDate); + Feature feature = outputFamily.feature(endDateStr + "-m-" + startDateStr); + featureVector.put(feature, duration.getStandardDays()); } } - } catch (ParseException e) { + } catch (IllegalArgumentException e) { + // TODO (Brad): Better error handling e.printStackTrace(); - return ; } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateValTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateValTransform.java index 5c497130..7b33b6b1 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateValTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DateValTransform.java @@ -1,76 +1,84 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; import com.airbnb.aerosolve.core.util.Util; - import com.typesafe.config.Config; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; +import org.joda.time.DateTime; +import org.joda.time.DateTimeFieldType; -import java.text.SimpleDateFormat; -import java.util.*; -import java.text.ParseException; +import javax.validation.constraints.NotNull; /** * Get the date value from date string * "field1" specifies the key of feature - * "field2" specifies the type of date value + * "date_type" specifies the type of date value */ -public class DateValTransform implements Transform { - protected String fieldName1; +@Data +@EqualsAndHashCode(callSuper = false) +@Slf4j +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DateValTransform extends BaseFeaturesTransform { + protected DateTimeFieldType dateTimeFieldType; + + @NotNull protected String dateType; - protected String outputName; - protected final static SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd"); + @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - dateType = config.getString(key + ".date_type"); - outputName = config.getString(key + ".output"); + public DateValTransform configure(Config config, String key) { + return super.configure(config, key) + .dateType(stringFromConfig(config, key, ".date_type")); } @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) { - return ; - } + protected void setup() { + super.setup(); + dateTimeFieldType = getDateTimeFieldType(dateType); + } - Set feature1 = stringFeatures.get(fieldName1); - if (feature1 == null) { - return ; + private DateTimeFieldType getDateTimeFieldType(String dateType) { + switch (dateType) { + case "day_of_month": + return DateTimeFieldType.dayOfMonth(); + case "day_of_week": + return DateTimeFieldType.dayOfWeek(); + case "day_of_year": + return DateTimeFieldType.dayOfYear(); + case "year": + return DateTimeFieldType.year(); + case "month": + return DateTimeFieldType.monthOfYear(); + default: + return null; } + } - Util.optionallyCreateFloatFeatures(featureVector); - Map> floatFeatures = featureVector.getFloatFeatures(); - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - for (String dateStr: feature1) { + @Override + protected void doTransform(MultiFamilyVector featureVector) { + for (FeatureValue value : getInput(featureVector)) { + String dateStr = value.feature().name(); try { - Date date = format.parse(dateStr); - Calendar cal = Calendar.getInstance(); - cal.setTime(date); - double dateVal; - switch (dateType) { - case "day_of_month": - dateVal = cal.get(Calendar.DAY_OF_MONTH); - break; - case "day_of_week": - dateVal = cal.get(Calendar.DAY_OF_WEEK); - break; - case "day_of_year": - dateVal = cal.get(Calendar.DAY_OF_YEAR); - break; - case "year": - dateVal = cal.get(Calendar.YEAR); - break; - case "month": - dateVal = cal.get(Calendar.MONTH) + 1; - break; - default: - return ; + DateTime date = Util.DATE_FORMAT.parseDateTime(dateStr); + double dateVal = date.get(dateTimeFieldType); + if (dateTimeFieldType.equals(DateTimeFieldType.dayOfWeek())) { + // Joda DateTimes start the week with Monday. So, Sunday is 7. We mod 7 to bring it to + // 0 and add 1 to every day to offset. + dateVal = (double) ((((int) dateVal) % 7) + 1); } - output.put(dateStr, dateVal); - } catch (ParseException e) { - e.printStackTrace(); - continue ; + featureVector.put(outputFamily.feature(dateStr), dateVal); + } catch (IllegalArgumentException e) { + log.error("Error parsing date String %s with format %s: %s", + dateStr, Util.DATE_FORMAT.toString(), e.getMessage()); + // Let's just continue here. It doesn't seem worth aborting on a malformed String. + // Hopefully someone checks the logs when this happens. } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransform.java index ea7bea68..3d22c19f 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransform.java @@ -1,14 +1,22 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.models.DecisionTreeModel; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; import com.typesafe.config.Config; +import java.util.List; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import org.hibernate.validator.constraints.NotEmpty; -import java.io.Serializable; -import java.util.*; -import java.util.Map.Entry; +import javax.validation.constraints.NotNull; /** * Applies a decision tree transform to existing float features. @@ -16,38 +24,49 @@ * Emits the score to the float family output_score * Use tree.toHumanReadableTransform to generate the nodes list. */ -public class DecisionTreeTransform implements Transform { - private String outputLeaves; - private String outputScoreFamily; - private String outputScoreName; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DecisionTreeTransform extends ConfigurableTransform { + @NotNull + private String outputLeavesFamilyName; + @NotNull + private String outputScoreFamilyName; + @NotNull + private String outputScoreFeatureName; + @NotNull + @NotEmpty + private List nodes; + @Setter(AccessLevel.NONE) + private Family outputLeavesFamily; + @Setter(AccessLevel.NONE) + private Feature outputScoreFeature; + @Setter(AccessLevel.NONE) private DecisionTreeModel tree; @Override - public void configure(Config config, String key) { - outputLeaves = config.getString(key + ".output_leaves"); - outputScoreFamily = config.getString(key + ".output_score_family"); - outputScoreName = config.getString(key + ".output_score_name"); - List nodes = config.getStringList(key + ".nodes"); - tree = DecisionTreeModel.fromHumanReadableTransform(nodes); + public DecisionTreeTransform configure(Config config, String key) { + return outputLeavesFamilyName(stringFromConfig(config, key, ".output_leaves")) + .outputScoreFamilyName(stringFromConfig(config, key, ".output_score_family")) + .outputScoreFeatureName(stringFromConfig(config, key, ".output_score_name")) + .nodes(stringListFromConfig(config, key, ".nodes", true)); } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } + protected void setup() { + super.setup(); + outputLeavesFamily = registry.family(outputLeavesFamilyName); + outputScoreFeature = registry.feature(outputScoreFamilyName, outputScoreFeatureName); + tree = DecisionTreeModel.fromHumanReadableTransform(nodes, registry); + } - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set outputString = Util.getOrCreateStringFeature(outputLeaves, stringFeatures); - - Map outputFloat = Util.getOrCreateFloatFeature(outputScoreFamily, floatFeatures); - int leafIdx = tree.getLeafIndex(floatFeatures); - ModelRecord rec = tree.getStumps().get(leafIdx); - outputString.add(rec.featureName); - outputFloat.put(outputScoreName, rec.featureWeight); + @Override + public void doTransform(MultiFamilyVector featureVector) { + int leafIdx = tree.getLeafIndex(featureVector); + ModelRecord rec = tree.stumps().get(leafIdx); + featureVector.put(outputLeavesFamily.feature(rec.getFeatureName()), 1.0d); + featureVector.put(outputScoreFeature, rec.getFeatureWeight()); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransform.java index de71d34c..533a15da 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransform.java @@ -1,12 +1,18 @@ package com.airbnb.aerosolve.core.transforms; import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - -import java.util.Map; -import java.util.Set; - +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; import com.typesafe.config.Config; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; /** * Tokenizes and counts strings using a regex and optionally generates bigrams from the tokens @@ -14,82 +20,68 @@ * "regex" specifies the regex used to tokenize * "generateBigrams" specifies whether bigrams should also be generated */ -public class DefaultStringTokenizerTransform implements Transform { +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DefaultStringTokenizerTransform + extends BaseFeaturesTransform { public static final String BIGRAM_SEPARATOR = " "; - private String fieldName1; - private String regex; - private String outputName; - private boolean generateBigrams; - private String bigramsOutputName; + // TODO (Brad): Is there a good default regex? + private String regex = " "; + private boolean generateBigrams = false; + private String bigramsOutputFamilyName; + + + @Setter(AccessLevel.NONE) + private Family bigramsOutputFamily; @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - regex = config.getString(key + ".regex"); - outputName = config.getString(key + ".output"); - if (config.hasPath(key + ".generate_bigrams")) { - generateBigrams = config.getBoolean(key + ".generate_bigrams"); - } else { - generateBigrams = false; - } - if (generateBigrams) { - bigramsOutputName = config.getString(key + ".bigrams_output"); - } + public DefaultStringTokenizerTransform configure(Config config, String key) { + return super.configure(config, key) + .regex(stringFromConfig(config, key, ".regex")) + .generateBigrams(booleanFromConfig(config, key, ".generate_bigrams")) + .bigramsOutputFamilyName(stringFromConfig(config, key, ".bigrams_output")); } @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) { - return; + protected void setup() { + super.setup(); + if (generateBigrams && bigramsOutputFamilyName != null) { + bigramsOutputFamily = registry.family(bigramsOutputFamilyName); } + } - Set feature1 = stringFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + @Override + protected void doTransform(MultiFamilyVector featureVector) { - Util.optionallyCreateFloatFeatures(featureVector); - Map> floatFeatures = featureVector.getFloatFeatures(); - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - Map bigramOutput = null; - if (generateBigrams) { - bigramOutput = Util.getOrCreateFloatFeature(bigramsOutputName, floatFeatures); - } + for (FeatureValue value : getInput(featureVector)) { + if (value.feature().name() == null) { + continue; + } - for (String rawString : feature1) { - if (rawString == null) continue; - String[] tokenizedString = rawString.split(regex); - for (String token : tokenizedString) { + String previousToken = null; + for (String token : value.feature().name().split(regex)) { if (token.length() == 0) continue; - incrementOutput(token, output); - } - if (generateBigrams) { - String previousToken = null; - for (String token : tokenizedString) { - if (token.length() == 0) continue; - if (previousToken == null) { - previousToken = token; - } else { + incrementOutput(outputFamily.feature(token), featureVector); + if (generateBigrams) { + if (previousToken != null) { String bigram = previousToken + BIGRAM_SEPARATOR + token; - incrementOutput(bigram, bigramOutput); - previousToken = token; + incrementOutput(bigramsOutputFamily.feature(bigram), featureVector); } + previousToken = token; } } } } - private static void incrementOutput(String key, Map output) { - if (key == null || output == null) { - return; - } - if (output.containsKey(key)) { - double count = output.get(key); - output.put(key, (count + 1.0)); + private static void incrementOutput(Feature feature, FeatureVector vector) { + if (vector.containsKey(feature)) { + double count = vector.get(feature); + vector.put(feature, (count + 1.0)); } else { - output.put(key, 1.0); + vector.put(feature, 1.0); } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransform.java deleted file mode 100644 index caf64e81..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransform.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; - -import com.typesafe.config.Config; - -/** - * "fields" specifies a list of float feature families to be deleted - */ -public class DeleteFloatFeatureFamilyTransform implements Transform { - private List fieldNames; - - @Override - public void configure(Config config, String key) { - fieldNames = config.getStringList(key + ".fields"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - - if (fieldNames == null) { - return; - } - - for (String fieldName: fieldNames) { - Map feature = floatFeatures.get(fieldName); - if (feature != null) { - floatFeatures.remove(fieldName); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransform.java deleted file mode 100644 index 6118b007..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransform.java +++ /dev/null @@ -1,34 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import java.util.List; -import java.util.Map; - -public class DeleteFloatFeatureTransform implements Transform { - private String fieldName1; - private List keys; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - keys = config.getStringList(key + ".keys"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - for (String key : keys) { - feature1.remove(key); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureColumnTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureColumnTransform.java deleted file mode 100644 index e1eeb5f5..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureColumnTransform.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -// TODO: remove this once all configs have migrated over to the new transform names -public class DeleteStringFeatureColumnTransform extends DeleteStringFeatureFamilyTransform {} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransform.java deleted file mode 100644 index dd63b71b..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransform.java +++ /dev/null @@ -1,41 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; - -/** - * "fields" specifies a list of string feature families to be deleted - */ -public class DeleteStringFeatureFamilyTransform implements Transform { - private List fieldNames; - - @Override - public void configure(Config config, String key) { - fieldNames = config.getStringList(key + ".fields"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) { - return; - } - - if (fieldNames == null) { - return; - } - - for (String fieldName: fieldNames) { - Set feature = stringFeatures.get(fieldName); - if (feature != null) { - stringFeatures.remove(fieldName); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransform.java deleted file mode 100644 index e3a90428..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransform.java +++ /dev/null @@ -1,46 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; - -public class DeleteStringFeatureTransform implements Transform { - private String fieldName1; - private List keys; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - keys = config.getStringList(key + ".keys"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - - if (stringFeatures == null) { - return; - } - - Set feature1 = stringFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - List toDelete = new ArrayList(); - for(String feat : feature1) { - for (String key : keys) { - if (feat.startsWith(key)) { - toDelete.add(feat); - break; - } - } - } - for (String feat : toDelete) { - feature1.remove(feat); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteTransform.java new file mode 100644 index 00000000..a4383d86 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DeleteTransform.java @@ -0,0 +1,118 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; +import com.typesafe.config.Config; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; +import org.apache.commons.lang3.tuple.Pair; + +/** + * + */ +@LegacyNames({"delete_float_feature_family", + "delete_string_feature_column", + "delete_string_feature_family", + "delete_float_feature", + "delete_string_feature"}) +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DeleteTransform extends ConfigurableTransform { + private List familyNames; + private List> featureNames; + private boolean deleteByPrefix = false; + + private List families; + private List features; + + @Override + public DeleteTransform configure(Config config, String key) { + DeleteTransform transform = familyNames(stringListFromConfig(config, key, ".fields", false)); + String familyName = stringFromConfig(config, key, ".field1", false); + List featureList = stringListFromConfig(config, key, ".keys", false); + if (featureList != null && familyName != null) { + List> featurePairs = new ArrayList<>(); + for (String featureName : featureList) { + featurePairs.add(Pair.of(familyName, featureName)); + } + transform.featureNames(featurePairs); + } + boolean deleteByPre = booleanFromConfig(config, key, ".delete_by_prefix"); + if (!deleteByPre) { + String transformType = getTransformType(config, key); + deleteByPre = transformType != null && + (transformType.equals("delete_string_feature") || + transformType.equals("delete_string_column")); + } + transform.deleteByPrefix(deleteByPre); + return transform; + } + + @Override + protected void setup() { + super.setup(); + if (familyNames != null) { + families = new ArrayList<>(); + for (String familyName : familyNames) { + families.add(registry.family(familyName)); + } + } + if (featureNames != null) { + features = new ArrayList<>(); + for (Map.Entry pair : featureNames) { + features.add(registry.feature(pair.getKey(), pair.getValue())); + } + } + } + + @Override + protected void validate() { + super.validate(); + if (familyNames == null && featureNames == null) { + throw new IllegalArgumentException( "At least one of familyNames or featureNames must be set."); + } + } + + @Override + protected void doTransform(MultiFamilyVector vector) { + if (features != null) { + for (Feature feature : features) { + if (deleteByPrefix) { + FamilyVector familyVector = vector.get(feature.family()); + if (familyVector == null) { + continue; + } + List toDelete = new ArrayList<>(); + for (FeatureValue value : familyVector) { + if (value.feature().name().startsWith(feature.name())) { + toDelete.add(value.feature()); + } + } + for (Feature deleteFeature : toDelete) { + vector.removeDouble(deleteFeature); + } + } else { + vector.removeDouble(feature); + } + } + } + + if (families != null) { + for (Family family : families) { + vector.remove(family); + } + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DenseTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DenseTransform.java new file mode 100644 index 00000000..5dbee359 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DenseTransform.java @@ -0,0 +1,117 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.features.SimpleFeatureValue; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; +import com.typesafe.config.Config; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import org.hibernate.validator.constraints.NotEmpty; + +import javax.validation.constraints.NotNull; + +/* + Turn several float features into one dense feature, feature number must > 1 + 1. IF all float features are null, create a string feature, + with family name string_output, feature name output^null + 2. IF only one float feature is not null, create a float feature + with family name same as family of the only not null float feature + 3. Other cases create dense features + both 2 and 3, feature name: output^key keys. + */ +@LegacyNames("float_to_dense") +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DenseTransform extends ConfigurableTransform { + private static final int FEATURE_AVG_SIZE = 16; + + @NotNull + @NotEmpty + private List fields; + @NotNull + @NotEmpty + private List keys; + @NotNull + private String outputFamilyName; + @NotNull + private String outputStringFamilyName; + + @Setter(AccessLevel.NONE) + private List inputFeatures; + @Setter(AccessLevel.NONE) + private Family outputFamily; + @Setter(AccessLevel.NONE) + private Family outputStringFamily; + + @Override + public DenseTransform configure(Config config, String key) { + return outputStringFamilyName(stringFromConfig(config, key, ".string_output")) + .outputFamilyName(stringFromConfig(config, key, ".output")) + .keys(stringListFromConfig(config, key, ".keys", true)) + .fields(stringListFromConfig(config, key, ".fields", true)); + } + + @Override + protected void setup() { + super.setup(); + outputFamily = registry.family(outputFamilyName); + outputStringFamily = registry.family(outputStringFamilyName); + inputFeatures = new ArrayList<>(); + for (int i = 0; i < fields.size(); i++) { + String familyName = fields.get(i); + String featureName = keys.get(i); + inputFeatures.add(registry.feature(familyName, featureName)); + } + } + + @Override + protected void validate() { + super.validate(); + if (fields.size() != keys.size() || fields.size() <= 1) { + String msg = String.format("fields size {} keys size {}", fields.size(), keys.size()); + throw new RuntimeException(msg); + } + } + + @Override + public void doTransform(MultiFamilyVector featureVector) { + List values = inputFeatures.stream() + .filter(featureVector::containsKey) + .map((Feature feature) -> + SimpleFeatureValue.of(feature, featureVector.getDouble(feature))) + .collect(Collectors.toList()); + + switch (values.size()) { + case 0: + Feature outputFeature = outputStringFamily.feature(outputFamilyName + "^null"); + featureVector.putString(outputFeature); + break; + case 1: + FeatureValue value = values.get(0); + String name = outputFamilyName + "^" + value.feature().name(); + featureVector.put(outputFamily.feature(name), value.value()); + break; + default: + double[] output = new double[values.size()]; + StringBuilder nameBuilder = new StringBuilder(); + for (int i = 0; i < values.size(); i++) { + output[i] = values.get(i).value(); + nameBuilder.append("^"); + nameBuilder.append(values.get(i).feature().name()); + } + featureVector.putDense(registry.family(nameBuilder.toString()), output); + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DivideTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DivideTransform.java index 05dcc562..7e302f44 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/DivideTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/DivideTransform.java @@ -1,68 +1,54 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.OtherFeatureTransform; import com.typesafe.config.Config; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; /** * output = field1.keys / (field2.key2 + constant) */ -public class DivideTransform implements Transform { - private String fieldName1; - private String fieldName2; - private List keys; - private String key2; - private String outputName; - private Double constant; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class DivideTransform extends OtherFeatureTransform { + private double constant = 0d; @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - keys = config.getStringList(key + ".keys"); - key2 = config.getString(key + ".key2"); - constant = config.getDouble((key + ".constant")); - outputName = config.getString(key + ".output"); + public DivideTransform configure(Config config, String key) { + return super.configure(config, key) + .constant(doubleFromConfig(config, key, ".constant", false, 0.0d)); } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } + protected String otherFeatureKey() { + return ".key2"; + } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + @Override + public String produceOutputFeatureName(String name) { + return name + "-d-" + otherFeatureName; + } - Map feature2 = floatFeatures.get(fieldName2); - if (feature2 == null) { - return; - } - Double div = feature2.get(key2); - if (div == null) { - return; - } + @Override + public void doTransform(MultiFamilyVector featureVector) { + double div = featureVector.getDouble(otherFeature); - Double scale = 1.0 / (constant + div); - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); + double scale = 1.0 / (constant + div); - for (String key : keys) { - if (feature1.containsKey(key)) { - Double val = feature1.get(key); - if (val != null) { - output.put(key + "-d-" + key2, val * scale); - } + for (FeatureValue value : getInput(featureVector)) { + Feature feature = value.feature(); + if (featureVector.containsKey(feature)) { + double v = value.value(); + Feature outputFeature = produceOutputFeature(feature); + featureVector.put(outputFeature, v * scale); } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransform.java deleted file mode 100644 index ed60ba8b..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransform.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; - -import java.util.Map; - -import com.typesafe.config.Config; - -/** - * Takes the floats in fieldName1, quantizes them into buckets, converts them to strings, then - * crosses them with the floats in fieldName2 and then stores the result in a new float feature - * output specified by outputName. - */ -public class FloatCrossFloatTransform implements Transform { - private String fieldName1; - private double bucket; - private double cap; - private String fieldName2; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - bucket = config.getDouble(key + ".bucket"); - if (config.hasPath(key + ".cap")) { - cap = config.getDouble(key + ".cap"); - } else { - cap = 1e10; - } - fieldName2 = config.getString(key + ".field2"); - outputName = config.getString(key + ".output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.floatFeatures; - - if (floatFeatures == null || floatFeatures.isEmpty()) { - return; - } - - Map map1 = floatFeatures.get(fieldName1); - - if (map1 == null || map1.isEmpty()) { - return; - } - - Map map2 = floatFeatures.get(fieldName2); - - if (map2 == null || map2.isEmpty()) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - for (Map.Entry entry1 : map1.entrySet()) { - String float1Key = entry1.getKey(); - Double float1Value = entry1.getValue(); - - if (float1Value > cap) { - float1Value = cap; - } - - Double float1Quantized = TransformUtil.quantize(float1Value, bucket); - - for (Map.Entry entry2 : map2.entrySet()) { - String float2Key = entry2.getKey(); - Double float2Value = entry2.getValue(); - - String outputKey = float1Key + "=" + float1Quantized + "^" + float2Key; - - output.put(outputKey, float2Value); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransform.java deleted file mode 100644 index 95859424..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransform.java +++ /dev/null @@ -1,80 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -/* - Turn several float features into one dense feature, feature number must > 1 - 1. IF all float features are null, create a string feature, - with family name string_output, feature name output^null - 2. IF only one float feature is not null, create a float feature - with family name same as family of the only not null float feature - 3. Other cases create dense features - both 2 and 3, feature name: output^key keys. - */ -public class FloatToDenseTransform implements Transform{ - private List fields; - private List keys; - private String outputName; - private String outputStringFamily; - private static final int featureAVGSize = 16; - @Override - public void configure(Config config, String key) { - outputStringFamily = config.getString(key + ".string_output"); - outputName = config.getString(key + ".output"); - keys = config.getStringList(key + ".keys"); - fields = config.getStringList(key + ".fields"); - if (fields.size() != keys.size() || fields.size() <= 1) { - String msg = String.format("fields size {} keys size {}", fields.size(), keys.size()); - throw new RuntimeException(msg); - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - int size = fields.size(); - StringBuilder sb = new StringBuilder((size + 1) * featureAVGSize); - sb.append(outputName); - List output = new ArrayList<>(size); - Map floatFamily = null; - for (int i = 0; i < size; ++i) { - String familyName = fields.get(i); - Map family = floatFeatures.get(familyName); - if (family == null) { - continue; - } - String featureName = keys.get(i); - Double feature = family.get(keys.get(i)); - if (feature != null) { - output.add(feature); - sb.append('^'); - sb.append(featureName); - floatFamily = family; - } - } - - switch (output.size()) { - case 0: { - sb.append("^null"); - Util.setStringFeature(featureVector, outputStringFamily, sb.toString()); - } - break; - case 1: { - floatFamily.put(sb.toString(), output.get(0)); - } - break; - default: - Util.setDenseFeature(featureVector, sb.toString(), output); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransform.java deleted file mode 100644 index 5751a8b7..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransform.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.airbnb.aerosolve.core.models.KDTreeModel; -import com.airbnb.aerosolve.core.KDTreeNode; -import com.google.common.base.Optional; -import com.typesafe.config.Config; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; - -/** - * Inputs = fieldName1 (value1, value2) - * Outputs = list of kdtree nodes and the distance from the split - * This is the continuous version of the kd-tree transform and encodes - * the distance from each splitting plane to the point being queried. - * One can think of this as a tree kernel transform of a point. - */ -public class KdtreeContinuousTransform implements Transform { - private String fieldName1; - private String value1; - private String value2; - private String outputName; - private Integer maxCount; - private Optional modelOptional; - private static final Logger log = LoggerFactory.getLogger(KdtreeContinuousTransform.class); - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - value1 = config.getString(key + ".value1"); - value2 = config.getString(key + ".value2"); - outputName = config.getString(key + ".output"); - maxCount = config.getInt(key + ".max_count"); - String modelEncoded = config.getString(key + ".model_base64"); - - modelOptional = KDTreeModel.readFromGzippedBase64String(modelEncoded); - - if (!modelOptional.isPresent()) { - log.error("Could not load KDTree from encoded field"); - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - if (!modelOptional.isPresent()) { - return; - } - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Double v1 = feature1.get(value1); - Double v2 = feature1.get(value2); - - if (v1 == null || v2 == null) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - ArrayList result = modelOptional.get().query(v1, v2); - int count = Math.min(result.size(), maxCount); - KDTreeNode[] nodes = modelOptional.get().getNodes(); - - for (int i = 0; i < count; i++) { - Integer res = result.get(result.size() - 1 - i); - double split = nodes[res].getSplitValue(); - switch (nodes[res].getNodeType()) { - case X_SPLIT: { - output.put(res.toString(), v1 - split); - } - break; - case Y_SPLIT: { - output.put(res.toString(), v2 - split); - } - break; - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeTransform.java index 35149905..29fd019b 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/KdtreeTransform.java @@ -1,77 +1,103 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFeatureTransform; import com.airbnb.aerosolve.core.models.KDTreeModel; +import com.airbnb.aerosolve.core.KDTreeNode; import com.google.common.base.Optional; import com.typesafe.config.Config; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; +import java.util.ArrayList; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; + +import javax.validation.constraints.NotNull; /** * Inputs = fieldName1 (value1, value2) - * Outputs = list of kdtree nodes + * Outputs = list of kdtree nodes and the distance from the split + * This is the continuous version of the kd-tree transform and encodes + * the distance from each splitting plane to the point being queried. + * One can think of this as a tree kernel transform of a point. */ -public class KdtreeTransform implements Transform { - private String fieldName1; - private String value1; - private String value2; - private String outputName; +@Slf4j +@LegacyNames("kdtree_continuous") +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class KdtreeTransform extends DualFeatureTransform { private Integer maxCount; + @NotNull + private String modelEncoded; + private boolean continuous; + + @Setter(AccessLevel.NONE) private Optional modelOptional; - private static final Logger log = LoggerFactory.getLogger(KdtreeTransform.class); @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - value1 = config.getString(key + ".value1"); - value2 = config.getString(key + ".value2"); - outputName = config.getString(key + ".output"); - maxCount = config.getInt(key + ".max_count"); - String modelEncoded = config.getString(key + ".model_base64"); - - modelOptional = KDTreeModel.readFromGzippedBase64String(modelEncoded); + public KdtreeTransform configure(Config config, String key) { + return super.configure(config, key) + .maxCount(intFromConfig(config, key, ".max_count")) + .modelEncoded(stringFromConfig(config, key, ".model_base64")) + .continuous(isContinuous(config, key)); + } - if (!modelOptional.isPresent()) { - log.error("Could not load KDTree from encoded field"); + private boolean isContinuous(Config config, String key) { + if (!booleanFromConfig(config, key, ".continuous")) { + String transformType = getTransformType(config, key); + // For legacy transform types. + return transformType != null && transformType.endsWith("continuous"); } + return true; } @Override - public void doTransform(FeatureVector featureVector) { + protected void setup() { + super.setup(); + modelOptional = KDTreeModel.readFromGzippedBase64String(modelEncoded); if (!modelOptional.isPresent()) { - return; - } - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Double v1 = feature1.get(value1); - Double v2 = feature1.get(value2); - - if (v1 == null || v2 == null) { - return; + String message = "Could not load KDTree from encoded field"; + log.error(message); + throw new IllegalStateException(message); } + } - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); + @Override + public void doTransform(MultiFamilyVector featureVector) { + double v1 = featureVector.getDouble(inputFeature); + double v2 = featureVector.getDouble(otherFeature); ArrayList result = modelOptional.get().query(v1, v2); int count = Math.min(result.size(), maxCount); - for (int i = 0; i < count; i++) { - Integer res = result.get(result.size() - 1 - i); - output.add(res.toString()); + if (continuous) { + KDTreeNode[] nodes = modelOptional.get().nodes(); + + for (int i = 0; i < count; i++) { + Integer res = result.get(result.size() - 1 - i); + double split = nodes[res].getSplitValue(); + switch (nodes[res].getNodeType()) { + case X_SPLIT: { + featureVector.put(outputFamily.feature(res.toString()), v1 - split); + } + break; + case Y_SPLIT: { + featureVector.put(outputFamily.feature(res.toString()), v2 - split); + } + break; + } + } + } else { + for (int i = 0; i < count; i++) { + Integer res = result.get(result.size() - 1 - i); + featureVector.putString(outputFamily.feature(res.toString())); + } } } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/LegacyNames.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/LegacyNames.java new file mode 100644 index 00000000..e65bcd39 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/LegacyNames.java @@ -0,0 +1,16 @@ +package com.airbnb.aerosolve.core.transforms; + +/** + * + */ + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface LegacyNames { + String[] value(); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransform.java deleted file mode 100644 index 87601725..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransform.java +++ /dev/null @@ -1,138 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.Map; -import java.util.Map.Entry; - -/** - * A quantizer that starts out with linearly space buckets that get coarser and coarser - * and eventually transitions to log buckets. - */ -public class LinearLogQuantizeTransform implements Transform { - private String fieldName1; - private String outputName; - - private static StringBuilder sb; - // Upper limit of each bucket to check if feature value falls in the bucket - private static List limits; - // Step size used for quantization, for the correponding limit - private static List stepSizes; - // Limit beyond which quantized value would be rounded to integer (ignoring decimals) - private static double integerRoundingLimit; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - outputName = config.getString(key + ".output"); - sb = new StringBuilder(); - - limits = new ArrayList<>(); - stepSizes = new ArrayList<>(); - - limits.add(1.0); - stepSizes.add(1.0 / 32.0); - - limits.add(10.0); - stepSizes.add(0.125); - - limits.add(25.0); - stepSizes.add(0.25); - - limits.add(50.0); - stepSizes.add(5.0); - - limits.add(100.0); - stepSizes.add(10.0); - - limits.add(400.0); - stepSizes.add(25.0); - - limits.add(2000.0); - stepSizes.add(100.0); - - limits.add(10000.0); - stepSizes.add(250.0); - - integerRoundingLimit = 25.0; - } - - private static boolean checkAndQuantize(Double featureValue, double limit, double stepSize, boolean integerRounding) { - if (featureValue <= limit) { - if (!integerRounding) { - sb.append(TransformUtil.quantize(featureValue, stepSize)); - } else { - sb.append(TransformUtil.quantize(featureValue, stepSize).intValue()); - } - - return true; - } - - return false; - } - - private static String logQuantize(String featureName, double featureValue) { - sb.setLength(0); - sb.append(featureName); - sb.append('='); - - Double dbl = featureValue; - if (dbl < 0.0) { - sb.append('-'); - dbl = -dbl; - } - // At every stage we quantize roughly to a precision 10% of the magnitude. - if (dbl < 1e-2) { - sb.append('0'); - } else { - boolean isQuantized = false; - for (int i = 0; i < limits.size(); i++) { - Double limit = limits.get(i); - Double stepSize = stepSizes.get(i); - if (limit > integerRoundingLimit) { - isQuantized = checkAndQuantize(dbl, limit, stepSize, true); - } else { - isQuantized = checkAndQuantize(dbl, limit, stepSize, false); - } - - if (isQuantized) { - break; - } - } - - if (! isQuantized) { - Double exp = Math.log(dbl) / Math.log(2.0); - Long val = 1L << exp.intValue(); - sb.append(val); - } - } - - return sb.toString(); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null || feature1.isEmpty()) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - for (Entry feature : feature1.entrySet()) { - output.add(logQuantize(feature.getKey(), feature.getValue())); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ListTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ListTransform.java index 079660b7..33a66ad2 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ListTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ListTransform.java @@ -1,37 +1,54 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; import com.typesafe.config.Config; - -import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.Vector; +import java.util.function.Function; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; /** * Created by hector_yee on 8/25/14. * A transform that accepts a list of other transforms and applies them as a group * in the order specified by the list. */ -public class ListTransform implements Transform { - private List transforms; +@Slf4j +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class ListTransform extends ConfigurableTransform + implements ModelAware { + + private AbstractModel model; + + @Override + public AbstractModel model() { + return model; + } @Override - public void configure(Config config, String key) { - transforms = new ArrayList<>(); + public ListTransform model(AbstractModel model) { + this.model = model; + return this; + } + + // Starts with identity to make null cases easier. + private Transform bigTransform = f -> f; + + @Override + public ListTransform configure(Config config, String key) { List transformKeys = config.getStringList(key + ".transforms"); for (String transformKey : transformKeys) { - Transform tmpTransform = TransformFactory.createTransform(config, transformKey); - if (tmpTransform != null) { - transforms.add(tmpTransform); - } + Transform tmpTransform = + TransformFactory.createTransform(config, transformKey, registry, model); + bigTransform = (Transform) bigTransform.andThen(tmpTransform); } + return this; } @Override - public void doTransform(FeatureVector featureVector) { - for (Transform transform : transforms) { - transform.doTransform(featureVector); - } + public void doTransform(MultiFamilyVector featureVector) { + bigTransform.apply(featureVector); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathFloatTransform.java deleted file mode 100644 index ba31cb2c..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathFloatTransform.java +++ /dev/null @@ -1,86 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.google.common.base.Optional; -import com.typesafe.config.Config; - -import java.util.List; -import java.util.Map; -import java.util.function.DoubleFunction; - -/** - * Apply given Math function on specified float features defined by fieldName1 and keys - * fieldName1: feature family name - * keys: feature names - * outputName: output feature family name (feature names or keys remain the same) - * function: a string that specified the function that is going to apply to the given feature - */ -public class MathFloatTransform implements Transform { - private String fieldName1; // feature family name - private List keys; // feature names - private String outputName; // output feature family name - private String functionName; // a string that specified the function that is going to apply to the given feature - private Optional> func; - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - if (config.hasPath(key + ".keys")) { - keys = config.getStringList(key + ".keys"); - } - outputName = config.getString(key + ".output"); - functionName = config.getString(key + ".function"); - func = getFunction(); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (keys.isEmpty()) { - return; - } - - if (!func.isPresent()) { - return; - } - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - - if (feature1 == null) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - for (String key : keys) { - Double v = feature1.get(key); - if (v != null) { - Double result = func.get().apply(v); - if (!result.isNaN() && !result.isInfinite()) { - output.put(key, result); - } - } - } - } - - private Optional> getFunction() { - switch (functionName) { - case "sin": - return Optional.of((double x) -> Math.sin(x)); - case "cos": - return Optional.of((double x) -> Math.cos(x)); - case "log10": - // return the original value if x <= 0 - return Optional.of((double x) -> Math.log10(x)); - case "log": - // return the original value if x <= 0 - return Optional.of((double x) -> Math.log(x)); - case "abs": - return Optional.of((double x) -> Math.abs(x)); - } - return Optional.>absent(); - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathTransform.java new file mode 100644 index 00000000..21f3117b --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MathTransform.java @@ -0,0 +1,87 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; +import com.google.common.base.Optional; +import com.typesafe.config.Config; +import java.util.function.DoubleUnaryOperator; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; + +import javax.validation.constraints.NotNull; + +/** + * Apply given Math function on specified float features defined by fieldName1 and keys + * fieldName1: feature family name + * keys: feature names + * outputName: output feature family name (feature names or keys remain the same) + * function: a string that specified the function that is going to apply to the given feature + */ +@LegacyNames("math_float") +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class MathTransform extends BaseFeaturesTransform { + @NotNull + private String functionName; + + @Setter(AccessLevel.NONE) + private Optional func; + + @Override + public MathTransform configure(Config config, String key) { + return super.configure(config, key) + .functionName(stringFromConfig(config, key, ".function")); + } + + @Override + protected void setup() { + super.setup(); + func = getFunction(functionName); + if (!func.isPresent()) { + throw new IllegalArgumentException( + String.format("Cannot run math transform. %s function is unknown.", functionName)); + } + } + + @Override + public void doTransform(MultiFamilyVector featureVector) { + for (FeatureValue value : getInput(featureVector)) { + if (featureVector.containsKey(value.feature())) { + double v = value.value(); + double result = func.get().applyAsDouble(v); + if (Double.isNaN(result) + || Double.isInfinite(result)) { + continue; + } + Feature outputFeature = produceOutputFeature(value.feature()); + featureVector.put(outputFeature, result); + } + } + } + + private Optional getFunction(String functionName) { + switch (functionName) { + case "sin": + return Optional.of(Math::sin); + case "cos": + return Optional.of(Math::cos); + case "log10": + // return the original value if x <= 0 + return Optional.of(Math::log10); + case "log": + // return the original value if x <= 0 + return Optional.of(Math::log); + case "abs": + return Optional.of(Math::abs); + } + return Optional.absent(); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ModelAware.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ModelAware.java new file mode 100644 index 00000000..d646d7cb --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ModelAware.java @@ -0,0 +1,13 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.models.AbstractModel; + +/** + * + */ +public interface ModelAware { + + AbstractModel model(); + + T model(AbstractModel model); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java deleted file mode 100644 index 23687d98..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransform.java +++ /dev/null @@ -1,93 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.Collection; -import java.util.Map; -import java.util.Set; - -/** - * Takes the floats in the keys of fieldName1 (or if keys are not specified, all floats) and - * quantizes them into buckets. If the quantized float is less than or equal to a maximum specified - * bucket value or greater than or equal to a minimum specified bucket value, then the quantized - * float is stored as a string in a new string feature output specified by stringOutputName. - * Otherwise, the original, unchanged float is stored in a new float feature output specified by - * floatOutputName. The input float feature remains unchanged. - */ -public class MoveFloatToStringAndFloatTransform implements Transform { - private String fieldName1; - private Collection keys; - private double bucket; - private double maxBucket; - private double minBucket; - private String stringOutputName; - private String floatOutputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - if (config.hasPath(key + ".keys")) { - keys = config.getStringList(key + ".keys"); - } - bucket = config.getDouble(key + ".bucket"); - maxBucket = config.getDouble(key + ".max_bucket"); - minBucket = config.getDouble(key + ".min_bucket"); - stringOutputName = config.getString(key + ".string_output"); - floatOutputName = config.getString(key + ".float_output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.floatFeatures; - - if (floatFeatures == null || floatFeatures.isEmpty()) { - return; - } - - Map input = floatFeatures.get(fieldName1); - - if (input == null || input.isEmpty()) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set stringOutput = Util.getOrCreateStringFeature(stringOutputName, stringFeatures); - - Map floatOutput = Util.getOrCreateFloatFeature(floatOutputName, floatFeatures); - - if (keys == null) { - keys = input.keySet(); - } - - for (String key : keys) { - moveFloatToStringAndFloat( - input, key, bucket, minBucket, maxBucket, stringOutput, floatOutput); - } - } - - private static void moveFloatToStringAndFloat( - Map input, - String key, - double bucket, - double minBucket, - double maxBucket, - Set stringOutput, - Map floatOutput) { - if (input.containsKey(key)) { - Double inputFloatValue = input.get(key); - - Double inputFloatQuantized = TransformUtil.quantize(inputFloatValue, bucket); - - if (inputFloatQuantized >= minBucket && inputFloatQuantized <= maxBucket) { - String movedFloat = key + "=" + inputFloatQuantized; - stringOutput.add(movedFloat); - } else { - floatOutput.put(key, inputFloatValue); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransform.java deleted file mode 100644 index f19927cd..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransform.java +++ /dev/null @@ -1,89 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.Iterator; -import java.util.Set; -import java.util.Map; -import java.util.Map.Entry; -import java.util.List; - -/** - * Moves named fields from one family to another. If keys are not specified, all keys are moved - * from the float family. - */ -public class MoveFloatToStringTransform implements Transform { - private String fieldName1; - private double bucket; - private String outputName; - private List keys; - private double cap; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - bucket = config.getDouble(key + ".bucket"); - outputName = config.getString(key + ".output"); - if (config.hasPath(key + ".keys")) { - keys = config.getStringList(key + ".keys"); - } - if (config.hasPath(key + ".cap")) { - cap = config.getDouble(key + ".cap"); - } else { - cap = 1e10; - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null || feature1.isEmpty()) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - if (keys != null) { - for (String key : keys) { - moveFloat(feature1, output, key, cap, bucket); - feature1.remove(key); - } - } else { - for (Iterator> iterator = feature1.entrySet().iterator(); - iterator.hasNext();) { - Entry entry = iterator.next(); - String key = entry.getKey(); - - moveFloat(feature1, output, key, cap, bucket); - iterator.remove(); - } - } - } - - public static void moveFloat( - Map feature1, - Set output, - String key, - double cap, - double bucket) { - if (feature1.containsKey(key)) { - Double dbl = feature1.get(key); - if (dbl > cap) { - dbl = cap; - } - - Double quantized = TransformUtil.quantize(dbl, bucket); - output.add(key + '=' + quantized); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveTransform.java new file mode 100644 index 00000000..b87f78ec --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MoveTransform.java @@ -0,0 +1,90 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; +import com.airbnb.aerosolve.core.util.TransformUtil; +import com.typesafe.config.Config; +import java.util.List; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; + +/** + * Moves named fields from one family to another. If keys are not specified, all keys are moved + * from the float family. + */ +@LegacyNames({"move_float_to_string", + "multiscale_move_float_to_string", + "move_float_to_string_and_float"}) +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class MoveTransform extends BaseFeaturesTransform { + private Double bucket; + private Double cap; + private List buckets; + private Double maxBucket; + private Double minBucket; + + @Override + public MoveTransform configure(Config config, String key) { + MoveTransform transform = super.configure(config, key); + if (outputFamilyName == null) { + // TODO (Brad): Make this ".output" in configs. + transform.outputFamilyName(stringFromConfig(config, key, ".float_output")); + // TODO (Brad): Think about ".string_output". Do we really need it? It doesn't make sense + // now that we don't distinguish strings and floats. But we may need multiple output families. + } + return transform.bucket(doubleFromConfig(config, key, ".bucket", false)) + .buckets(doubleListFromConfig(config, key, ".buckets", false)) + .maxBucket(doubleFromConfig(config, key, ".max_bucket", false)) + .minBucket(doubleFromConfig(config, key, ".min_bucket", false)) + .cap(doubleFromConfig(config, key, ".cap", false, 1e10)); + } + + @Override + public void doTransform(MultiFamilyVector featureVector) { + for (FeatureValue value : getInput(featureVector)) { + double dbl = value.value(); + if (dbl > cap) { + dbl = cap; + } + + boolean placed = false; + if (bucket != null) { + Double quantized = TransformUtil.quantize(dbl, bucket); + if (minBucket == null || maxBucket == null + || (quantized >= minBucket && quantized <= maxBucket)) { + featureVector.putString(outputFamily.feature(value.feature().name() + '=' + quantized)); + placed = true; + } + } else if (buckets != null) { + for (Double bucket : buckets) { + Double quantized = TransformUtil.quantize(dbl, bucket); + Feature outputFeature = outputFamily.feature( + value.feature().name() + '[' + bucket + "]=" + quantized); + featureVector.putString(outputFeature); + placed = true; + } + } + if (!placed) { + // TODO (Brad): Question for Chris: Is it important that for move_float_to_string_and_float + // the features end up in different families now that we can have Strings and Floats in + // the same family? Would it make sense to just leave it in the original family + // if it's not inside the bounds? + // We can handle it as is, but it feels a little complex and seems like it might be better + // handled with two Moves. One that keeps the features in the existing family and deletes + // keys that match the interval. + // Then we could have another Move that moves the family that remains if it's important it + // have a different name. + featureVector.put(outputFamily.feature(value.feature().name()), value.value()); + } + featureVector.remove(value.feature()); + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransform.java index af23fd21..5cd5438f 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransform.java @@ -1,73 +1,63 @@ package com.airbnb.aerosolve.core.transforms; import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFeatureTransform; import com.typesafe.config.Config; - -import java.util.*; -import java.util.Map.Entry; +import java.util.List; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; /** * Quantizes the floatFeature named in "field1" with buckets in "bucket" before placing * it in the floatFeature named "output" subtracting the origin of the box. */ -public class MultiscaleGridContinuousTransform implements Transform { - private String fieldName1; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class MultiscaleGridContinuousTransform + extends DualFeatureTransform { + private List buckets; - private String outputName; - private String value1; - private String value2; @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - buckets = config.getDoubleList(key + ".buckets"); - outputName = config.getString(key + ".output"); - value1 = config.getString(key + ".value1"); - value2 = config.getString(key + ".value2"); - + public MultiscaleGridContinuousTransform configure(Config config, String key) { + return super.configure(config, key) + .buckets(doubleListFromConfig(config, key, ".buckets", true)); } - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Double v1 = feature1.get(value1); - Double v2 = feature1.get(value2); - if (v1 == null || v2 == null) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); + @Override + public void doTransform(MultiFamilyVector featureVector) { + double v1 = featureVector.getDouble(inputFeature); + double v2 = featureVector.getDouble(otherFeature); - transformFeature(v1, v2, output); + transformFeature(v1, v2, featureVector); } public void transformFeature(double v1, double v2, - Map output) { + FeatureVector vector) { for (Double bucket : buckets) { - transformFeature(v1, v2, bucket, output); + transformFeature(v1, v2, bucket, outputFamily, vector); } } public static void transformFeature(double v1, double v2, double bucket, - Map output) { - Double mult1 = v1 / bucket; - double q1 = bucket * mult1.intValue(); - Double mult2 = v2 / bucket; - double q2 = bucket * mult2.intValue(); + Family outputFamily, + FeatureVector vector) { + double mult1 = v1 / bucket; + double q1 = bucket * Math.floor(mult1); + double mult2 = v2 / bucket; + double q2 = bucket * Math.floor(mult2); String bucketName = "[" + bucket + "]=(" + q1 + ',' + q2 + ')'; - output.put(bucketName + "@1", v1 - q1); - output.put(bucketName + "@2", v2 - q2); + vector.put(outputFamily.feature(bucketName + "@1"), v1 - q1); + vector.put(outputFamily.feature(bucketName + "@2"), v2 - q2); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransform.java index dc0c8c5d..937ed05b 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransform.java @@ -1,66 +1,58 @@ package com.airbnb.aerosolve.core.transforms; import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFeatureTransform; import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; import com.typesafe.config.Config; - -import java.util.*; +import java.util.List; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.Accessors; /** * Created by hector_yee on 8/25/14. * Quantizes the floatFeature named in "field1" with buckets in "bucket" before placing * it in the stringFeature named "output" */ -public class MultiscaleGridQuantizeTransform implements Transform { - private String fieldName1; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class MultiscaleGridQuantizeTransform + extends DualFeatureTransform { private List buckets; - private String outputName; - private String value1; - private String value2; @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - buckets = config.getDoubleList(key + ".buckets"); - outputName = config.getString(key + ".output"); - value1 = config.getString(key + ".value1"); - value2 = config.getString(key + ".value2"); - + public MultiscaleGridQuantizeTransform configure(Config config, String key) { + return super.configure(config, key) + .buckets(doubleListFromConfig(config, key, ".buckets", true)); } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + public void doTransform(MultiFamilyVector featureVector) { - Double v1 = feature1.get(value1); - Double v2 = feature1.get(value2); - if (v1 == null || v2 == null) { - return; - } + double v1 = featureVector.getDouble(inputFeature); + double v2 = featureVector.getDouble(otherFeature); - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - transformFeature(v1, v2, buckets, output); + transformFeature(v1, v2, buckets, outputFamily, featureVector); } - public static void transformFeature(double v1, double v2, List buckets, Set output) { + public static void transformFeature(double v1, double v2, List buckets, + Family outputFamily, + FeatureVector vector) { for (Double bucket : buckets) { - transformFeature(v1, v2, bucket, output); + transformFeature(v1, v2, bucket, outputFamily, vector); } } - public static void transformFeature(double v1, double v2, double bucket, Set output) { + public static void transformFeature(double v1, double v2, double bucket, + Family outputFamily, FeatureVector vector) { double q1 = TransformUtil.quantize(v1, bucket); double q2 = TransformUtil.quantize(v2, bucket); - output.add("[" + bucket + "]=(" + q1 + ',' + q2 + ')'); + vector.putString(outputFamily.feature("[" + bucket + "]=(" + q1 + ',' + q2 + ')')); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransform.java deleted file mode 100644 index 175d581d..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransform.java +++ /dev/null @@ -1,64 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; -import java.util.Set; -import java.util.Map; -import java.util.List; - -/** - * Moves named fields from one family to another. - */ -public class MultiscaleMoveFloatToStringTransform implements Transform { - private String fieldName1; - private List buckets; - private String outputName; - private List keys; - private double cap; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - buckets = config.getDoubleList(key + ".buckets"); - outputName = config.getString(key + ".output"); - keys = config.getStringList(key + ".keys"); - try { - cap = config.getDouble(key + ".cap"); - } catch (Exception e) { - cap = 1e10; - } - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - for (String key : keys) { - if (feature1.containsKey(key)) { - Double dbl = feature1.get(key); - if (dbl > cap) { - dbl = cap; - } - for (Double bucket : buckets) { - Double quantized = TransformUtil.quantize(dbl, bucket); - output.add(key + '[' + bucket + "]=" + quantized); - } - feature1.remove(key); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransform.java deleted file mode 100644 index 8d420a14..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransform.java +++ /dev/null @@ -1,65 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.TransformUtil; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.*; -import java.util.Map.Entry; - -/** - * Created by hector_yee on 8/25/14. - * Quantizes the floatFeature named in "field1" with buckets in "bucket" before placing - * it in the stringFeature named "output" - */ -public class MultiscaleQuantizeTransform implements Transform { - private String fieldName1; - private List buckets; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - buckets = config.getDoubleList(key + ".buckets"); - outputName = config.getString(key + ".output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - for (Entry feature : feature1.entrySet()) { - transformAndAddFeature(buckets, - feature.getKey(), - feature.getValue(), - output); - } - } - - public static void transformAndAddFeature(List buckets, - String featureName, - Double featureValue, - Set output) { - if (featureValue == 0.0) { - output.add(featureName + "=0"); - return; - } - - for (double bucket : buckets) { - double quantized = TransformUtil.quantize(featureValue, bucket); - output.add(featureName + '[' + bucket + "]=" + quantized); - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NearestTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NearestTransform.java index 55281cad..4995d549 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NearestTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NearestTransform.java @@ -1,68 +1,34 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.HashMap; -import java.util.Set; -import java.util.HashSet; -import java.util.Map; -import java.util.Map.Entry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.OtherFeatureTransform; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; /** * output = nearest of (field1, field2.key) */ -public class NearestTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String key2; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - key2 = config.getString(key + ".key"); - outputName = config.getString(key + ".output"); - } - +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class NearestTransform extends OtherFeatureTransform { @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + protected void doTransform(MultiFamilyVector featureVector) { + double sub = featureVector.getDouble(otherFeature); - Map feature2 = floatFeatures.get(fieldName2); - if (feature2 == null) { - return; - } - Double sub = feature2.get(key2); - if (sub == null) { - return; - } - - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); String nearest = "nothing"; double bestDist = 1e10; - for (Entry f1 : feature1.entrySet()) { - double dist = Math.abs(f1.getValue() - sub); + for (FeatureValue value : getInput(featureVector)) { + double dist = Math.abs(value.value() - sub); if (dist < bestDist) { - nearest = f1.getKey(); + nearest = value.feature().name(); bestDist = dist; } } - output.add(key2 + "~=" + nearest); + featureVector.putString(outputFamily.feature(otherFeature.name() + "~=" + nearest)); + } - Util.optionallyCreateStringFeatures(featureVector); + protected String otherFeatureKey() { + return ".key"; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransform.java deleted file mode 100644 index 6082a0ff..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransform.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; - -import java.util.Map; - -// L2 normalizes a float feature -public class NormalizeFloatTransform implements Transform { - private String fieldName1; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - double norm = 0.0; - for (Map.Entry feat : feature1.entrySet()) { - norm += feat.getValue() * feat.getValue(); - } - if (norm > 0.0) { - double scale = 1.0 / Math.sqrt(norm); - for (Map.Entry feat : feature1.entrySet()) { - feat.setValue(feat.getValue() * scale); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeTransform.java new file mode 100644 index 00000000..b07f771f --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeTransform.java @@ -0,0 +1,38 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; + +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import java.util.LinkedList; +import java.util.List; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +/** + * L2 normalizes a float feature + */ +@LegacyNames("normalize_float") +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class NormalizeTransform + extends BaseFeaturesTransform { + + @Override + protected void doTransform(MultiFamilyVector featureVector) { + double norm = 0.0; + List values = new LinkedList<>(); + for (FeatureValue value : getInput(featureVector)) { + norm += value.value() * value.value(); + // TODO (Brad): May need to copy here due to re-using instances in the iterable. + values.add(value); + } + if (norm > 0.0) { + double scale = 1.0 / Math.sqrt(norm); + // Not sure it's necessary but I'm storing the values to avoid mutating the vector + // while iterating its values. + for (FeatureValue value : values) { + featureVector.put(value.feature(), value.value() * scale); + } + } + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8Transform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8Transform.java index 19b8fe6f..a8b19ff9 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8Transform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8Transform.java @@ -1,10 +1,14 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.transforms.types.StringTransform; - -import java.text.Normalizer; - +import com.airbnb.aerosolve.core.transforms.base.StringTransform; import com.typesafe.config.Config; +import java.text.Normalizer; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; /** * Normalizes strings to UTF-8 NFC, NFD, NFKC or NFKD form (NFD by default) @@ -13,16 +17,33 @@ * "output" optionally specifies the key of the output feature, if it is not given the transform * overwrites / replaces the input feature */ -public class NormalizeUtf8Transform extends StringTransform { +// This is what it used to be called but the automatic namer turns it into normalize_utf8 +// I think that's arguably more correct so we'll support both. +@LegacyNames("normalize_utf_8") +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class NormalizeUtf8Transform extends StringTransform { public static final Normalizer.Form DEFAULT_NORMALIZATION_FORM = Normalizer.Form.NFD; + private String normalizationFormString; + + @Setter(AccessLevel.NONE) private Normalizer.Form normalizationForm; @Override - public void init(Config config, String key) { - String normalizationFormString = DEFAULT_NORMALIZATION_FORM.name(); - if (config.hasPath(key + ".normalization_form")) { - normalizationFormString = config.getString(key + ".normalization_form"); + public NormalizeUtf8Transform configure(Config config, String key) { + return super.configure(config, key) + .normalizationFormString(stringFromConfig(config, key, ".normalization_form", false)); + } + + @Override + protected void setup() { + super.setup(); + if (normalizationFormString == null) { + normalizationForm = DEFAULT_NORMALIZATION_FORM; + return; } if (normalizationFormString.equalsIgnoreCase("NFC")) { normalizationForm = Normalizer.Form.NFC; @@ -39,10 +60,6 @@ public void init(Config config, String key) { @Override public String processString(String rawString) { - if (rawString == null) { - return null; - } - return Normalizer.normalize(rawString, normalizationForm); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ProductTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ProductTransform.java index 188b2be2..2e56500d 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ProductTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ProductTransform.java @@ -1,51 +1,37 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.HashMap; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; /** * Computes the polynomial product of all values in field1 * i.e. prod_i 1 + x_i * and places the result in outputName */ -public class ProductTransform implements Transform { - private String fieldName1; - private List keys; - private String outputName; +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class ProductTransform extends BaseFeaturesTransform { + private Feature outputFeature; @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - outputName = config.getString(key + ".output"); - keys = config.getStringList(key + ".keys"); + protected void setup() { + super.setup(); + outputFeature = outputFamily.feature("*"); } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - + public void doTransform(MultiFamilyVector featureVector) { Double prod = 1.0; - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - for (String key : keys) { - Double dbl = feature1.get(key); - if (dbl != null) { - prod *= 1.0 + dbl; - } + boolean computedSomething = false; + for (FeatureValue value : getInput(featureVector)) { + computedSomething = true; + prod *= 1.0 + value.value(); + } + if (computedSomething) { + featureVector.put(outputFeature, prod); } - output.put("*", prod); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/QuantizeTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/QuantizeTransform.java index 76818b23..c14e2cbc 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/QuantizeTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/QuantizeTransform.java @@ -1,61 +1,229 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.BaseFeaturesTransform; +import com.airbnb.aerosolve.core.util.TransformUtil; +import com.google.common.collect.ImmutableSet; import com.typesafe.config.Config; -import java.util.Set; import java.util.HashSet; -import java.util.Map; -import java.util.Map.Entry; +import java.util.List; +import java.util.Set; +import java.util.TreeMap; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; /** * Created by hector_yee on 8/25/14. + * TODO (Brad): Update doc * Multiplies the floatFeature named in "field1" with "scale" before placing * it in the stringFeature named "output" */ -public class QuantizeTransform implements Transform { - private String fieldName1; - private double scale; - private String outputName; +@LegacyNames({"custom_linear_log_quantize", + "custom_multiscale_quantize", + "linear_log_quantize", + "multiscale_quantize" +}) +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class QuantizeTransform extends BaseFeaturesTransform { + /** Upper limit of each bucket to check if feature value falls in the bucket **/ + private static double[] LIMITS = { + 1.0, 10.0, 25.0, 50.0, 100.0, 400.0, 2000.0, 10000.0 + }; + + /** Step size used for quantization, for the correponding limit **/ + private static double[] STEP_SIZES = { + 1.0 / 32.0, 0.125, 0.25, 5.0, 10.0, 25.0, 100.0, 250.0 + }; + + /** Limit beyond which quantized value would be rounded to integer (ignoring decimals) **/ + private static double INTEGER_ROUNDING_LIMIT = 25.0; + + private Double scale; + private QuantizeType type = QuantizeType.SIMPLE; + private List buckets; + private TreeMap limitBucketPairs; + + @Setter(AccessLevel.NONE) + private double upperLimit; + + @Override + public QuantizeTransform configure(Config config, String key) { + return super.configure(config, key) + .type(figureOutType(config, key)) + .scale(doubleFromConfig(config, key, ".scale", false)) + .buckets(doubleListFromConfig(config, key, ".buckets", false)) + .limitBucketPairs(doubleTreeMapFromConfig(config, key, ".limit_bucket", false)); + } + + private QuantizeType figureOutType(Config config, String key) { + String type = stringFromConfig(config, key, ".type", false); + if (type != null) { + return QuantizeType.valueOf(type.toUpperCase()); + } + String transformType = getTransformType(config, key); + // For legacy reasons + if (transformType != null && transformType.equals("linear_log_quantize")) { + return QuantizeType.LOG; + } + return QuantizeType.SIMPLE; + } @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - scale = config.getDouble(key + ".scale"); - outputName = config.getString(key + ".output"); + protected void setup() { + super.setup(); + if (limitBucketPairs != null) { + upperLimit = limitBucketPairs.lastKey(); + } } + // TODO (Brad): Validation + @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); + protected void doTransform(MultiFamilyVector featureVector) { - if (floatFeatures == null) { - return; + for (FeatureValue value : getInput(featureVector)) { + Set resultFeatures; + String featureName = value.feature().name(); + double featureValue = value.value(); + if (limitBucketPairs != null) { + resultFeatures = customLogQuantize(featureName, featureValue); + } else if (type == QuantizeType.LOG) { + resultFeatures = logQuantize(featureName, featureValue); + } else if (buckets != null) { + resultFeatures = bucket(featureName, featureValue); + } else if (scale != null) { + resultFeatures = scale(featureName, featureValue); + } else { + // TODO (Brad): Log issue + resultFeatures = ImmutableSet.of(); + } + for (Feature feature : resultFeatures) { + featureVector.putString(feature); + } } - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null || feature1.isEmpty()) { - return; + } + + private Set bucket(String featureName, double featureValue) { + if (featureValue == 0.0) { + return ImmutableSet.of(outputFamily.feature(featureName + "=0")); } - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); + Set features = new HashSet<>(); + for (double bucket : buckets) { + double quantized = TransformUtil.quantize(featureValue, bucket); + features.add(outputFamily.feature(featureName + '[' + bucket + "]=" + quantized)); + } + return features; + } - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); + private Set customLogQuantize(String featureName, double featureValue) { + StringBuilder sb = new StringBuilder(); + sb.setLength(0); + sb.append(featureName); + boolean isValueNegative = false; + if (featureValue < 0.0) { + isValueNegative = true; + featureValue = -featureValue; + } - for (Entry feature : feature1.entrySet()) { - transformAndAddFeature(scale, - feature.getKey(), - feature.getValue(), - output); + if (featureValue < 1e-2) { + sb.append("=0.0"); + } else { + double limit; + double bucket; + if (featureValue >= upperLimit) { + featureValue = upperLimit; + bucket = limitBucketPairs.get(upperLimit); + } else { + limit = limitBucketPairs.higherKey(featureValue); + bucket = limitBucketPairs.get(limit); + } + + Double val = TransformUtil.quantize(featureValue, bucket) * 1000; + + sb.append('='); + if (isValueNegative) { + sb.append('-'); + } + + sb.append(val.intValue()/1000.0); + } + + return ImmutableSet.of(outputFamily.feature(sb.toString())); + } + + protected Set scale(String featureName, double featureValue) { + double dbl = featureValue * scale; + return ImmutableSet.of(outputFamily.feature(featureName + '=' + (int) dbl)); + } + + private Set logQuantize(String featureName, double featureValue) { + StringBuilder sb = new StringBuilder(); + sb.append(featureName); + sb.append('='); + + Double dbl = featureValue; + if (dbl < 0.0) { + sb.append('-'); + dbl = -dbl; } + // At every stage we quantize roughly to a precision 10% of the magnitude. + if (dbl < 1e-2) { + sb.append('0'); + } else { + boolean isQuantized = false; + for (int i = 0; i < LIMITS.length; i++) { + double limit = LIMITS[i]; + double stepSize = STEP_SIZES[i]; + if (limit > INTEGER_ROUNDING_LIMIT) { + isQuantized = checkAndQuantize(sb, dbl, limit, stepSize, true); + } else { + isQuantized = checkAndQuantize(sb, dbl, limit, stepSize, false); + } + + if (isQuantized) { + break; + } + } + + if (! isQuantized) { + Double exp = Math.log(dbl) / Math.log(2.0); + Long val = 1L << exp.intValue(); + sb.append(val); + } + } + + return ImmutableSet.of(outputFamily.feature(sb.toString())); + } + + private static boolean checkAndQuantize(StringBuilder sb, + double featureValue, + double limit, + double stepSize, + boolean integerRounding) { + if (featureValue <= limit) { + if (!integerRounding) { + sb.append(TransformUtil.quantize(featureValue, stepSize)); + } else { + sb.append(TransformUtil.quantize(featureValue, stepSize).intValue()); + } + + return true; + } + + return false; } - public static void transformAndAddFeature(Double scale, - String featureName, - Double featureValue, - Set output) { - Double dbl = featureValue * scale; - int val = dbl.intValue(); - output.add(featureName + '=' + val); + public enum QuantizeType { + SIMPLE, LOG } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransform.java index 48a54550..bf7db0b3 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransform.java @@ -1,12 +1,21 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.transforms.types.StringTransform; - +import com.airbnb.aerosolve.core.transforms.base.StringTransform; +import com.typesafe.config.Config; import java.util.List; import java.util.Map; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigObject; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import org.apache.commons.lang3.tuple.Pair; +import org.hibernate.validator.constraints.NotEmpty; + +import javax.validation.constraints.NotNull; /** * Replaces all substrings that match a given regex with a replacement string @@ -15,30 +24,37 @@ * Replacements are performed in the same order as specified in the list of pairs * "replacement" specifies the replacement string */ -public class ReplaceAllStringsTransform extends StringTransform { - private List replacements; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class ReplaceAllStringsTransform extends StringTransform { + @NotNull + @NotEmpty + private Map replacements; + + @Setter(AccessLevel.NONE) + private List> patterns; @Override - public void init(Config config, String key) { - replacements = config.getObjectList(key + ".replacements"); + public ReplaceAllStringsTransform configure(Config config, String key) { + return super.configure(config, key) + .replacements(stringMapFromConfig(config, key, ".replacements", true)); } @Override - public String processString(String rawString) { - if (rawString == null) { - return null; - } - - for (ConfigObject replacementCO : replacements) { - Map replacementMap = replacementCO.unwrapped(); + protected void setup() { + super.setup(); + patterns = replacements.entrySet().stream() + .map(e -> Pair.of(Pattern.compile(e.getKey()), e.getValue())) + .collect(Collectors.toList()); + } - for (Map.Entry replacementEntry : replacementMap.entrySet()) { - String regex = replacementEntry.getKey(); - String replacement = (String) replacementEntry.getValue(); - rawString = rawString.replaceAll(regex, replacement); - } + @Override + public String processString(String rawString) { + for (Pair replacement : patterns) { + rawString = replacement.getKey().matcher(rawString).replaceAll(replacement.getValue()); } - return rawString; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransform.java deleted file mode 100644 index 90f98bf9..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransform.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; -import java.util.HashSet; -import java.util.Set; -import java.util.Map; - -/** - * Takes the self cross product of stringFeatures named in field1 - * and places it in a stringFeature with family name specified in output. - */ -public class SelfCrossTransform implements Transform { - private String fieldName1; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - outputName = config.getString(key + ".output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) return; - - Set set1 = stringFeatures.get(fieldName1); - if (set1 == null) return; - - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - selfCross(set1, output); - } - - public static void selfCross(Set set1, Set output) { - for (String s1 : set1) { - for (String s2 : set1) { - // To prevent duplication we only take pairs there s1 < s2. - if (s1.compareTo(s2) < 0) { - output.add(s1 + '^' + s2); - } - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransform.java deleted file mode 100644 index 569f8ab9..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransform.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; -import java.util.HashMap; -import java.util.Set; -import java.util.Map; - -public class StringCrossFloatTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - outputName = config.getString(key + ".output"); - } - - @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.stringFeatures; - Map> floatFeatures = featureVector.floatFeatures; - if (stringFeatures == null || stringFeatures.isEmpty()) return; - if (floatFeatures == null || floatFeatures.isEmpty()) return; - - Set list1 = stringFeatures.get(fieldName1); - if (list1 == null || list1.isEmpty()) return; - Map list2 = floatFeatures.get(fieldName2); - if (list2 == null || list2.isEmpty()) return; - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - - for (String s1 : list1) { - for (Map.Entry s2 : list2.entrySet()) { - output.put(s1 + "^" + s2.getKey(), s2.getValue()); - } - } - } -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StuffIdTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/StuffIdTransform.java index 7c57d5ec..30a4a80c 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StuffIdTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/StuffIdTransform.java @@ -1,10 +1,9 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.Map; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.DualFeatureTransform; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; /** * id = fieldName1.key1 @@ -18,49 +17,25 @@ * On the other hand searches_at_leaf @ 123 can tell you how the model changes * for searches at a particular place changing from day to day. */ -public class StuffIdTransform implements Transform { - private String fieldName1; - private String fieldName2; - private String key1; - private String key2; - private String outputName; +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class StuffIdTransform extends DualFeatureTransform { @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - key1 = config.getString(key + ".key1"); - key2 = config.getString(key + ".key2"); - outputName = config.getString(key + ".output"); + public void doTransform(MultiFamilyVector featureVector) { + double v1 = featureVector.getDouble(inputFeature); + double v2 = featureVector.getDouble(otherFeature); + + String newname = otherFeature.name() + '@' + (long)v1; + featureVector.put(outputFamily.feature(newname), v2); } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - Map feature2 = floatFeatures.get(fieldName2); - if (feature2 == null) { - return; - } - - Double v1 = feature1.get(key1); - Double v2 = feature2.get(key2); - if (v1 == null || v2 == null) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); + protected String inputFeatureKey() { + return ".key1"; + } - String newname = key2 + '@' + v1.longValue(); - output.put(newname, v2); + @Override + protected String otherFeatureKey() { + return ".key2"; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StumpTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/StumpTransform.java index cb04acba..b01d4a73 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/StumpTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/StumpTransform.java @@ -1,12 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; import com.typesafe.config.Config; - import java.io.Serializable; -import java.util.*; -import java.util.Map.Entry; +import java.util.ArrayList; +import java.util.List; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.Value; +import lombok.experimental.Accessors; +import org.hibernate.validator.constraints.NotEmpty; + +import javax.validation.constraints.NotNull; /** * Applies boosted stump transform to float features and places them in string feature output. @@ -15,70 +26,69 @@ * val model = sc.textFile(name).map(Util.decodeModel).take(10).map(x => * "%s,%s,%f".format(x.featureFamily,x.featureName,x.threshold)).foreach(println) */ -public class StumpTransform implements Transform { - private String outputName; - - private class StumpDescription implements Serializable { - public StumpDescription(String featureName, Double threshold, String descriptiveName) { - this.featureName = featureName; - this.threshold = threshold; - this.descriptiveName = descriptiveName; - } - public String featureName; - public Double threshold; - public String descriptiveName; - } +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class StumpTransform extends ConfigurableTransform { + @NotNull + private String outputFamilyName; + @NotNull + @NotEmpty + private List stumps; // Family -> description - private Map> thresholds; + @Setter(AccessLevel.NONE) + private List thresholds; + @Setter(AccessLevel.NONE) + private Family outputFamily; @Override - public void configure(Config config, String key) { - outputName = config.getString(key + ".output"); - thresholds = new HashMap<>(); + public StumpTransform configure(Config config, String key) { + return outputFamilyName(stringFromConfig(config, key, ".output")) + .stumps(stringListFromConfig(config, key, ".stumps", true)); + } - List stumps = config.getStringList(key + ".stumps"); + @Override + protected void setup() { + super.setup(); + outputFamily = registry.family(outputFamilyName); + thresholds = new ArrayList<>(stumps.size()); for (String stump : stumps) { String[] tokens = stump.split(","); if (tokens.length == 4) { String family = tokens[0]; String featureName = tokens[1]; - Double threshold = Double.parseDouble(tokens[2]); + Feature feature = registry.feature(family, featureName); + double threshold = Double.parseDouble(tokens[2]); String descriptiveName = tokens[3]; - List featureList = thresholds.get(family); - if (featureList == null) { - featureList = new ArrayList<>(); - thresholds.put(family, featureList); - } - StumpDescription description = new StumpDescription(featureName, + Feature outputFeature = outputFamily.feature(descriptiveName); + StumpDescription description = new StumpDescription(feature, threshold, - descriptiveName); - featureList.add(description); + outputFeature); + thresholds.add(description); } } } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); + public void doTransform(MultiFamilyVector featureVector) { + for (StumpDescription stump : thresholds) { + if (!featureVector.containsKey(stump.feature())) { + continue; + } - for (Entry> stumpFamily : thresholds.entrySet()) { - Map feature = floatFeatures.get(stumpFamily.getKey()); - if (feature == null) continue; - for (StumpDescription desc : stumpFamily.getValue()) { - Double value = feature.get(desc.featureName); - if (value != null && value >= desc.threshold) { - output.add(desc.descriptiveName); - } + double value = featureVector.getDouble(stump.feature()); + if (value >= stump.threshold()) { + featureVector.putString(stump.outputFeature()); } } } + + @Value + private static class StumpDescription implements Serializable { + private final Feature feature; + private final double threshold; + private final Feature outputFeature; + } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/SubtractTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/SubtractTransform.java index 630939b4..64870ec2 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/SubtractTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/SubtractTransform.java @@ -1,67 +1,34 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; - -import com.typesafe.config.Config; - -import java.util.HashMap; -import java.util.List; -import java.util.Set; -import java.util.HashSet; -import java.util.Map; -import java.util.Map.Entry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.OtherFeatureTransform; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; /** * output = field1 - field2.key */ -public class SubtractTransform implements Transform { - private String fieldName1; - private String fieldName2; - private List keys; - private String key2; - private String outputName; +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class SubtractTransform extends OtherFeatureTransform { @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - fieldName2 = config.getString(key + ".field2"); - keys = config.getStringList(key + ".keys"); - key2 = config.getString(key + ".key2"); - outputName = config.getString(key + ".output"); + protected String produceOutputFeatureName(String featureName) { + return featureName + '-' + otherFeatureName; } @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - - Map feature1 = floatFeatures.get(fieldName1); - if (feature1 == null) { - return; - } + public void doTransform(MultiFamilyVector featureVector) { + double sub = featureVector.getDouble(otherFeature); - Map feature2 = floatFeatures.get(fieldName2); - if (feature2 == null) { - return; + for (FeatureValue value : getInput(featureVector)) { + double val = featureVector.getDouble(value.feature()); + featureVector.put(produceOutputFeature(value.feature()), val - sub); } - Double sub = feature2.get(key2); - if (sub == null) { - return; - } - - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); + } - for (String key : keys) { - if (feature1.containsKey(key)) { - Double val = feature1.get(key); - if (val != null) { - output.put(key + '-' + key2, val - sub); - } - } - } + @Override + protected String otherFeatureKey() { + return ".key2"; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transform.java index 657bf814..ff84a5e0 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transform.java @@ -1,18 +1,12 @@ package com.airbnb.aerosolve.core.transforms; import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; - import java.io.Serializable; +import java.util.function.Function; /** * Created by hector_yee on 8/25/14. * Base class for feature transforms. */ -public interface Transform extends Serializable { - // Configure the transform from the supplied config and key. - void configure(Config config, String key); - - // Applies a transform to the featureVector. - void doTransform(FeatureVector featureVector); -} +@FunctionalInterface +public interface Transform extends Function, Serializable {} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/TransformFactory.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/TransformFactory.java index eb2286e6..1591ff76 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/TransformFactory.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/TransformFactory.java @@ -1,13 +1,24 @@ package com.airbnb.aerosolve.core.transforms; -import com.google.common.base.CaseFormat; +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; +import com.airbnb.aerosolve.core.util.Util; import com.typesafe.config.Config; +import java.util.Map; +import lombok.Synchronized; /** * Created by hector_yee on 8/25/14. */ public class TransformFactory { - public static Transform createTransform(Config config, String key) { + + private static Map> TRANSFORM_CLASSES; + + public static Transform createTransform(Config config, String key, + FeatureRegistry registry, + AbstractModel model) { if (config == null || key == null) { return null; } @@ -16,18 +27,46 @@ public static Transform createTransform(Config config, String key) { return null; } - String name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, transformName); - Transform result = null; + // TODO (Brad): I don't love that this is a static initialization. Can lead to awkward bugs but + // it's probably tricky to make clients use a singleton instance of this class without Guice. + // Not sure a static singleton in the class is any better. + if (TRANSFORM_CLASSES == null) { + loadTransformMap(); + } + + Class clazz = TRANSFORM_CLASSES.get(transformName); + if (clazz == null) { + throw new IllegalArgumentException( + String.format("No transform exists with name %s", transformName)); + } try { - result = (Transform) Class.forName("com.airbnb.aerosolve.core.transforms." + name + "Transform").newInstance(); - result.configure(config, key); - } catch (InstantiationException e) { - e.printStackTrace(); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } catch (ClassNotFoundException e) { - e.printStackTrace(); + ConfigurableTransform transform = clazz.newInstance(); + // (Brad): It's kind of awkward we have to do all this initialization. I tried to make + // Transforms immutable and was hoping to use Builder inheritance to make them buildable in + // a logical way. But Builder inheritance is a mess in Java and constructor inheritance + // leads to a bunch of hard to understand boilerplate in the concrete classes. So, this is + // what we have. Every Transform has to have a package private constructor. They can't be + // initialized any other way. So as long as we do all the steps correctly and in the right + // order here, we should be safe. + transform.registry(registry); + if (transform instanceof ModelAware && model != null) { + ((ModelAware)transform).model(model); + } + transform.configure(config, key); + return transform; + } catch (InstantiationException | IllegalAccessException e) { + throw new IllegalStateException( + String.format("There was an error instantiating Transform of class %s", clazz.getName())); + } + } + + @Synchronized + private static void loadTransformMap() { + if (TRANSFORM_CLASSES != null) { + return; } - return result; + TRANSFORM_CLASSES = Util.loadFactoryNamesFromPackage(ConfigurableTransform.class, + "com.airbnb.aerosolve.core.transforms", + "Transform"); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transformer.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transformer.java index b38cf296..5bef2050 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transformer.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/Transformer.java @@ -1,131 +1,49 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.models.AbstractModel; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import com.typesafe.config.Config; - import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; public class Transformer implements Serializable { private static final long serialVersionUID = 1569952057032186608L; // The transforms to be applied to the context, item and combined // (context | item) respectively. - private final Transform contextTransform; - private final Transform itemTransform; - private final Transform combinedTransform; + private final Transform contextTransform; + private final Transform itemTransform; + private final Transform combinedTransform; + + public Transformer(Config config, String key, FeatureRegistry registry) { + this(config, key, registry, null); + } - public Transformer(Config config, String key) { + public Transformer(Config config, String key, FeatureRegistry registry, AbstractModel model) { // Configures the model transforms. // context_transform : name of ListTransform to apply to context // item_transform : name of ListTransform to apply to each item // combined_transform : name of ListTransform to apply to each (item context) pair String contextTransformName = config.getString(key + ".context_transform"); - contextTransform = TransformFactory.createTransform(config, contextTransformName); + contextTransform = TransformFactory.createTransform(config, contextTransformName, + registry, model); String itemTransformName = config.getString(key + ".item_transform"); - itemTransform = TransformFactory.createTransform(config, itemTransformName); + itemTransform = TransformFactory.createTransform(config, itemTransformName, + registry, model); String combinedTransformName = config.getString(key + ".combined_transform"); - combinedTransform = TransformFactory.createTransform(config, combinedTransformName); - } - - // Helper functions for transforming context, items or combined feature vectors. - public void transformContext(FeatureVector context) { - if (contextTransform != null && context != null) { - contextTransform.doTransform(context); - } - } - - public void transformItem(FeatureVector item) { - if (itemTransform != null && item != null) { - itemTransform.doTransform(item); - } - } - - public void transformItems(List items) { - if (items != null) { - for (FeatureVector item : items) { - transformItem(item); - } - } - } - - public void transformCombined(FeatureVector combined) { - if (combinedTransform != null && combined != null) { - combinedTransform.doTransform(combined); - } + combinedTransform = TransformFactory.createTransform(config, combinedTransformName, + registry, model); } - // In place apply all the transforms to the context and items - // and apply the combined transform to items. - public void combineContextAndItems(Example examples) { - transformContext(examples.context); - transformItems(examples.example); - addContextToItemsAndTransform(examples); + public Transform getContextTransform() { + return contextTransform; } - // Adds the context to items and applies the combined transform - public void addContextToItemsAndTransform(Example examples) { - Map> contextStringFeatures = null; - Map> contextFloatFeatures = null; - Map> contextDenseFeatures = null; - if (examples.context != null) { - if (examples.context.stringFeatures != null) { - contextStringFeatures = examples.context.getStringFeatures(); - } - if (examples.context.floatFeatures != null) { - contextFloatFeatures = examples.context.getFloatFeatures(); - } - if (examples.context.denseFeatures != null) { - contextDenseFeatures = examples.context.getDenseFeatures(); - } - } - for (FeatureVector item : examples.example) { - addContextToItemAndTransform( - contextStringFeatures, contextFloatFeatures, contextDenseFeatures, item); - } + public Transform getItemTransform() { + return itemTransform; } - public void addContextToItemAndTransform(Map> contextStringFeatures, - Map> contextFloatFeatures, - Map> contextDenseFeatures, - FeatureVector item) { - if (contextStringFeatures != null) { - if (item.getStringFeatures() == null) { - item.setStringFeatures(new HashMap<>()); - } - Map> itemStringFeatures = item.getStringFeatures(); - for (Map.Entry> stringFeature : contextStringFeatures.entrySet()) { - Set stringFeatureValueCopy = new HashSet<>(stringFeature.getValue()); - itemStringFeatures.put(stringFeature.getKey(), stringFeatureValueCopy); - } - } - if (contextFloatFeatures != null) { - if (item.getFloatFeatures() == null) { - item.setFloatFeatures(new HashMap<>()); - } - Map> itemFloatFeatures = item.getFloatFeatures(); - for (Map.Entry> floatFeature : contextFloatFeatures.entrySet()) { - Map floatFeatureValueCopy = new HashMap<>(floatFeature.getValue()); - itemFloatFeatures.put(floatFeature.getKey(), floatFeatureValueCopy); - } - } - if (contextDenseFeatures != null) { - if (item.getDenseFeatures() == null) { - item.setDenseFeatures(new HashMap<>()); - } - Map> itemDenseFeatures = item.getDenseFeatures(); - for (Map.Entry> denseFeature : contextDenseFeatures.entrySet()) { - List denseFeatureValueCopy = new ArrayList<>(denseFeature.getValue()); - itemDenseFeatures.put(denseFeature.getKey(), denseFeatureValueCopy); - } - } - transformCombined(item); + public Transform getCombinedTransform() { + return combinedTransform; } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/WtaTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/WtaTransform.java index da37b490..f7fbacb6 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/WtaTransform.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/WtaTransform.java @@ -1,10 +1,29 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.util.Util; +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.base.ConfigurableTransform; +import com.google.common.base.Preconditions; import com.typesafe.config.Config; -import java.util.*; +import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import org.hibernate.validator.constraints.NotEmpty; + +import javax.validation.constraints.Max; +import javax.validation.constraints.NotNull; /** * A transform that applies the winner takes all hash to @@ -15,34 +34,60 @@ * to generate 2-bit tokens * and pack each word with num_tokens_per_word of these. */ -public class WtaTransform implements Transform { - private List fieldNames; - private String outputName; +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(fluent = true, chain = true) +@NoArgsConstructor(access = AccessLevel.PACKAGE) +public class WtaTransform extends ConfigurableTransform { + private static final byte WINDOW_SIZE = 4; + + @NotNull + @NotEmpty + private List familyNames; + @NotNull + private String outputFamilyName; + // The seed of the random number generator. private int seed; + // The number of words per feature. private int numWordsPerFeature; + // The number of tokens per word. + @Max(16) private int numTokensPerWord; - private final byte windowSize = 4; + + @Setter(AccessLevel.NONE) + private Random rnd; + @Setter(AccessLevel.NONE) + private List families; + @Setter(AccessLevel.NONE) + private Family outputFamily; @Override - public void configure(Config config, String key) { - // What fields to use to construct the hash. - fieldNames = config.getStringList(key + ".field_names"); - // Name of field to output to. - outputName = config.getString(key + ".output"); - // The seed of the random number generator. - seed = config.getInt(key + ".seed"); - // The number of words per feature. - numWordsPerFeature = config.getInt(key + ".num_words_per_feature"); - // The number of tokens per word. - numTokensPerWord = config.getInt(key + ".num_tokens_per_word"); - assert(numTokensPerWord <= 16); + public WtaTransform configure(Config config, String key) { + return + familyNames(stringListFromConfig(config, key, ".field_names", true)) + .outputFamilyName(stringFromConfig(config, key, ".output")) + .seed(intFromConfig(config, key, ".seed", false, (int) System.currentTimeMillis())) + .numWordsPerFeature(intFromConfig(config, key, ".num_words_per_feature", false)) + .numTokensPerWord(intFromConfig(config, key, ".num_tokens_per_word", false)); + } + + @Override + protected void setup() { + // TODO (Brad): I may be introducing a bug here. Need to confirm. It's expensive to generate + // a new Random on every transform. It's also a bit weird because it will produce the same + // values for every transform this way. Is that intentional? Do we want "deterministic" + // randomness for some reason? + rnd = new Random(seed); + families = familyNames.stream() + .map(registry::family) + .collect(Collectors.toList()); + outputFamily = registry.family(outputFamilyName); } // Generates a permutation of the array and appends it // to a given deque. private void generatePermutation(int size, - Random rnd, - Deque dq) { + IntArrayFIFOQueue dq) { dq.clear(); int[] permutation = new int[size]; for (int i = 0; i < size; i++) { @@ -55,20 +100,19 @@ private void generatePermutation(int size, permutation[other] = tmp; } for (int i = 0; i < size; i++) { - dq.add(permutation[i]); + dq.enqueue(permutation[i]); } } - private int getToken(Deque dq, - List feature, - Random rnd) { - if (dq.size() < windowSize) { - generatePermutation(feature.size(), rnd, dq); + private int getToken(IntArrayFIFOQueue dq, + double[] features) { + if (dq.size() < WINDOW_SIZE) { + generatePermutation(features.length, dq); } byte largest = 0; - double largestValue = feature.get(dq.pollFirst()); - for (byte i = 1; i < windowSize; i++) { - double value = feature.get(dq.pollFirst()); + double largestValue = features[dq.dequeueInt()]; + for (byte i = 1; i < WINDOW_SIZE; i++) { + double value = features[dq.dequeueInt()]; if (value > largestValue) { largestValue = value; largest = i; @@ -77,49 +121,40 @@ private int getToken(Deque dq, return largest; } - private int getWord(Deque dq, - List feature, - Random rnd) { + private int getWord(IntArrayFIFOQueue dq, + double[] features) { int result = 0; for (int i = 0; i < numTokensPerWord; i++) { - result |= getToken(dq, feature, rnd) << 2 * i; + result |= getToken(dq, features) << 2 * i; } return result; } - // Returns the "words" for a feature. - // A word is compok - private void getWordsForFeature(Set output, - String featureName, - Map> denseFeatures) { - List feature = denseFeatures.get(featureName); - if (feature == null) { + // Returns the "words" for a feature. Compok is not a word. + private void getWordsForFeature(FamilyVector vector, Set outputs) { + if (vector == null) { return; } - assert (feature instanceof ArrayList); - Random rnd = new Random(seed); - Deque dq = new ArrayDeque<>(); + Preconditions.checkArgument(vector instanceof DenseVector, + "Each family in WTAHashTransform must be a DenseVector."); + double[] features = ((DenseVector)vector).denseArray(); + // We switch from Dequeue to IntArrayFIFOQueue to avoid boxing and unboxing ints. + IntArrayFIFOQueue dq = new IntArrayFIFOQueue(features.length); for (int i = 0; i < numWordsPerFeature; i++) { - String word = featureName + i + ':' + getWord(dq, feature, rnd); - output.add(word); + String word = vector.family().name() + i + ':' + getWord(dq, features); + outputs.add(word); } } @Override - public void doTransform(FeatureVector featureVector) { - Map> denseFeatures = featureVector.getDenseFeatures(); - if (denseFeatures == null) { - return; + protected void doTransform(MultiFamilyVector featureVector) { + Set outputs = new HashSet<>(); + for (Family family : families) { + getWordsForFeature(featureVector.get(family), outputs); } - Set output = new HashSet<>(); - - for (String featureName : fieldNames) { - getWordsForFeature(output, featureName, denseFeatures); + for (String output : outputs) { + featureVector.putString(outputFamily.feature(output)); } - - Util.optionallyCreateStringFeatures(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - stringFeatures.put(outputName, output); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BaseFeaturesTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BaseFeaturesTransform.java new file mode 100644 index 00000000..2ffee4f3 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BaseFeaturesTransform.java @@ -0,0 +1,144 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import com.typesafe.config.Config; +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.Spliterator; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import lombok.Getter; + +import javax.validation.constraints.NotNull; + +@SuppressWarnings("unchecked") +public abstract class BaseFeaturesTransform + extends ConfigurableTransform { + @Getter + @NotNull + protected String inputFamilyName; + + @Getter + protected String outputFamilyName; + + @Getter + protected Set inputFeatureNames; + + @Getter + protected Set excludedFeatureNames; + + protected BaseFeaturesTransform() { + } + + protected Family inputFamily; + protected Family outputFamily; + protected List inputFeatures; + protected Reference2ObjectMap outputFeatures; + protected Set excludedFeatures; + + public T inputFamilyName(String name) { + this.inputFamilyName = name; + + return (T) this; + } + + public T outputFamilyName(String name) { + this.outputFamilyName = name; + return (T) this; + } + + public T inputFeatureNames(Set names) { + inputFeatureNames = names; + return (T) this; + } + + public T excludedFeatureNames(Set names) { + this.excludedFeatureNames = names == null ? null : ImmutableSet.copyOf(names); + return (T) this; + } + + @Override + protected void setup() { + super.setup(); + this.inputFamily = registry.family(inputFamilyName); + this.outputFamily = outputFamilyName == null + ? this.inputFamily + : registry.family(outputFamilyName); + if (inputFeatureNames != null) { + inputFeatures = new ArrayList<>(); + for (String featureName : inputFeatureNames) { + Feature feature = inputFamily.feature(featureName); + inputFeatures.add(feature); + if (outputFamily != inputFamily) { + if (outputFeatures == null) { + outputFeatures = new Reference2ObjectOpenHashMap<>(); + } + outputFeatures.put(feature, outputFamily.feature(produceOutputFeatureName(featureName))); + } + } + } + if (excludedFeatureNames != null) { + excludedFeatures = new HashSet<>(); + for (String featureName : excludedFeatureNames) { + excludedFeatures.add(inputFamily.feature(featureName)); + } + } + } + + protected String produceOutputFeatureName(String featureName) { + return featureName; + } + + protected Iterable getInput(MultiFamilyVector featureVector) { + Spliterator spliterator; + if (inputFeatures != null) { + spliterator = featureVector.iterateMatching(inputFeatures).spliterator(); + } else { + spliterator = featureVector.get(inputFamily).spliterator(); + } + Stream stream = StreamSupport.stream(spliterator, false); + if (excludedFeatures != null) { + stream = stream.filter(value -> !excludedFeatures.contains(value.feature())); + } + return stream::iterator; + } + + protected Feature produceOutputFeature(Feature feature) { + if (outputFeatures != null) { + return outputFeatures.get(feature); + } + // TODO (Brad): Handle the case where outputFamily == inputFamily and we're not going to do + // any sort of transform here. + return outputFamily.feature(produceOutputFeatureName(feature.name())); + } + + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && vector.contains(inputFamily); + } + + @Override + public T configure(Config config, String key) { + // There are two ways we specified input feature names in configs. No one should specify + // both So we try each one. + Set inputFeatureNames = stringSetFromConfig(config, key, ".keys", false); + Set selectFeatureNames = stringSetFromConfig(config, key, ".select_features", false); + inputFeatureNames = inputFeatureNames == null ? selectFeatureNames : + selectFeatureNames == null ? inputFeatureNames : + Sets.union(inputFeatureNames, selectFeatureNames); + return (T) + inputFamilyName(stringFromConfig(config, key, ".field1")) + .outputFamilyName(stringFromConfig(config, key, ".output", false)) + .inputFeatureNames(inputFeatureNames) + .excludedFeatureNames(stringSetFromConfig(config, key, ".exclude_features", false)); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BoundedFeaturesTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BoundedFeaturesTransform.java new file mode 100644 index 00000000..725471ee --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/BoundedFeaturesTransform.java @@ -0,0 +1,34 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.typesafe.config.Config; + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class BoundedFeaturesTransform + extends BaseFeaturesTransform { + protected double lowerBound = -Double.MAX_VALUE; + protected double upperBound = Double.MAX_VALUE; + + protected BoundedFeaturesTransform() { + } + + public T lowerBound(double bound) { + this.lowerBound = bound; + return (T) this; + } + + public T upperBound(double bound) { + this.upperBound = bound; + return (T) this; + } + + @Override + public T configure(Config config, String key) { + return (T) super.configure(config, key) + .lowerBound(doubleFromConfig(config, key, ".lower_bound", false, -Double.MAX_VALUE)) + .upperBound(doubleFromConfig(config, key, ".upper_bound", false, Double.MAX_VALUE)); + } + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/ConfigurableTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/ConfigurableTransform.java new file mode 100644 index 00000000..19fc69bb --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/ConfigurableTransform.java @@ -0,0 +1,205 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.Transform; +import com.google.common.collect.ImmutableSet; +import com.typesafe.config.Config; +import com.typesafe.config.ConfigObject; +import com.typesafe.config.ConfigValue; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Collectors; +import lombok.Getter; + +import javax.validation.ConstraintViolation; +import javax.validation.Validation; +import javax.validation.ValidatorFactory; +import javax.validation.constraints.NotNull; + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class ConfigurableTransform + implements Transform { + private static final ValidatorFactory + VALIDATION_FACTORY = Validation.buildDefaultValidatorFactory(); + + private boolean setupComplete = false; + + @Getter + @NotNull + protected FeatureRegistry registry; + + public T registry(FeatureRegistry registry) { + this.registry = registry; + return (T) this; + } + + protected ConfigurableTransform() { + } + + public abstract T configure(Config config, String key); + + abstract protected void doTransform(MultiFamilyVector vector); + + @Override + public MultiFamilyVector apply(MultiFamilyVector vector) { + if (!setupComplete) { + validate(); + setup(); + setupComplete = true; + } + if (checkPreconditions(vector)) { + doTransform(vector); + } + return vector; + } + + protected boolean checkPreconditions(MultiFamilyVector vector) { + return true; + } + + protected void setup() { + // Do nothing. Override. + } + + protected void validate() { + // This sort of sucks and I wish I could make it immutable instead. + // This is a small speed hack. If the user mutates the Transformer after this call, + // things can get bad. But we don't have immutability right now and I don't want to + // much things up with a dirty flag. So this will have to do for the time being. + if (setupComplete) { + return; + } + Set>> violations = + VALIDATION_FACTORY.getValidator().validate(this); + int numViolations = violations.size(); + + if (numViolations > 0) { + String violationMessage = String.join("\n", violations.stream() + .map(v -> v.getPropertyPath() + " " + v.getMessage()).collect(Collectors.toList())); + throw new IllegalArgumentException( + String.format("Transformer failed validation with %d violations:\n %s", numViolations, + violationMessage)); + } + } + + protected static String stringFromConfig(Config config, String key, String field) { + return stringFromConfig(config, key, field, true); + } + + protected static String stringFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + String path = key + field; + if (failIfAbsent || config.hasPath(path)) { + // This fails with an exception if the path doesn't exist. + return config.getString(key + field); + } + return null; + } + + protected static Double doubleFromConfig(Config config, String key, String field) { + return doubleFromConfig(config, key, field, true, null); + } + + protected static Double doubleFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + return doubleFromConfig(config, key, field, failIfAbsent, null); + } + + protected static Double doubleFromConfig(Config config, String key, String field, + boolean failIfAbsent, Double defaultValue) { + String path = key + field; + if (failIfAbsent || config.hasPath(path)) { + return config.getDouble(path); + } + return defaultValue; + } + + protected static Integer intFromConfig(Config config, String key, String field) { + return intFromConfig(config, key, field, true, null); + } + + protected static Integer intFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + return intFromConfig(config, key, field, failIfAbsent, null); + } + + protected static Integer intFromConfig(Config config, String key, String field, + boolean failIfAbsent, + Integer defaultValue) { + String fullPath = key + field; + if (failIfAbsent || config.hasPath(fullPath)) { + return config.getInt(fullPath); + } + return defaultValue; + } + + protected static boolean booleanFromConfig(Config config, String key, String field) { + String fullPath = key + field; + return config.hasPath(fullPath) && config.getBoolean(fullPath); + } + + protected static Set stringSetFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + List result = stringListFromConfig(config, key, field, failIfAbsent); + return result == null ? null : ImmutableSet.copyOf(result); + } + + protected static List stringListFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + String fullKey = key + field; + if (failIfAbsent || config.hasPath(fullKey)) { + return config.getStringList(fullKey); + } + return null; + } + + protected static List doubleListFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + String fullKey = key + field; + if (failIfAbsent || config.hasPath(fullKey)) { + return config.getDoubleList(fullKey); + } + return null; + } + + protected static TreeMap doubleTreeMapFromConfig(Config config, String key, + String field, + boolean failIfAbsent) { + String fullPath = key + field; + if (failIfAbsent || config.hasPath(fullPath)) { + TreeMap parsedTokensMap = new TreeMap<>(); + for (ConfigObject configObject : config.getObjectList(fullPath)) { + List> entries = new ArrayList<>(configObject.entrySet()); + parsedTokensMap.put(Double.parseDouble(entries.get(0).getKey()), + Double.parseDouble(entries.get(0).getValue().unwrapped().toString())); + } + + return parsedTokensMap; + } + return null; + } + + protected static Map stringMapFromConfig(Config config, String key, String field, + boolean failIfAbsent) { + String fullPath = key + field; + if (failIfAbsent || config.hasPath(fullPath)) { + return config.getObjectList(fullPath) + .stream() + .flatMap((ConfigObject o) -> o.unwrapped().entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, e -> (String) e.getValue())); + } + return null; + } + + + protected static String getTransformType(Config config, String key) { + return stringFromConfig(config, key, ".transform", false); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFamilyTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFamilyTransform.java new file mode 100644 index 00000000..4725e101 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFamilyTransform.java @@ -0,0 +1,40 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.typesafe.config.Config; + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class DualFamilyTransform + extends BaseFeaturesTransform { + protected String otherFamilyName; + + protected Family otherFamily; + + protected DualFamilyTransform() { + } + + public T otherFamilyName(String name) { + this.otherFamilyName = name; + return (T) this; + } + + @Override + protected void setup() { + super.setup(); + this.otherFamily = otherFamilyName == null ? inputFamily : registry.family(otherFamilyName); + } + + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && vector.contains(otherFamily); + } + + public T configure(Config config, String key) { + return (T) super.configure(config, key) + .otherFamilyName(stringFromConfig(config, key, ".field2", false)); + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFeatureTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFeatureTransform.java new file mode 100644 index 00000000..4789f19e --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/DualFeatureTransform.java @@ -0,0 +1,60 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.typesafe.config.Config; +import lombok.Getter; + + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class DualFeatureTransform + extends SingleFeatureTransform{ + + @Getter + protected String otherFeatureName; + + @Getter + protected String otherFamilyName; + + protected Family otherFamily; + protected Feature otherFeature; + + protected DualFeatureTransform() { + } + + public T otherFamilyName(String name) { + this.otherFamilyName = name; + return (T) this; + } + + public T otherFeatureName(String name) { + this.otherFeatureName = name; + return (T) this; + } + + @Override + protected void setup() { + super.setup(); + this.otherFamily = otherFamilyName == null ? inputFamily : registry.family(otherFamilyName); + this.otherFeature = otherFamily.feature(otherFeatureName); + } + + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && vector.containsKey(otherFeature); + } + + public T configure(Config config, String key) { + return (T) super.configure(config, key) + .otherFamilyName(stringFromConfig(config, key, ".field2", false)) + .otherFeatureName(stringFromConfig(config, key, otherFeatureKey())); + } + + protected String otherFeatureKey() { + return ".value2"; + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/OtherFeatureTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/OtherFeatureTransform.java new file mode 100644 index 00000000..26d33fae --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/OtherFeatureTransform.java @@ -0,0 +1,51 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.typesafe.config.Config; +import lombok.Getter; + +import javax.validation.constraints.NotNull; + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class OtherFeatureTransform + extends DualFamilyTransform { + + @Getter + @NotNull + protected String otherFeatureName; + + protected Feature otherFeature; + + protected OtherFeatureTransform() { + } + + public T otherFeatureName(String name) { + this.otherFeatureName = name; + return (T) this; + } + + @Override + protected void setup() { + super.setup(); + otherFeature = otherFamily.feature(otherFeatureName); + } + + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && vector.containsKey(otherFeature); + } + + @Override + public T configure(Config config, String key) { + return (T) super.configure(config, key) + .otherFeatureName(stringFromConfig(config, key, otherFeatureKey(), false)); + } + + protected String otherFeatureKey() { + return ".value2"; + } +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/SingleFeatureTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/SingleFeatureTransform.java new file mode 100644 index 00000000..05df4090 --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/SingleFeatureTransform.java @@ -0,0 +1,79 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.typesafe.config.Config; +import lombok.Getter; + +import javax.validation.constraints.NotNull; + +/** + * + */ +@SuppressWarnings("unchecked") +public abstract class SingleFeatureTransform + extends ConfigurableTransform{ + + @Getter + @NotNull + protected String inputFamilyName; + + @Getter + @NotNull + protected String inputFeatureName; + + @Getter + protected String outputFamilyName; + + protected Family inputFamily; + protected Feature inputFeature; + protected Family outputFamily; + + protected SingleFeatureTransform() { + } + + public T inputFamilyName(String name) { + this.inputFamilyName = name; + + return (T) this; + } + + public T outputFamilyName(String name) { + this.outputFamilyName = name; + return (T) this; + } + + public T inputFeatureName(String name) { + this.inputFeatureName = name; + return (T) this; + } + + @Override + protected void setup() { + super.setup(); + this.inputFamily = registry.family(inputFamilyName); + this.inputFeature = inputFamily.feature(inputFeatureName); + this.outputFamily = outputFamilyName == null + ? this.inputFamily + : registry.family(outputFamilyName); + } + + @Override + protected boolean checkPreconditions(MultiFamilyVector vector) { + return super.checkPreconditions(vector) && vector.containsKey(inputFeature); + } + + @Override + public T configure(Config config, String key) { + return (T) + inputFamilyName(stringFromConfig(config, key, ".field1")) + .inputFeatureName(stringFromConfig(config, key, inputFeatureKey())) + .outputFamilyName(stringFromConfig(config, key, ".output", false)); + } + + protected String inputFeatureKey() { + return ".value1"; + } + +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/StringTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/StringTransform.java new file mode 100644 index 00000000..f70fedda --- /dev/null +++ b/core/src/main/java/com/airbnb/aerosolve/core/transforms/base/StringTransform.java @@ -0,0 +1,41 @@ +package com.airbnb.aerosolve.core.transforms.base; + +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import java.util.HashSet; + +/** + * Abstract representation of a transform that processes all strings in a string feature and + * outputs a new string feature or overwrites /replaces the input string feature. + * "field1" specifies the key of the feature + * "output" optionally specifies the key of the output feature, if it is not given the transform + * overwrites / replaces the input feature + */ +public abstract class StringTransform extends BaseFeaturesTransform { + + protected StringTransform() { + } + + @Override + public void doTransform(MultiFamilyVector featureVector) { + HashSet processedStrings = new HashSet<>(); + + for (FeatureValue featureValue : getInput(featureVector)) { + if (featureValue.feature().name() != null) { + processedStrings.add(processString(featureValue.feature().name())); + } + } + + // Check reference equality to determine whether the output should overwrite the input + if (outputFamily == inputFamily) { + // TODO (Brad): Not sure how I feel about doing this. Are we sure? + featureVector.remove(inputFamily); + } + + for (String string : processedStrings) { + featureVector.putString(outputFamily.feature(string)); + } + } + + public abstract String processString(String rawString); +} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/FloatTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/FloatTransform.java deleted file mode 100644 index 7f4451cf..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/FloatTransform.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.airbnb.aerosolve.core.transforms.types; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.transforms.Transform; -import com.airbnb.aerosolve.core.util.Util; -import com.typesafe.config.Config; - -import java.util.Map; - -public abstract class FloatTransform implements Transform { - protected String fieldName1; - protected String outputName; // output family name, if not specified, output to fieldName1 - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - if (config.hasPath(key + ".output")) { - outputName = config.getString(key + ".output"); - } else { - outputName = fieldName1; - } - init(config, key); - } - protected abstract void init(Config config, String key); - - @Override - public void doTransform(FeatureVector featureVector) { - Map> floatFeatures = featureVector.getFloatFeatures(); - - if (floatFeatures == null) { - return; - } - Map input = floatFeatures.get(fieldName1); - if (input == null) { - return; - } - Map output = Util.getOrCreateFloatFeature(outputName, floatFeatures); - output(input, output); - } - - protected abstract void output(Map input, Map output); -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/StringTransform.java b/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/StringTransform.java deleted file mode 100644 index 5480b8fa..00000000 --- a/core/src/main/java/com/airbnb/aerosolve/core/transforms/types/StringTransform.java +++ /dev/null @@ -1,69 +0,0 @@ -package com.airbnb.aerosolve.core.transforms.types; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.transforms.Transform; -import com.airbnb.aerosolve.core.util.Util; - -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; - -/** - * Abstract representation of a transform that processes all strings in a string feature and - * outputs a new string feature or overwrites /replaces the input string feature. - * "field1" specifies the key of the feature - * "output" optionally specifies the key of the output feature, if it is not given the transform - * overwrites / replaces the input feature - */ -public abstract class StringTransform implements Transform { - protected String fieldName1; - protected String outputName; - - @Override - public void configure(Config config, String key) { - fieldName1 = config.getString(key + ".field1"); - if (config.hasPath(key + ".output")) { - outputName = config.getString(key + ".output"); - } else { - outputName = fieldName1; - } - init(config, key); - } - - protected abstract void init(Config config, String key); - - @Override - public void doTransform(FeatureVector featureVector) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) { - return; - } - - Set feature1 = stringFeatures.get(fieldName1); - if (feature1 == null) { - return; - } - - HashSet processedStrings = new HashSet<>(); - - for (String rawString : feature1) { - if (rawString != null) { - String processedString = processString(rawString); - processedStrings.add(processedString); - } - } - - Set output = Util.getOrCreateStringFeature(outputName, stringFeatures); - - // Check reference equality to determine whether the output should overwrite the input - if (output == feature1) { - output.clear(); - } - - output.addAll(processedStrings); - } - - public abstract String processString(String rawString); -} diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/DateUtil.java b/core/src/main/java/com/airbnb/aerosolve/core/util/DateUtil.java index 733a0360..bfafe4af 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/DateUtil.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/DateUtil.java @@ -1,7 +1,5 @@ package com.airbnb.aerosolve.core.util; -import lombok.extern.slf4j.Slf4j; - import java.text.ParseException; import java.text.SimpleDateFormat; import java.time.LocalDate; @@ -10,6 +8,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Function; +import lombok.extern.slf4j.Slf4j; @Slf4j public class DateUtil { diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/Debug.java b/core/src/main/java/com/airbnb/aerosolve/core/util/Debug.java index e9789af9..047529f4 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/Debug.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/Debug.java @@ -2,138 +2,46 @@ import com.airbnb.aerosolve.core.Example; import com.airbnb.aerosolve.core.FeatureVector; -import lombok.extern.slf4j.Slf4j; -import org.apache.thrift.TDeserializer; -import org.apache.thrift.TSerializer; -import org.apache.thrift.protocol.TBinaryProtocol; - +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.SimpleExample; +import com.google.common.collect.MapDifference; +import com.google.common.collect.Maps; import java.io.FileOutputStream; +import java.io.FileWriter; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.Map; -import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.codec.binary.Base64; +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; @Slf4j public class Debug { public static int printDiff(FeatureVector a, FeatureVector b) { - final Map> stringFeaturesA = a.getStringFeatures(); - final Map> stringFeaturesB = b.getStringFeatures(); - int diff = printDiff(stringFeaturesA, stringFeaturesB); - final Map> floatFeaturesA = a.getFloatFeatures(); - final Map> floatFeaturesB = b.getFloatFeatures(); - diff += printFloatDiff(floatFeaturesA, floatFeaturesB); - return diff; - } - - private static int printFloatDiff(Map> a, - Map> b) { - int diff = 0; - for (Map.Entry> entry : a.entrySet()) { - String key = entry.getKey(); - Map bSet = b.get(key); - if (bSet == null) { - log.info("b miss float family {}", key); - diff++; - } else { - diff += printMapDiff(entry.getValue(), bSet); - } - } - - for (Map.Entry> entry : b.entrySet()) { - String key = entry.getKey(); - Map bSet = a.get(key); - if (bSet == null) { - log.info("a miss float family {}", key); - diff++; - } - } - return diff; - } - - private static int printMapDiff(Map a, Map b) { - int diff = 0; - for (Map.Entry entry : a.entrySet()) { - String key = entry.getKey(); - Double bValue = b.get(key); - if (bValue == null) { - log.info("b miss feature {} {}", key, entry.getValue()); - diff++; - } else { - if (Math.abs(bValue- entry.getValue()) > 0.01) { - log.info("feature {} a: {}, b: {}", key, entry.getValue(), bValue); - diff++; - } - } - } - - for (Map.Entry entry : b.entrySet()) { - String key = entry.getKey(); - Double bValue = a.get(key); - if (bValue == null) { - log.info("a miss feature {} {}", key, entry.getValue()); - diff++; - } - } - return diff; - } - + MapDifference diff = Maps.difference(a, b); - public static int printDiff(Map> a, Map> b) { - int diff = 0; - for (Map.Entry> entry : a.entrySet()) { - String key = entry.getKey(); - Set bSet = b.get(key); - if (bSet == null) { - log.info("b miss string family {}", key); - diff++; - } else { - diff += printDiff(entry.getValue(), bSet); - } - } - - for (Map.Entry> entry : b.entrySet()) { - String key = entry.getKey(); - Set bSet = a.get(key); - if (bSet == null) { - log.info("a miss string family {}", key); - diff++; - } - } - return diff; - } - - private static int printDiff(Set a, Set b) { - int diff = 0; - for(String s : a) { - if (!b.contains(s)) { - log.info("b missing {}", s); - diff++; - } - } - for(String s : b) { - if (!a.contains(s)) { - log.info("a missing {}", s); - diff++; - } - } - return diff; + // TODO (Brad): We can format this differently or something if needed. + // It should print out pretty reasonably though. + log.info(diff.toString()); + return diff.entriesDiffering().size() + + diff.entriesOnlyOnLeft().size() + + diff.entriesOnlyOnRight().size(); } /* loadExampleFromResource read example from resources folder, i.e. test/resources use it on unit test to load example from disk */ - public static Example loadExampleFromResource(String name) { + public static Example loadExampleFromResource(String name, FeatureRegistry registry) { URL url = Debug.class.getResource("/" + name); try { Path path = Paths.get(url.toURI()); byte[] bytes = Files.readAllBytes(path); - TDeserializer deserializer = new TDeserializer(new TBinaryProtocol.Factory()); - Example example = new Example(); - deserializer.deserialize(example, bytes); - return example; + return Util.decodeExample(new String(Base64.encodeBase64(bytes)), registry); } catch (Exception e) { e.printStackTrace(); } @@ -144,12 +52,13 @@ public static Example loadExampleFromResource(String name) { // Save example to path // If you hit permission error, touch and chmod the file public static void saveExample(Example example, String path) { - TSerializer serializer = new TSerializer(new TBinaryProtocol.Factory()); + // TODO (Brad): This base64 encoding stuff is crazy. Let's fix that. + String encoded = Util.encodeExample(example); + byte[] bytes = Base64.decodeBase64(encoded); try { - byte[] buf = serializer.serialize(example); - FileOutputStream fos = new FileOutputStream(path); - fos.write(buf); - fos.close(); + FileOutputStream out = new FileOutputStream(path); + out.write(bytes); + out.close(); } catch (Exception e) { e.printStackTrace(); } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureDictionary.java b/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureDictionary.java index af6678a5..a79fd332 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureDictionary.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureDictionary.java @@ -1,13 +1,13 @@ package com.airbnb.aerosolve.core.util; import com.airbnb.aerosolve.core.FeatureVector; -import lombok.Getter; -import lombok.Setter; - +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import java.io.Serializable; -import java.util.*; -import java.util.AbstractMap.SimpleEntry; +import java.util.Comparator; +import java.util.List; import java.util.Map.Entry; +import lombok.Setter; /** * A class that maintains a dictionary of features and returns @@ -15,18 +15,21 @@ */ public abstract class FeatureDictionary implements Serializable { + + protected final FeatureRegistry registry; /** * The dictionary to maintain */ @Setter - protected List dictionaryList; + protected List dictionaryList; - public FeatureDictionary() { + public FeatureDictionary(FeatureRegistry registry) { + this.registry = registry; } - class EntryComparator implements Comparator> { - public int compare(Entry e1, - Entry e2) { + class EntryComparator implements Comparator> { + public int compare(Entry e1, + Entry e2) { if (e1.getValue() > e2.getValue()) { return 1; } else if (e1.getValue() < e2.getValue()) { @@ -39,7 +42,7 @@ public int compare(Entry e1, /** * Returns the k-nearest neighbors as floatFeatures in featureVector. */ - public abstract FeatureVector getKNearestNeighbors( + public abstract MultiFamilyVector getKNearestNeighbors( KNearestNeighborsOptions options, FeatureVector featureVector); } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureVectorUtil.java b/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureVectorUtil.java index 802cd333..4505e8f2 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureVectorUtil.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/FeatureVectorUtil.java @@ -1,37 +1,28 @@ package com.airbnb.aerosolve.core.util; import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.List; -import java.util.Set; +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; public class FeatureVectorUtil { /** - * Computes the min kernel for one feature family. - * @param featureKey - name of feature e.g. "rgb" - * @param a - first feature vector - * @param b - second feature vector + * Computes the min kernel between two double arrays. + * @param a - first double array + * @param b - second double array * @return - sum(min(a(i), b(i)) */ - public static double featureMinKernel(String featureKey, - FeatureVector a, - FeatureVector b) { + public static double minKernel(double[] a, double[] b) { double sum = 0.0; - if (a.getDenseFeatures() == null || b.getDenseFeatures() == null) { - return 0.0; - } - List aFeat = a.getDenseFeatures().get(featureKey); - List bFeat = b.getDenseFeatures().get(featureKey); - if (aFeat == null || bFeat == null) { + if (a == null || b == null) { return 0.0; } - int count = aFeat.size(); - for (int i = 0; i < count; i++) { - if (aFeat.get(i) < bFeat.get(i)) { - sum += aFeat.get(i); - } else { - sum += bFeat.get(i); - } + + // This stops at the shorter array. Since we're taking the min, we can assume the shorter array + // could be interpreted as 0 beyond it's length and that would be less than b. Is this true? + // what if the other array has negatives? + for (int i = 0; i < Math.min(a.length, b.length); i++) { + sum += Math.min(a[i], b[i]); } return sum; } @@ -44,14 +35,23 @@ public static double featureMinKernel(String featureKey, */ public static double featureVectorMinKernel(FeatureVector a, FeatureVector b) { - double sum = 0.0; - if (a.getDenseFeatures() == null) { - return 0.0; + if (a instanceof MultiFamilyVector && b instanceof MultiFamilyVector) { + double sum = 0.0; + + MultiFamilyVector multiB = (MultiFamilyVector) b; + for (FamilyVector vec : ((MultiFamilyVector) a).families()) { + sum += denseVectorMinKernel(vec, multiB.get(vec.family())); + } + return sum; } - Set keys = a.getDenseFeatures().keySet(); - for (String key : keys) { - sum += featureMinKernel(key, a, b); + return denseVectorMinKernel(a, b); + } + + private static double denseVectorMinKernel(FeatureVector a, + FeatureVector b) { + if (a instanceof DenseVector && b instanceof DenseVector) { + return minKernel(a.denseArray(), b.denseArray()); } - return sum; + return 0.0; } } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/FloatVector.java b/core/src/main/java/com/airbnb/aerosolve/core/util/FloatVector.java index aa8bc7a6..6e9e9493 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/FloatVector.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/FloatVector.java @@ -1,12 +1,12 @@ package com.airbnb.aerosolve.core.util; -import lombok.Getter; -import lombok.Setter; - import java.io.Serializable; import java.util.Random; +import lombok.Getter; +import lombok.Setter; // TODO change to FloatVector +// TODO (Brad): Floats to Doubles public class FloatVector implements Serializable { private static final Random rnd = new java.util.Random(); @@ -148,7 +148,7 @@ public void add(FloatVector other) { } } - public void multiplyAdd(float w, FloatVector other) { + public void multiplyAdd(double w, FloatVector other) { assert(values.length == other.values.length); for (int i = 0; i < values.length; i++) { values[i] += w * other.values[i]; diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/KNearestNeighborsOptions.java b/core/src/main/java/com/airbnb/aerosolve/core/util/KNearestNeighborsOptions.java index d38ab4a0..792ee62c 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/KNearestNeighborsOptions.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/KNearestNeighborsOptions.java @@ -1,21 +1,14 @@ package com.airbnb.aerosolve.core.util; -import lombok.Getter; -import lombok.Setter; +import com.airbnb.aerosolve.core.features.Family; +import lombok.Builder; +import lombok.Value; +@Value +@Builder public class KNearestNeighborsOptions { - public KNearestNeighborsOptions() { - numNearest = 5; - idKey = ""; - outputKey = ""; - featureKey = ""; - } - @Getter @Setter - private int numNearest; - @Getter @Setter - public String idKey; - @Getter @Setter - public String outputKey; - @Getter @Setter - public String featureKey; + private final int numNearest; + private final Family idKey; + private final Family outputKey; + private final Family featureKey; } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/LocalitySensitiveHashSparseFeatureDictionary.java b/core/src/main/java/com/airbnb/aerosolve/core/util/LocalitySensitiveHashSparseFeatureDictionary.java index 198409b1..fe2f6c12 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/LocalitySensitiveHashSparseFeatureDictionary.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/LocalitySensitiveHashSparseFeatureDictionary.java @@ -1,12 +1,21 @@ package com.airbnb.aerosolve.core.util; import com.airbnb.aerosolve.core.FeatureVector; -import lombok.Setter; - -import java.io.Serializable; -import java.util.*; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.Sets; import java.util.AbstractMap.SimpleEntry; -import java.util.Map.Entry; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; /** * A class that maintains a dictionary for sparse features and returns @@ -15,57 +24,47 @@ public class LocalitySensitiveHashSparseFeatureDictionary extends FeatureDictionary { - private Map> LSH; - private boolean haveLSH; + private Map> LSH; - public LocalitySensitiveHashSparseFeatureDictionary() { - haveLSH = false; + public LocalitySensitiveHashSparseFeatureDictionary(FeatureRegistry registry) { + super(registry); } - private int similarity(FeatureVector f1, - FeatureVector f2, - String featureKey) { - Set s1 = f1.getStringFeatures().get(featureKey); - if (s1 == null) { - return 0; - } - Set s2 = f2.getStringFeatures().get(featureKey); - if (s2 == null) { + private int similarity(MultiFamilyVector f1, + MultiFamilyVector f2, + Family featureKey) { + FamilyVector fam1 = f1.get(featureKey); + FamilyVector fam2 = f2.get(featureKey); + if (fam1 == null || fam2 == null) { return 0; } - Set intersection = new HashSet(s1); - intersection.retainAll(s2); - - return intersection.size(); + return Sets.intersection(fam1.keySet(), fam2.keySet()).size(); } // Builds the hash table lookup for the LSH. - private void buildHashTable(String featureKey) { + private void buildHashTable(Family featureKey) { LSH = new HashMap<>(); assert(dictionaryList instanceof ArrayList); int size = dictionaryList.size(); for (int i = 0; i < size; i++) { - FeatureVector featureVector = dictionaryList.get(i); - Set keys = featureVector.getStringFeatures().get(featureKey); - if (keys == null) { + MultiFamilyVector featureVector = dictionaryList.get(i); + FamilyVector vec = featureVector.get(featureKey); + if (vec == null) { continue; } - for (String key : keys) { - Set row = LSH.get(key); - if (row == null) { - row = new HashSet<>(); - LSH.put(key, row); - } + for (FeatureValue value : vec) { + Set row = LSH.computeIfAbsent(value.feature(), + f -> new HashSet<>()); row.add(i); } } } // Returns all the candidates with a hash overlap. - private Set getCandidates(Set keys) { + private Set getCandidates(FamilyVector vector) { Set result = new HashSet<>(); - for (String key : keys) { - Set row = LSH.get(key); + for (FeatureValue value : vector) { + Set row = LSH.get(value.feature()); if (row != null) { result.addAll(row); } @@ -74,60 +73,67 @@ private Set getCandidates(Set keys) { } @Override - public FeatureVector getKNearestNeighbors( + public MultiFamilyVector getKNearestNeighbors( KNearestNeighborsOptions options, FeatureVector featureVector) { - FeatureVector result = new FeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); + MultiFamilyVector result = new BasicMultiFamilyVector(registry); - if (stringFeatures == null) { + if (!(featureVector instanceof MultiFamilyVector)) { return result; } - String featureKey = options.getFeatureKey(); - Set keys = stringFeatures.get(featureKey); + MultiFamilyVector vector = (MultiFamilyVector) featureVector; + + Family featureKey = options.getFeatureKey(); + + FamilyVector keys = vector.get(featureKey); if (keys == null) { return result; } - if (!haveLSH) { + if (LSH == null) { buildHashTable(featureKey); } - String idKey = options.getIdKey(); - PriorityQueue> pq = new PriorityQueue<>( + Family idKey = options.getIdKey(); + PriorityQueue> pq = new PriorityQueue<>( options.getNumNearest() + 1, new EntryComparator()); - Map> floatFeatures = new HashMap<>(); - String myId = featureVector.getStringFeatures() - .get(idKey).iterator().next(); + FamilyVector idVector = vector.get(idKey); + if (idVector == null || idVector.isEmpty()) { + return result; + } + + Feature myId = idVector.iterator().next().feature(); Set candidates = getCandidates(keys); for (Integer candidate : candidates) { - FeatureVector supportVector = dictionaryList.get(candidate); - double sim = similarity(featureVector, - supportVector, - featureKey); - Set idSet = supportVector.getStringFeatures().get(idKey); - String id = idSet.iterator().next(); + MultiFamilyVector supportVector = dictionaryList.get(candidate); + double sim = similarity(vector, + supportVector, + featureKey); + + FamilyVector idSet = supportVector.get(idKey); + if (idSet == null) { + continue; + } + Feature id = idSet.iterator().next().feature(); if (id == myId) { continue; } - SimpleEntry entry = new SimpleEntry(id, sim); - pq.add(entry); + pq.add(new SimpleEntry<>(id, sim)); if (pq.size() > options.getNumNearest()) { pq.poll(); } } - HashMap newFeature = new HashMap<>(); + Family outputFamily = options.getOutputKey(); while (pq.peek() != null) { - SimpleEntry entry = pq.poll(); - newFeature.put(entry.getKey(), entry.getValue()); + SimpleEntry entry = pq.poll(); + Feature outputFeature = outputFamily.feature(entry.getKey().name()); + result.put(outputFeature, (double) entry.getValue()); } - floatFeatures.put(options.getOutputKey(), newFeature); - result.setFloatFeatures(floatFeatures); return result; } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/MinKernelDenseFeatureDictionary.java b/core/src/main/java/com/airbnb/aerosolve/core/util/MinKernelDenseFeatureDictionary.java index a0f9e830..4d3d3f41 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/MinKernelDenseFeatureDictionary.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/MinKernelDenseFeatureDictionary.java @@ -1,12 +1,14 @@ package com.airbnb.aerosolve.core.util; import com.airbnb.aerosolve.core.FeatureVector; -import lombok.Setter; - -import java.io.Serializable; -import java.util.*; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import java.util.AbstractMap.SimpleEntry; -import java.util.Map.Entry; +import java.util.PriorityQueue; /** * A class that maintains a dictionary for dense features and returns @@ -14,51 +16,64 @@ */ public class MinKernelDenseFeatureDictionary extends FeatureDictionary { - /** + + public MinKernelDenseFeatureDictionary(FeatureRegistry registry) { + super(registry); + } + /** * Calculates the Min Kernel distance to each dictionary element. * Returns the top K elements as a new sparse feature. */ @Override - public FeatureVector getKNearestNeighbors( + public MultiFamilyVector getKNearestNeighbors( KNearestNeighborsOptions options, FeatureVector featureVector) { - FeatureVector result = new FeatureVector(); - Map> denseFeatures = featureVector.getDenseFeatures(); + MultiFamilyVector result = new BasicMultiFamilyVector(registry); - if (denseFeatures == null) { + if (!(featureVector instanceof MultiFamilyVector)) { return result; } - PriorityQueue> pq = new PriorityQueue<>( + + MultiFamilyVector vector = (MultiFamilyVector) featureVector; + + PriorityQueue> pq = new PriorityQueue<>( options.getNumNearest() + 1, new EntryComparator()); - String idKey = options.getIdKey(); + Family idKey = options.getIdKey(); + + FamilyVector idVector = vector.get(idKey); + if (idVector == null || idVector.isEmpty()) { + return result; + } - Map> floatFeatures = new HashMap<>(); - String myId = featureVector.getStringFeatures() - .get(idKey).iterator().next(); + Feature myId = idVector.iterator().next().feature(); - for (FeatureVector supportVector : dictionaryList) { - Double minKernel = FeatureVectorUtil.featureVectorMinKernel(featureVector, + for (MultiFamilyVector supportVector : dictionaryList) { + double minKernel = FeatureVectorUtil.featureVectorMinKernel(featureVector, supportVector); - Set idSet = supportVector.getStringFeatures().get(idKey); - String id = idSet.iterator().next(); - if (id == myId) continue; - SimpleEntry entry = new SimpleEntry(id, minKernel); + FamilyVector idSet = supportVector.get(idKey); + if (idSet == null) { + continue; + } + Feature id = idSet.iterator().next().feature(); + if (id == myId) { + continue; + } + SimpleEntry entry = new SimpleEntry<>(id, minKernel); pq.add(entry); if (pq.size() > options.getNumNearest()) { pq.poll(); } } - HashMap newFeature = new HashMap<>(); + Family outputFamily = options.getOutputKey(); while (pq.peek() != null) { - SimpleEntry entry = pq.poll(); - newFeature.put(entry.getKey(), entry.getValue()); + SimpleEntry entry = pq.poll(); + Feature outputFeature = outputFamily.feature(entry.getKey().name()); + result.put(outputFeature, (double) entry.getValue()); } - floatFeatures.put(options.getOutputKey(), newFeature); - result.setFloatFeatures(floatFeatures); return result; } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/ReinforcementLearning.java b/core/src/main/java/com/airbnb/aerosolve/core/util/ReinforcementLearning.java index 7c25dea2..34475c1a 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/ReinforcementLearning.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/ReinforcementLearning.java @@ -4,19 +4,11 @@ * Utilities for reinforcement learning */ -import com.airbnb.aerosolve.core.models.AbstractModel; import com.airbnb.aerosolve.core.FeatureVector; -import com.google.common.hash.HashCode; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; - -import org.apache.commons.codec.binary.Base64; -import org.apache.thrift.TDeserializer; -import org.apache.thrift.TSerializer; -import org.apache.thrift.TBase; - +import com.airbnb.aerosolve.core.models.AbstractModel; import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.Random; public class ReinforcementLearning implements Serializable { // Updates a model using SARSA @@ -28,28 +20,28 @@ public static void updateSARSA(AbstractModel model, FeatureVector nextStateAction, float learningRate, float discountRate) { - Map> flatSA = Util.flattenFeature(stateAction); - float nextQ = 0.0f; + double nextQ = 0.0f; if (nextStateAction != null) { nextQ = model.scoreItem(nextStateAction); } - float currentQ = model.scoreItem(stateAction); - float expectedQ = reward + discountRate * nextQ; - float grad = currentQ - expectedQ; - model.onlineUpdate(grad, learningRate, flatSA); + double currentQ = model.scoreItem(stateAction); + double expectedQ = reward + discountRate * nextQ; + double grad = currentQ - expectedQ; + model.onlineUpdate(grad, learningRate, stateAction); } // Picks a random action with probability epsilon. - public static int epsilonGreedyPolicy(AbstractModel model, ArrayList stateAction, float epsilon, Random rnd) { + public static int epsilonGreedyPolicy(AbstractModel model, ArrayList stateAction, + double epsilon, Random rnd) { if (rnd.nextFloat() <= epsilon) { return rnd.nextInt(stateAction.size()); } int bestAction = 0; - float bestScore = model.scoreItem(stateAction.get(0)); + double bestScore = model.scoreItem(stateAction.get(0)); for (int i = 1; i < stateAction.size(); i++) { FeatureVector sa = stateAction.get(i); - float score = model.scoreItem(sa); + double score = model.scoreItem(sa); if (score > bestScore) { bestAction = i; bestScore = score; @@ -60,11 +52,12 @@ public static int epsilonGreedyPolicy(AbstractModel model, ArrayList stateAction, float temperature, Random rnd) { + public static int softmaxPolicy(AbstractModel model, ArrayList stateAction, + double temperature, Random rnd) { int count = stateAction.size(); - float[] scores = new float[count]; - float[] cumScores = new float[count]; - float maxVal = -1e10f; + double[] scores = new double[count]; + double[] cumScores = new double[count]; + double maxVal = -1e10d; for (int i = 0; i < count; i++) { FeatureVector sa = stateAction.get(i); scores[i] = model.scoreItem(sa); @@ -77,7 +70,7 @@ public static int softmaxPolicy(AbstractModel model, ArrayList st cumScores[i] += cumScores[i - 1]; } } - float threshold = rnd.nextFloat() * cumScores[count - 1]; + double threshold = rnd.nextFloat() * cumScores[count - 1]; for (int i = 0; i < count; i++) { if (threshold <= cumScores[i]) { return i; diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/StringDictionary.java b/core/src/main/java/com/airbnb/aerosolve/core/util/StringDictionary.java index dfe479d4..db46f3d4 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/StringDictionary.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/StringDictionary.java @@ -1,16 +1,14 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.DictionaryEntry; import com.airbnb.aerosolve.core.DictionaryRecord; - -import lombok.Getter; -import lombok.Setter; - +import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureValue; import java.io.Serializable; -import java.util.*; -import java.util.AbstractMap.SimpleEntry; -import java.util.Map.Entry; +import java.util.HashMap; +import java.util.Map; +import lombok.Getter; /** * A class that maps strings to indices. It can be used to map sparse @@ -37,7 +35,7 @@ public StringDictionary(DictionaryRecord dict) { // Returns the dictionary entry, null if not present public DictionaryEntry getEntry(String family, String feature) { - Map familyMap = dictionary.dictionary.get(family); + Map familyMap = dictionary.getDictionary().get(family); if (familyMap == null) { return null; } @@ -45,31 +43,31 @@ public DictionaryEntry getEntry(String family, String feature) { } // Returns -1 if key exists, the index it was inserted if successful. - public int possiblyAdd(String family, String feature, double mean, double scale) { - Map familyMap = dictionary.dictionary.get(family); + public int possiblyAdd(Feature feature, double mean, double scale) { + String familyName = feature.family().name(); + Map familyMap = dictionary.getDictionary().get(familyName); if (familyMap == null) { familyMap = new HashMap<>(); - dictionary.dictionary.put(family, familyMap); + dictionary.getDictionary().put(familyName, familyMap); } - if (familyMap.containsKey(feature)) return -1; + if (familyMap.containsKey(feature.name())) return -1; DictionaryEntry entry = new DictionaryEntry(); int currIdx = dictionary.getEntryCount(); entry.setIndex(currIdx); entry.setMean(mean); entry.setScale(scale); dictionary.setEntryCount(currIdx + 1); - familyMap.put(feature, entry); + familyMap.put(feature.name(), entry); return currIdx; } - public FloatVector makeVectorFromSparseFloats(Map> sparseFloats) { + public FloatVector makeVectorFromSparseFloats(FeatureVector vector) { FloatVector vec = new FloatVector(dictionary.getEntryCount()); - for (Map.Entry> kv : sparseFloats.entrySet()) { - for (Map.Entry feat : kv.getValue().entrySet()) { - DictionaryEntry entry = getEntry(kv.getKey(), feat.getKey()); - if (entry != null) { - vec.values[entry.index] = (float) entry.scale * (feat.getValue().floatValue() - (float) entry.mean); - } + for (FeatureValue value : vector) { + DictionaryEntry entry = getEntry(value.feature().family().name(), value.feature().name()); + if (entry != null) { + vec.values[entry.getIndex()] = (float) (entry.getScale() * + (value.value() - entry.getMean())); } } return vec; diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/SupportVector.java b/core/src/main/java/com/airbnb/aerosolve/core/util/SupportVector.java index 00530c10..be8733a0 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/SupportVector.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/SupportVector.java @@ -1,15 +1,13 @@ package com.airbnb.aerosolve.core.util; -import lombok.Getter; -import lombok.Setter; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Random; - import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelRecord; +import java.io.Serializable; +import java.util.ArrayList; +import lombok.Getter; +import lombok.Setter; +// TODO (Brad): Floats to Doubles public class SupportVector implements Serializable { // Dense support vector value. @Getter @Setter @@ -36,13 +34,13 @@ public SupportVector(FloatVector fv, FunctionForm f, float s, float wt) { } public SupportVector(ModelRecord rec) { - scale = (float) rec.scale; + scale = (float) rec.getScale(); form = rec.getFunctionForm(); weight = (float) rec.getFeatureWeight(); - int size = rec.weightVector.size(); + int size = rec.getWeightVector().size(); floatVector = new FloatVector(size); for (int i = 0; i < size; i++) { - floatVector.getValues()[i] = rec.weightVector.get(i).floatValue(); + floatVector.getValues()[i] = rec.getWeightVector().get(i).floatValue(); } } diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/Util.java b/core/src/main/java/com/airbnb/aerosolve/core/util/Util.java index 4a5538b5..6b72ad4d 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/Util.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/Util.java @@ -6,20 +6,59 @@ * Utilities for machine learning */ -import com.airbnb.aerosolve.core.*; +import com.airbnb.aerosolve.core.DebugScoreDiffRecord; +import com.airbnb.aerosolve.core.DebugScoreRecord; +import com.airbnb.aerosolve.core.Example; +import com.airbnb.aerosolve.core.KDTreeNode; +import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.ThriftExample; +import com.airbnb.aerosolve.core.ThriftFeatureVector; +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.features.SimpleExample; +import com.airbnb.aerosolve.core.transforms.LegacyNames; +import com.google.common.base.CaseFormat; import com.google.common.hash.HashCode; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; +import com.google.common.primitives.Doubles; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Serializable; +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.binary.Base64; import org.apache.thrift.TBase; import org.apache.thrift.TDeserializer; import org.apache.thrift.TSerializer; - -import java.io.*; -import java.util.*; import java.util.zip.GZIPInputStream; +import org.apache.commons.lang3.tuple.Pair; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; +import org.reflections.Reflections; +@Slf4j public class Util implements Serializable { + + // SimpleDateFormat is not thread safe. . . + public static final DateTimeFormatter DATE_FORMAT = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static double LOG2 = Math.log(2); // Coder / decoder utilities for various protos. This makes it easy to // manipulate in spark. e.g. if we wanted to see the 50 weights in a model @@ -33,30 +72,71 @@ public static String encode(TBase obj) { return ""; } } - public static FeatureVector decodeFeatureVector(String str) { - return decode(FeatureVector.class, str); + + public static String encodeFeatureVector(MultiFamilyVector vector) { + return encode(getThriftFeatureVector(vector)); } - public static FeatureVector createNewFeatureVector() { - FeatureVector featureVector = new FeatureVector(); - Map> floatFeatures = new HashMap<>(); + private static ThriftFeatureVector getThriftFeatureVector(MultiFamilyVector vector) { + ThriftFeatureVector tVec = new ThriftFeatureVector(); Map> stringFeatures = new HashMap<>(); - featureVector.setFloatFeatures(floatFeatures); - featureVector.setStringFeatures(stringFeatures); + Map> floatFeatures = new HashMap<>(); + Map> denseFeatures = new HashMap<>(); + for (FamilyVector familyVector : vector.families()) { + String familyName = familyVector.family().name(); + // This might seem to break compatibility because we write all SparseVectors + // as floatFeatures. But it's fine as long as we don't try to read this vector in + // old versions of aerosolve. + if (familyVector instanceof DenseVector) { + denseFeatures.put(familyName, Doubles.asList(((DenseVector) familyVector).denseArray())); + } else { + floatFeatures.put(familyName, familyVector.entrySet() + .stream() + .map(e -> Pair.of(e.getKey().name(), e.getValue())) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue))); + } + } + tVec.setDenseFeatures(denseFeatures); + tVec.setStringFeatures(stringFeatures); + tVec.setFloatFeatures(floatFeatures); + return tVec; + } - return featureVector; + public static String encodeExample(Example example) { + return encode(getThriftExample(example)); } - public static Example createNewExample() { - Example example = new Example(); - example.setContext(createNewFeatureVector()); - example.setExample(new ArrayList()); + public static ThriftExample getThriftExample(Example example) { + ThriftExample thriftExample = new ThriftExample(); + thriftExample.setContext(getThriftFeatureVector(example.context())); + for (MultiFamilyVector innerVec : example) { + thriftExample.addToExample(getThriftFeatureVector(innerVec)); + } + return thriftExample; + } - return example; + public static MultiFamilyVector decodeFeatureVector(String str, FeatureRegistry registry) { + ThriftFeatureVector tmp = new ThriftFeatureVector(); + try { + byte[] bytes = Base64.decodeBase64(str.getBytes()); + TDeserializer deserializer = new TDeserializer(); + deserializer.deserialize(tmp, bytes); + } catch (Exception e) { + log.error("Error deserializing ThriftFeatureVector", e); + } + return new BasicMultiFamilyVector(tmp, registry); } - public static Example decodeExample(String str) { - return decode(Example.class, str); + public static Example decodeExample(String str, FeatureRegistry registry) { + ThriftExample tmp = new ThriftExample(); + try { + byte[] bytes = Base64.decodeBase64(str.getBytes()); + TDeserializer deserializer = new TDeserializer(); + deserializer.deserialize(tmp, bytes); + } catch (Exception e) { + log.error("Error deserializing ThriftExample", e); + } + return new SimpleExample(tmp, registry); } public static ModelRecord decodeModel(String str) { @@ -74,6 +154,8 @@ public static T decode(T base, String str) { return base; } + + public static T decode(Class clazz, String str) { try { return decode(clazz.newInstance(), str); @@ -113,76 +195,6 @@ public static List readFromGzippedStream(Class clazz, In return Collections.EMPTY_LIST; } - public static void optionallyCreateStringFeatures(FeatureVector featureVector) { - if (featureVector.getStringFeatures() == null) { - Map> stringFeatures = new HashMap<>(); - featureVector.setStringFeatures(stringFeatures); - } - } - - public static void optionallyCreateFloatFeatures(FeatureVector featureVector) { - if (featureVector.getFloatFeatures() == null) { - Map> floatFeatures = new HashMap<>(); - featureVector.setFloatFeatures(floatFeatures); - } - } - - public static void setStringFeature( - FeatureVector featureVector, - String family, - String value) { - Map> stringFeatures = featureVector.getStringFeatures(); - if (stringFeatures == null) { - stringFeatures = new HashMap<>(); - featureVector.setStringFeatures(stringFeatures); - } - Set stringFamily = getOrCreateStringFeature(family, stringFeatures); - stringFamily.add(value); - } - - public static Set getOrCreateStringFeature( - String name, - Map> stringFeatures) { - Set output = stringFeatures.get(name); - if (output == null) { - output = new HashSet<>(); - stringFeatures.put(name, output); - } - return output; - } - - public static Map getOrCreateFloatFeature( - String name, - Map> floatFeatures) { - Map output = floatFeatures.get(name); - if (output == null) { - output = new HashMap<>(); - floatFeatures.put(name, output); - } - return output; - } - - public static Map> getOrCreateDenseFeatures(FeatureVector featureVector) { - if (featureVector.getDenseFeatures() == null) { - Map> dense = new HashMap<>(); - featureVector.setDenseFeatures(dense); - } - return featureVector.getDenseFeatures(); - } - - public static void setDenseFeature( - FeatureVector featureVector, - String name, - List value) { - Map> denseFeatures = featureVector.getDenseFeatures(); - if (denseFeatures == null) { - denseFeatures = new HashMap<>(); - featureVector.setDenseFeatures(denseFeatures); - } - denseFeatures.put(name, value); - } - - public static HashCode getHashCode(String family, String value) { Hasher hasher = Hashing.murmur3_128().newHasher(); hasher.putBytes(family.getBytes()); @@ -251,55 +263,30 @@ public static Map prepareRankMap(List scores, List> flattenFeature( - FeatureVector featureVector) { - Map> flatFeature = new HashMap<>(); - if (featureVector.stringFeatures != null) { - for (Map.Entry> entry : featureVector.stringFeatures.entrySet()) { - Map out = new HashMap<>(); - flatFeature.put(entry.getKey(), out); - for (String feature : entry.getValue()) { - out.put(feature, 1.0); - } - } - } - if (featureVector.floatFeatures != null) { - for (Map.Entry> entry : featureVector.floatFeatures.entrySet()) { - Map out = new HashMap<>(); - flatFeature.put(entry.getKey(), out); - for (Map.Entry feature : entry.getValue().entrySet()) { - out.put(feature.getKey(), feature.getValue()); - } - } - } - return flatFeature; + public static Map> loadFactoryNamesFromPackage( + Class superClass, String packageName, String endWord) { + Reflections reflections = new Reflections(packageName); + return reflections.getSubTypesOf(superClass).stream() + .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers())) + .flatMap(clazz -> getFactoryNames(clazz, endWord).stream()) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } - public static Map> flattenFeatureWithDropout( - FeatureVector featureVector, - double dropout) { - Map> flatFeature = new HashMap<>(); - if (featureVector.stringFeatures != null) { - for (Map.Entry> entry : featureVector.stringFeatures.entrySet()) { - Map out = new HashMap<>(); - flatFeature.put(entry.getKey(), out); - for (String feature : entry.getValue()) { - if (Math.random() < dropout) continue; - out.put(feature, 1.0); - } - } - } - if (featureVector.floatFeatures != null) { - for (Map.Entry> entry : featureVector.floatFeatures.entrySet()) { - Map out = new HashMap<>(); - flatFeature.put(entry.getKey(), out); - for (Map.Entry feature : entry.getValue().entrySet()) { - if (Math.random() < dropout) continue; - out.put(feature.getKey(), feature.getValue()); - } + public static Map> loadConstructorsFromPackage( + Class superClass, String packageName, String endWord, + Class constructorParam) { + Map> tmpMap = new HashMap<>(); + for (Map.Entry> entry : + Util.loadFactoryNamesFromPackage(superClass, packageName, endWord).entrySet()) { + try { + tmpMap.put(entry.getKey(), entry.getValue().getConstructor(constructorParam)); + } catch (NoSuchMethodException ex) { + throw new IllegalStateException("AbstractModel of type %s does not have a single argument" + + "constructor that takes a FeatureRegistry. " + + "Please add one."); } } - return flatFeature; + return tmpMap; } public static class DebugDiffRecordComparator implements Comparator { @@ -320,10 +307,10 @@ private static Map> debugScoreRecordListToMap(List> recordMap = new HashMap<>(); for(int i = 0; i < recordList.size(); i++){ - String key = recordList.get(i).featureFamily + '\t' + recordList.get(i).featureName; + String key = recordList.get(i).getFeatureFamily() + '\t' + recordList.get(i).getFeatureName(); Map record = new HashMap<>(); - record.put("featureValue", recordList.get(i).featureValue); - record.put("featureWeight", recordList.get(i).featureWeight); + record.put("featureValue", recordList.get(i).getFeatureValue()); + record.put("featureWeight", recordList.get(i).getFeatureWeight()); recordMap.put(key, record); } return recordMap; @@ -373,23 +360,46 @@ public static List compareDebugRecords(List y) { + public static double euclideanDistance(double[] x, List y) { assert (x.length == y.size()); double sum = 0; for (int i = 0; i < x.length; i++) { final double dp = x[i] - y.get(i); sum += dp * dp; } - return (float) Math.sqrt(sum); + return Math.sqrt(sum); } - public static float euclideanDistance(List x, List y) { + public static double euclideanDistance(List x, List y) { assert (x.size() == y.size()); double sum = 0; for (int i = 0; i < y.size(); i++) { final double dp = x.get(i) - y.get(i); sum += dp * dp; } - return (float) Math.sqrt(sum); + return Math.sqrt(sum); + } + + public static List>> getFactoryNames( + Class clazz, String endWord) { + List>> result = new ArrayList<>(); + + String baseName = clazz.getSimpleName(); + + // Cut off the word Model or Transform at the end + if (baseName.endsWith(endWord)) { + baseName = baseName.substring(0, baseName.length() - endWord.length()); + } + baseName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, baseName); + result.add(Pair.of(baseName, clazz)); + + // Handle any old names we used to use that are annotated on the class. + if (clazz.isAnnotationPresent(LegacyNames.class)) { + LegacyNames legacyNames = clazz.getAnnotation(LegacyNames.class); + for (String legacyName : legacyNames.value()) { + result.add(Pair.of(legacyName, clazz)); + } + } + return result; } } \ No newline at end of file diff --git a/core/src/main/java/com/airbnb/aerosolve/core/util/Weibull.java b/core/src/main/java/com/airbnb/aerosolve/core/util/Weibull.java index 0efe33da..4777d800 100644 --- a/core/src/main/java/com/airbnb/aerosolve/core/util/Weibull.java +++ b/core/src/main/java/com/airbnb/aerosolve/core/util/Weibull.java @@ -1,13 +1,14 @@ package com.airbnb.aerosolve.core.util; -import lombok.experimental.Builder; +import lombok.Builder; import lombok.extern.slf4j.Slf4j; /* weibull(x) = exp(a*(x)^k + b) default max x is Double.MAX_VALUE */ -@Slf4j @Builder public class Weibull { +@Slf4j @Builder +public class Weibull { final private double k, a, b, maxX; public WeibullBuilder defaultBuilder() { diff --git a/core/src/main/thrift/MLSchema.thrift b/core/src/main/thrift/MLSchema.thrift index a2c4744a..f8fad883 100644 --- a/core/src/main/thrift/MLSchema.thrift +++ b/core/src/main/thrift/MLSchema.thrift @@ -5,23 +5,20 @@ */ namespace java com.airbnb.aerosolve.core -// Function name correspondent to Function class, -// so it can be created by java reflection -// we can save string ModelRecord, but that breaks released model file -// so to add new function, please add to FunctionForm in order +// Please add new function to the end of FunctionForm in order not to break serialization. enum FunctionForm { - Spline, - Linear, + SPLINE, + LINEAR, RADIAL_BASIS_FUNCTION, ARC_COSINE, SIGMOID, RELU, TANH, IDENTITY, - MultiDimensionSpline + MULTI_DIMENSION_SPLINE } -struct FeatureVector { +struct ThriftFeatureVector { // The first field is the feature family. e.g. "geo" // The rest are string feature values. e.g. "SF," CA", "USA" // e.g. "geo" -> "San Francisco", "CA", "USA" @@ -42,13 +39,13 @@ struct FeatureVector { 3: optional map> denseFeatures; } -struct Example { +struct ThriftExample { // Repeated list of examples in a bag, e.g. groups by user session // or ranked list. - 1: optional list example; + 1: optional list example; // The context feature, e.g. query / user features that is in common // over the whole session. - 2: optional FeatureVector context; + 2: optional ThriftFeatureVector context; } struct DictionaryEntry { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureGenTest.java b/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureGenTest.java deleted file mode 100644 index e3fbbf66..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureGenTest.java +++ /dev/null @@ -1,28 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -public class FeatureGenTest { - - @Test - public void add() throws Exception { - FeatureMapping m = new FeatureMapping(100); - String[] doubleNames = {"a", "b"}; - m.add(Double.class, doubleNames); - String[] booleanNames = {"c", "d"}; - m.add(Boolean.class, booleanNames); - String[] strNames = {"e", "f"}; - m.add(String.class, strNames); - m.finish(); - - FeatureGen f = new FeatureGen(m); - f.add(new float[]{Float.MIN_VALUE, 5}, Double.class); - Features p = f.gen(); - assertEquals(p.names.length, 6); - assertEquals(p.values[0], null); - assertEquals((Double) p.values[1], 5, 0.1); - - } -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureMappingTest.java b/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureMappingTest.java deleted file mode 100644 index 09a52d29..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/features/FeatureMappingTest.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import org.junit.Test; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -public class FeatureMappingTest { - - @Test - public void add() throws Exception { - FeatureMapping m = new FeatureMapping(100); - String[] doubleNames = {"a", "b"}; - m.add(Double.class, doubleNames); - String[] booleanNames = {"c", "d"}; - m.add(Boolean.class, booleanNames); - String[] strNames = {"e", "f"}; - m.add(String.class, strNames); - m.finish(); - - assertEquals(m.getNames().length, 6); - assertArrayEquals(m.getNames(), - new String[]{"a", "b", "c", "d", "e", "f"}); - assertEquals(m.getMapping().get(String.class).start, 4); - assertEquals(m.getMapping().get(String.class).length, 2); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/features/FeaturesTest.java b/core/src/test/java/com/airbnb/aerosolve/core/features/FeaturesTest.java deleted file mode 100644 index 11b8349b..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/features/FeaturesTest.java +++ /dev/null @@ -1,186 +0,0 @@ -package com.airbnb.aerosolve.core.features; - -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Test; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -import static org.junit.Assert.*; - -public class FeaturesTest { - public static Features createFeature() { - Object[] values = new Object[6]; - String[] names = new String[6]; - - names[0] = Features.LABEL; - values[0] = new Double(5.0); - - names[1] = "f_RAW"; - values[1] = "raw_feature"; - names[2] = "K_star"; - values[2] = "monkey"; - names[3] = "K_good"; - values[3] = Boolean.FALSE; - names[4] = "S_speed"; - values[4] = new Double(10.0); - - names[5] = "X_jump"; - values[5] = null; - - return Features.builder().names(names).values(values).build(); - } - - public static Features createMultiClassFeature() { - Object[] values = new Object[1]; - String[] names = new String[1]; - - names[0] = Features.LABEL; - values[0] = "a:1,b:2"; - return Features.builder().names(names).values(values).build(); - } - - @Test - public void toExample() throws Exception { - Example example = createFeature().toExample(false); - FeatureVector featureVector = example.getExample().get(0); - final Map> stringFeatures = featureVector.getStringFeatures(); - final Map> floatFeatures = featureVector.getFloatFeatures(); - - // we have default BIAS - assertEquals(4, stringFeatures.size()); - Set stringFeature = stringFeatures.get("f"); - assertEquals(1, stringFeature.size()); - assertTrue(stringFeature.contains("raw_feature")); - - stringFeature = stringFeatures.get("K"); - assertEquals(2, stringFeature.size()); - assertTrue(stringFeature.contains("star:monkey")); - assertTrue(stringFeature.contains("good:F")); - - stringFeature = stringFeatures.get("X"); - assertNull(stringFeature); - - stringFeature = stringFeatures.get(Features.MISS); - assertEquals(1, stringFeature.size()); - assertTrue(stringFeature.contains("X_jump")); - - assertEquals(2, floatFeatures.size()); - Map floatFeature = floatFeatures.get("S"); - assertEquals(1, floatFeature.size()); - assertEquals(10.0, floatFeature.get("speed"), 0); - - floatFeature = floatFeatures.get(Features.LABEL); - assertEquals(1, floatFeature.size()); - assertEquals(5.0, floatFeature.get(Features.LABEL_FEATURE_NAME), 0); - } - - @Test - public void toExampleMultiClass() throws Exception { - Example example = createMultiClassFeature().toExample(true); - FeatureVector featureVector = example.getExample().get(0); - final Map> floatFeatures = featureVector.getFloatFeatures(); - - assertEquals(1, floatFeatures.size()); - Map floatFeature = floatFeatures.get(Features.LABEL); - assertEquals(2, floatFeature.size()); - assertEquals(1, floatFeature.get("a"), 0); - assertEquals(2, floatFeature.get("b"), 0); - } - - @Test - public void addNumberFeature() throws Exception { - Pair featurePair = new ImmutablePair<>("family", "feature"); - Map> floatFeatures = new HashMap<>(); - Features.addNumberFeature(4, featurePair, floatFeatures); - Map feature = floatFeatures.get("family"); - assertEquals(1, feature.size()); - assertEquals(4, feature.get("feature"), 0); - - featurePair = new ImmutablePair<>("family", "feature_float"); - Features.addNumberFeature(5.0f, featurePair, floatFeatures); - assertEquals(2, feature.size()); - assertEquals(5.0, feature.get("feature_float"), 0); - } - - @Test - public void addBoolFeature() throws Exception { - Pair featurePair = new ImmutablePair<>("family", "feature"); - Map> stringFeatures = new HashMap<>(); - Features.addBoolFeature(false, featurePair, stringFeatures); - Features.addBoolFeature(true, featurePair, stringFeatures); - Set feature = stringFeatures.get("family"); - assertEquals(2, feature.size()); - assertTrue(feature.contains("feature:T")); - assertTrue(feature.contains("feature:F")); - } - - @Test - public void addStringFeature() throws Exception { - Pair featurePair = new ImmutablePair<>("family", "feature"); - Map> stringFeatures = new HashMap<>(); - Features.addStringFeature("value", featurePair, stringFeatures); - - Pair raw = new ImmutablePair<>("family", Features.RAW); - Features.addStringFeature("feature_1", raw, stringFeatures); - - Set feature = stringFeatures.get("family"); - assertEquals(2, feature.size()); - assertTrue(feature.contains("feature:value")); - assertTrue(feature.contains("feature_1")); - } - - @Test - public void addMultiClassLabel() throws Exception { - Map> floatFeatures = new HashMap<>(); - - Features.addMultiClassLabel("a:1,b:2", floatFeatures); - Map feature = floatFeatures.get(Features.LABEL); - assertEquals(2, feature.size()); - assertEquals(1, feature.get("a"), 0); - assertEquals(2, feature.get("b"), 0); - } - - @Test (expected = RuntimeException.class) - public void addMultiClassManyColon() throws Exception { - Features.addMultiClassLabel("a:1:2,b:2", Collections.EMPTY_MAP); - } - - @Test (expected = RuntimeException.class) - public void addMultiClassLabelNoolon() throws Exception { - Features.addMultiClassLabel("abc,b:2", Collections.EMPTY_MAP); - } - - @Test - public void isLabel() throws Exception { - assertTrue(Features.isLabel(Features.getFamily("LABEL"))); - assertFalse(Features.isLabel(Features.getFamily("LABE_ab"))); - } - - @Test(expected = RuntimeException.class) - public void getFamilyEmpty() throws Exception { - Pair p = Features.getFamily(""); - } - - @Test(expected = RuntimeException.class) - public void getFamilyNotLABEL() throws Exception { - Pair p = Features.getFamily("LABE"); - } - - @Test(expected = RuntimeException.class) - public void getFamilyPrefix() throws Exception { - Pair p = Features.getFamily("_abc"); - } - - @Test - public void getFamily() throws Exception { - Pair p = Features.getFamily("f_ab_cd"); - assertEquals(p.getLeft(), "f"); - assertEquals(p.getRight(), "ab_cd"); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/features/InputGenerationTest.java b/core/src/test/java/com/airbnb/aerosolve/core/features/InputGenerationTest.java new file mode 100644 index 00000000..b8143813 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/features/InputGenerationTest.java @@ -0,0 +1,232 @@ +package com.airbnb.aerosolve.core.features; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class InputGenerationTest { + private static final FeatureRegistry registry = new FeatureRegistry(); + public static MultiFamilyVector createVector() { + Object[] values = new Object[6]; + String[] names = new String[6]; + + names[0] = GenericNamingConvention.LABEL; + values[0] = 5.0d; + + names[1] = "f_RAW"; + values[1] = "raw_feature"; + names[2] = "K_star"; + values[2] = "monkey"; + names[3] = "K_good"; + values[3] = false; + names[4] = "S_speed"; + values[4] = 10.0d; + + names[5] = "X_jump"; + values[5] = null; + + BasicMultiFamilyVector vector = new BasicMultiFamilyVector(registry); + return vector.putAll(names, values); + } + + public static MultiFamilyVector createMulticlassVector() { + Object[] values = new Object[1]; + String[] names = new String[1]; + + names[0] = GenericNamingConvention.LABEL; + values[0] = "a:1,b:2"; + + BasicMultiFamilyVector vector = new BasicMultiFamilyVector(registry); + return vector.putAll(names, values); + } + + @Test + public void addInputToGenerator() throws Exception { + InputSchema m = getInputSchema(); + + InputGenerator f = new InputGenerator(m); + f.add(new double[]{Double.MIN_VALUE, 5}, Double.class); + FeatureRegistry registry = new FeatureRegistry(); + BasicMultiFamilyVector vector = new BasicMultiFamilyVector(registry); + f.load(vector); + + assertEquals(vector.getDouble(registry.feature("z", "b")), 5d, 0.1); + } + + @Test + public void addToSchema() throws Exception { + InputSchema m = getInputSchema(); + + assertEquals(m.getNames().length, 6); + assertArrayEquals(m.getNames(), + new String[]{"z_a", "z_b", "z_c", "z_d", "z_e", "z_f"}); + assertEquals(m.getMapping().get(String.class).start, 4); + assertEquals(m.getMapping().get(String.class).length, 2); + } + + private InputSchema getInputSchema() { + InputSchema m = new InputSchema(100); + String[] doubleNames = {"z_a", "z_b"}; + m.add(Double.class, doubleNames); + String[] booleanNames = {"z_c", "z_d"}; + m.add(Boolean.class, booleanNames); + String[] strNames = {"z_e", "z_f"}; + m.add(String.class, strNames); + m.finish(); + return m; + } + + @Test + public void testVector() throws Exception { + MultiFamilyVector featureVector = createVector(); + + // we don't want default BIAS families to be here. We'll do that in the scorer. + assertEquals(5, featureVector.numFamilies()); + FamilyVector fFamily = featureVector.get(registry.family("f")); + assertEquals(1, fFamily.size()); + assertTrue(featureVector.containsKey(fFamily.family().feature("raw_feature"))); + + FamilyVector kFamily = featureVector.get(registry.family("K")); + assertEquals(2, kFamily.size()); + assertTrue(featureVector.containsKey(kFamily.family().feature("star:monkey"))); + assertTrue(featureVector.containsKey(kFamily.family().feature("good:F"))); + + assertFalse(featureVector.contains(registry.family("X"))); + + FamilyVector missFamily = featureVector.get(registry.family(GenericNamingConvention.MISS)); + assertEquals(1, missFamily.size()); + assertTrue(featureVector.containsKey(missFamily.family().feature("X_jump"))); + + FamilyVector sFamily = featureVector.get(registry.family("S")); + assertEquals(1, sFamily.size()); + assertEquals(featureVector.getDouble(sFamily.family().feature("speed")), 10d, .01); + + FamilyVector labelFamily = featureVector.get(registry.family(GenericNamingConvention.LABEL)); + assertEquals(1, labelFamily.size()); + assertEquals(featureVector.getDouble( + labelFamily.family().feature(GenericNamingConvention.LABEL_FEATURE_NAME)), 5d, .01); + } + + @Test + public void testMulticlass() throws Exception { + MultiFamilyVector featureVector = createMulticlassVector(); + + assertEquals(1, featureVector.numFamilies()); + + FamilyVector labelFamily = featureVector.get(registry.family(GenericNamingConvention.LABEL)); + assertEquals(2, labelFamily.size()); + assertEquals(1, labelFamily.get(labelFamily.family().feature("a")), 0); + assertEquals(2, labelFamily.get(labelFamily.family().feature("b")), 0); + } + + // TODO (Brad): Move tests to correct classes. + + /* @Test + public void addNumberFeature() throws Exception { + MultiFamilyVector vector = new FastMultiFamilyVector(registry); + Feature feature = Features.calculateFeature("family_feature", 4d, registry); + vector.put(feature, 4d); + FamilyVector family = vector.get(registry.family("family")); + assertEquals(1, family.size()); + assertEquals(4, family.get(family.family().feature("feature")), 0); + + feature = Features.calculateFeature("family_feature_float", 5d, registry); + vector.put(feature, 5d); + family = vector.get(registry.family("family")); + assertEquals(2, family.size()); + assertEquals(5d, family.get(family.family().feature("feature_float")), 0); + } + + @Test + public void addBoolFeature() throws Exception { + MultiFamilyVector vector = new FastMultiFamilyVector(registry); + Feature feature = Features.calculateFeature("family_feature", false, registry); + vector.putString(feature); + FamilyVector family = vector.get(registry.family("family")); + assertEquals(1, family.size()); + assertTrue(family.containsKey(family.family().feature("feature:F"))); + + feature = Features.calculateFeature("family_feature", true, registry); + vector.putString(feature); + assertEquals(2, family.size()); + assertTrue(family.containsKey(family.family().feature("feature:T"))); + // Other one is still there. + assertTrue(family.containsKey(family.family().feature("feature:F"))); + } + + @Test + public void addStringFeature() throws Exception { + MultiFamilyVector vector = new FastMultiFamilyVector(registry); + Feature feature = Features.calculateFeature("family_feature", "value", registry); + vector.putString(feature); + + feature = Features.calculateFeature("family_RAW", "feature_1", registry); + vector.putString(feature); + + FamilyVector family = vector.get(registry.family("family")); + assertEquals(2, family.size()); + assertTrue(family.containsKey(family.family().feature("feature:value"))); + assertTrue(family.containsKey(family.family().feature("feature_1"))); + } + + @Test + public void addMultiClassLabel() throws Exception { + MultiFamilyVector vector = empty(); + + Family labelFamily = labelFamily(); + Features.addMultiClassLabel("a:1,b:2", vector, labelFamily); + FamilyVector labelVector = vector.get(labelFamily); + assertEquals(2, labelVector.size()); + assertEquals(1, labelVector.get(labelFamily.feature("a")), 0); + assertEquals(2, labelVector.get(labelFamily.feature("b")), 0); + } + + private MultiFamilyVector empty() { + return new FastMultiFamilyVector(registry); + } + + private Family labelFamily() { + return registry.family(Features.LABEL); + } + + @Test (expected = RuntimeException.class) + public void addMultiClassManyColon() throws Exception { + Features.addMultiClassLabel("a:1:2,b:2", empty(), labelFamily()); + } + + @Test (expected = RuntimeException.class) + public void addMultiClassLabelNoolon() throws Exception { + Features.addMultiClassLabel("abc,b:2", empty(), labelFamily()); + } + + @Test + public void isLabel() throws Exception { + assertTrue(Features.isLabel(Features.calculateFeature("LABEL", "A", registry), labelFamily())); + assertFalse(Features.isLabel(Features.calculateFeature("LABE_ab", "A", registry), labelFamily())); + } + + @Test(expected = RuntimeException.class) + public void getFamilyEmpty() throws Exception { + Features.calculateFeature("", 0d, registry); + } + + @Test(expected = RuntimeException.class) + public void getFamilyNotLABEL() throws Exception { + Features.calculateFeature("LABE", 0d, registry); + } + + @Test(expected = RuntimeException.class) + public void getFamilyPrefix() throws Exception { + Features.calculateFeature("_abc", 0d, registry); + } + + @Test + public void getFamily() throws Exception { + Feature feature = Features.calculateFeature("f_ab_cd", 0d, registry); + assertEquals(feature.family(), registry.family("f")); + assertEquals(feature.name(), "ab_cd"); + } */ +} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/function/SplineTest.java b/core/src/test/java/com/airbnb/aerosolve/core/function/SplineTest.java deleted file mode 100644 index a02b01c4..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/function/SplineTest.java +++ /dev/null @@ -1,197 +0,0 @@ -package com.airbnb.aerosolve.core.function; - -import com.airbnb.aerosolve.core.ModelRecord; -import lombok.extern.slf4j.Slf4j; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -import static org.junit.Assert.assertEquals; - -/** - * @author Hector Yee - */ -@Slf4j -public class SplineTest { - @Test - public void testSplineEvaluate() { - float[] weights = {5.0f, 10.0f, -20.0f}; - Spline spline = new Spline(1.0f, 3.0f, weights); - testSpline(spline, 0.1f); - } - - public static Spline getSpline() { - float[] weights = {5.0f, 10.0f, -20.0f}; - return new Spline(1.0f, 3.0f, weights); - - } - @Test - public void testSplineResampleConstructor() { - Spline spline = getSpline(); - - // Same size - spline.resample(3); - testSpline(spline, 0.1f); - - // Smaller - Spline spline3 = getSpline(); - spline3.resample(2); - assertEquals(5.0f, spline3.evaluate(-1.0f), 0.1f); - assertEquals(5.0f, spline3.evaluate(1.0f), 0.1f); - assertEquals((5.0f - 20.0f) * 0.5f, spline3.evaluate(2.0f), 0.1f); - assertEquals(-20.0f, spline3.evaluate(3.0f), 0.1f); - assertEquals(-20.0f, spline3.evaluate(4.0f), 0.1f); - - // Larger - Spline spline4 = getSpline(); - spline4.resample(100); - testSpline(spline4, 0.2f); - } - - @Test - public void testSplineModelRecordConstructor() { - ModelRecord record = new ModelRecord(); - record.setFeatureFamily("TEST"); - record.setFeatureName("a"); - record.setMinVal(1.0); - record.setMaxVal(3.0); - List weightVec = new ArrayList(); - weightVec.add(5.0); - weightVec.add(10.0); - weightVec.add(-20.0); - record.setWeightVector(weightVec); - Spline spline = new Spline(record); - testSpline(spline, 0.1f); - } - - @Test - public void testSplineToModelRecord() { - float[] weights = {5.0f, 10.0f, -20.0f}; - Spline spline = new Spline(1.0f, 3.0f, weights); - ModelRecord record = spline.toModelRecord("family", "name"); - assertEquals(record.getFeatureFamily(), "family"); - assertEquals(record.getFeatureName(), "name"); - List weightVector = record.getWeightVector(); - assertEquals(5.0f, weightVector.get(0).floatValue(), 0.01f); - assertEquals(10.0f, weightVector.get(1).floatValue(), 0.01f); - assertEquals(-20.0f, weightVector.get(2).floatValue(), 0.01f); - assertEquals(1.0f, record.getMinVal(), 0.01f); - assertEquals(3.0f, record.getMaxVal(), 0.01f); - } - - @Test - public void testSplineResample() { - float[] weights = {5.0f, 10.0f, -20.0f}; - // Same size - Spline spline1 = new Spline(1.0f, 3.0f, weights); - spline1.resample(3); - testSpline(spline1, 0.1f); - - // Smaller - Spline spline2 = new Spline(1.0f, 3.0f, weights); - spline2.resample(2); - assertEquals(5.0f, spline2.evaluate(-1.0f), 0.1f); - assertEquals(5.0f, spline2.evaluate(1.0f), 0.1f); - assertEquals((5.0f - 20.0f) * 0.5f, spline2.evaluate(2.0f), 0.1f); - assertEquals(-20.0f, spline2.evaluate(3.0f), 0.1f); - assertEquals(-20.0f, spline2.evaluate(4.0f), 0.1f); - - // Larger - Spline spline3 = new Spline(1.0f, 3.0f, weights); - spline3.resample(100); - testSpline(spline3, 0.2f); - spline3.resample(200); - testSpline(spline3, 0.2f); - } - - void testSpline(Spline spline, float tol) { - float a = spline.evaluate(1.5f); - log.info("spline 1.5 is " + a); - assertEquals(5.0f, spline.evaluate(-1.0f), tol); - assertEquals(5.0f, spline.evaluate(1.0f), tol); - assertEquals(7.5f, spline.evaluate(1.5f), tol); - assertEquals(10.0f, spline.evaluate(1.99f), tol); - assertEquals(10.0f, spline.evaluate(2.0f), tol); - assertEquals(0.0f, spline.evaluate(2.3333f), tol); - assertEquals(-10.0f, spline.evaluate(2.667f), tol); - assertEquals(-20.0f, spline.evaluate(2.99999f), tol); - assertEquals(-20.0f, spline.evaluate(3.0f), tol); - assertEquals(-20.0f, spline.evaluate(4.0f), tol); - } - - float func(float x) { - return 0.1f * (x + 0.5f) * (x - 4.0f) * (x - 1.0f); - } - - @Test - public void testSplineUpdate() { - float[] weights = new float[8]; - Spline spline = new Spline(-1.0f, 5.0f, weights); - Random rnd = new java.util.Random(123); - for (int i = 0; i < 1000; i++) { - float x = (float) (rnd.nextDouble() * 6.0 - 1.0); - float y = func(x); - float tmp = spline.evaluate(x); - float delta =0.1f * (y - tmp); - spline.update(delta, x); - } - // Check we get roots where we expect them to be. - assertEquals(0.0f, spline.evaluate(-0.5f), 0.1f); - assertEquals(0.0f, spline.evaluate(1.0f), 0.1f); - assertEquals(0.0f, spline.evaluate(4.0f), 0.1f); - for (int i = 0; i < 20; i++) { - float x = (float) (6.0 * i / 20.0 - 1.0f); - float expected = func(x); - float eval = spline.evaluate(x); - log.info("x = " + x + " expected = " + expected + " got = " + eval); - assertEquals(expected, spline.evaluate(x), 0.1f); - } - } - - @Test - public void testSplineL1Norm() { - float[] weights1 = {5.0f, 10.0f, -20.0f}; - Spline spline1 = new Spline(1.0f, 3.0f, weights1); - assertEquals(35.0f, spline1.L1Norm(), 0.01f); - - float[] weights2 = {0.0f, 0.0f}; - Spline spline2 = new Spline(1.0f, 3.0f, weights2); - assertEquals(0.0f, spline2.L1Norm(), 0.01f); - } - - @Test - public void testSplineLInfinityNorm() { - float[] weights1 = {5.0f, 10.0f, -20.0f}; - Spline spline1 = new Spline(1.0f, 3.0f, weights1); - assertEquals(20.0f, spline1.LInfinityNorm(), 0.01f); - - float[] weights2 = {0.0f, 0.0f}; - Spline spline2 = new Spline(1.0f, 3.0f, weights2); - assertEquals(0.0f, spline2.LInfinityNorm(), 0.01f); - } - - @Test - public void testSplineLInfinityCap() { - float[] weights = {5.0f, 10.0f, -20.0f}; - Spline spline1 = new Spline(1.0f, 3.0f, weights); - // Larger (no scale) - spline1.LInfinityCap(30.0f); - assertEquals(5.0f, spline1.getWeights()[0], 0.01f); - assertEquals(10.0f, spline1.getWeights()[1], 0.01f); - assertEquals(-20.0f, spline1.getWeights()[2], 0.01f); - // Negative - spline1.LInfinityCap(-10.0f); - assertEquals(5.0f, spline1.getWeights()[0], 0.01f); - assertEquals(10.0f, spline1.getWeights()[1], 0.01f); - assertEquals(-20.0f, spline1.getWeights()[2], 0.01f); - // Smaller (with scale) - Spline spline2 = new Spline(1.0f, 3.0f, weights); - spline2.LInfinityCap(10.0f); - float scale = 10.0f / 20.0f; - assertEquals(5.0f * scale, spline2.getWeights()[0], 0.01f); - assertEquals(10.0f * scale, spline2.getWeights()[1], 0.01f); - assertEquals(-20.0f * scale, spline2.getWeights()[2], 0.01f); - } -} diff --git a/core/src/test/java/com/airbnb/aerosolve/core/function/LinearTest.java b/core/src/test/java/com/airbnb/aerosolve/core/functions/LinearTest.java similarity index 58% rename from core/src/test/java/com/airbnb/aerosolve/core/function/LinearTest.java rename to core/src/test/java/com/airbnb/aerosolve/core/functions/LinearTest.java index 8f1e1b10..75e5ea7a 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/function/LinearTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/functions/LinearTest.java @@ -1,11 +1,10 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.ModelRecord; -import org.junit.Test; - import java.util.ArrayList; import java.util.List; import java.util.Random; +import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -14,14 +13,13 @@ */ public class LinearTest { - float func(float x) { - return 0.2f + 1.5f * (x + 6.0f) / 11.0f; + double func(double x) { + return 0.2d + 1.5d * (x + 6.0d) / 11.0d; } Linear createLinearTestExample() { - float [] weights = {0.2f, 1.5f}; - Linear linearFunc = new Linear(-6.0f, 5.0f, weights); - return linearFunc; + double [] weights = {0.2d, 1.5d}; + return new Linear(-6.0d, 5.0d, weights); } @Test @@ -39,8 +37,8 @@ public void testLinearModelRecordConstructor() { weightVec.add(0.2); weightVec.add(1.5); record.setWeightVector(weightVec); - record.setMinVal(-6.0f); - record.setMaxVal(5.0f); + record.setMinVal(-6.0d); + record.setMaxVal(5.0d); Linear linearFunc = new Linear(record); testLinear(linearFunc); } @@ -52,29 +50,29 @@ public void testLinearToModelRecord() { assertEquals(record.getFeatureFamily(), "family"); assertEquals(record.getFeatureName(), "name"); List weightVector = record.getWeightVector(); - assertEquals(0.2f, weightVector.get(0).floatValue(), 0.01f); - assertEquals(1.5f, weightVector.get(1).floatValue(), 0.01f); - assertEquals(-6.0f, record.getMinVal(), 0.01f); - assertEquals(5.0f, record.getMaxVal(), 0.01f); + assertEquals(0.2d, weightVector.get(0), 0.01d); + assertEquals(1.5d, weightVector.get(1), 0.01d); + assertEquals(-6.0d, record.getMinVal(), 0.01d); + assertEquals(5.0d, record.getMaxVal(), 0.01d); } @Test public void testLinearUpdate() { - Linear linearFunc = new Linear(-6.0f, 5.0f, new float[2]); + Linear linearFunc = new Linear(-6.0d, 5.0d, new double[2]); Random rnd = new java.util.Random(123); for (int i = 0; i < 1000; i++) { - float x = (float) (rnd.nextDouble() * 10.0 - 5.0); - float y = func(x); - float tmp = linearFunc.evaluate(x); - float delta = 0.5f * (y - tmp); + double x = (rnd.nextDouble() * 10.0 - 5.0); + double y = func(x); + double tmp = linearFunc.evaluate(x); + double delta = 0.5d * (y - tmp); linearFunc.update(delta, x); } testLinear(linearFunc); } void testLinear(Linear linearFunc) { - assertEquals(0.2f + 1.5f * 6.0f / 11.0f, linearFunc.evaluate(0.0f), 0.01f); - assertEquals(0.2f + 1.5f * 7.0f / 11.0f, linearFunc.evaluate(1.0f), 0.01f); - assertEquals(0.2f + 1.5f * 5.0f / 11.0f, linearFunc.evaluate(-1.0f), 0.01f); + assertEquals(0.2d + 1.5d * 6.0d / 11.0d, linearFunc.evaluate(0.0d), 0.01d); + assertEquals(0.2d + 1.5d * 7.0d / 11.0d, linearFunc.evaluate(1.0d), 0.01d); + assertEquals(0.2d + 1.5d * 5.0d / 11.0d, linearFunc.evaluate(-1.0d), 0.01d); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionPointTest.java b/core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionPointTest.java similarity index 94% rename from core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionPointTest.java rename to core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionPointTest.java index 9abe916d..a377d42b 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionPointTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionPointTest.java @@ -1,4 +1,4 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.util.Util; import org.junit.Test; @@ -94,7 +94,7 @@ public void getCombination() throws Exception { @Test public void testZero() { - float[] af = new float[]{0f, 1f, -2f, 3.4f, 5.0f, -6.7f, 8.9f}; + double[] af = new double[]{0f, 1f, -2f, 3.4f, 5.0f, -6.7f, 8.9f}; List al = Arrays.asList(0f, 1f, -2f, 3.4f, 5.0f, -6.7f, 8.9f); List b = Arrays.asList(0f, 1.0f, -2.0f, 3.4f, 5.0f, -6.7f, 8.9f); @@ -104,8 +104,8 @@ public void testZero() { @Test public void test() { - float[] af = new float[]{1.0f, -2.0f, 3.0f, 4.0f}; - float[] bf = new float[]{-5.0f, -6.0f, 7.0f, 8.0f}; + double[] af = new double[]{1.0f, -2.0f, 3.0f, 4.0f}; + double[] bf = new double[]{-5.0f, -6.0f, 7.0f, 8.0f}; List a = Arrays.asList(1.0f, -2.0f, 3.0f, 4.0f); List b = Arrays.asList(-5.0f, -6.0f, 7.0f, 8.0f); diff --git a/core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionSplineTest.java b/core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionSplineTest.java similarity index 98% rename from core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionSplineTest.java rename to core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionSplineTest.java index d6e94551..e0d26df9 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/function/MultiDimensionSplineTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/functions/MultiDimensionSplineTest.java @@ -1,4 +1,4 @@ -package com.airbnb.aerosolve.core.function; +package com.airbnb.aerosolve.core.functions; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.models.NDTreeModel; diff --git a/core/src/test/java/com/airbnb/aerosolve/core/functions/SplineTest.java b/core/src/test/java/com/airbnb/aerosolve/core/functions/SplineTest.java new file mode 100644 index 00000000..6b70ca82 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/functions/SplineTest.java @@ -0,0 +1,197 @@ +package com.airbnb.aerosolve.core.functions; + +import com.airbnb.aerosolve.core.ModelRecord; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +/** + * @author Hector Yee + */ +@Slf4j +public class SplineTest { + @Test + public void testSplineEvaluate() { + double[] weights = {5.0d, 10.0d, -20.0d}; + Spline spline = new Spline(1.0d, 3.0d, weights); + testSpline(spline, 0.1d); + } + + public static Spline getSpline() { + double[] weights = {5.0d, 10.0d, -20.0d}; + return new Spline(1.0d, 3.0d, weights); + } + + @Test + public void testSplineResampleConstructor() { + Spline spline = getSpline(); + + // Same size + spline.resample(3); + testSpline(spline, 0.1d); + + // Smaller + Spline spline3 = getSpline(); + spline3.resample(2); + assertEquals(5.0d, spline3.evaluate(-1.0d), 0.1d); + assertEquals(5.0d, spline3.evaluate(1.0d), 0.1d); + assertEquals((5.0d - 20.0d) * 0.5d, spline3.evaluate(2.0d), 0.1d); + assertEquals(-20.0d, spline3.evaluate(3.0d), 0.1d); + assertEquals(-20.0d, spline3.evaluate(4.0d), 0.1d); + + // Larger + Spline spline4 = getSpline(); + spline4.resample(100); + testSpline(spline4, 0.2d); + } + + @Test + public void testSplineModelRecordConstructor() { + ModelRecord record = new ModelRecord(); + record.setFeatureFamily("TEST"); + record.setFeatureName("a"); + record.setMinVal(1.0); + record.setMaxVal(3.0); + List weightVec = new ArrayList(); + weightVec.add(5.0); + weightVec.add(10.0); + weightVec.add(-20.0); + record.setWeightVector(weightVec); + Spline spline = new Spline(record); + testSpline(spline, 0.1d); + } + + @Test + public void testSplineToModelRecord() { + double[] weights = {5.0d, 10.0d, -20.0d}; + Spline spline = new Spline(1.0d, 3.0d, weights); + ModelRecord record = spline.toModelRecord("family", "name"); + assertEquals(record.getFeatureFamily(), "family"); + assertEquals(record.getFeatureName(), "name"); + List weightVector = record.getWeightVector(); + assertEquals(5.0d, weightVector.get(0), 0.01d); + assertEquals(10.0d, weightVector.get(1), 0.01d); + assertEquals(-20.0d, weightVector.get(2), 0.01d); + assertEquals(1.0d, record.getMinVal(), 0.01d); + assertEquals(3.0d, record.getMaxVal(), 0.01d); + } + + @Test + public void testSplineResample() { + double[] weights = {5.0d, 10.0d, -20.0d}; + // Same size + Spline spline1 = new Spline(1.0d, 3.0d, weights); + spline1.resample(3); + testSpline(spline1, 0.1d); + + // Smaller + Spline spline2 = new Spline(1.0d, 3.0d, weights); + spline2.resample(2); + assertEquals(5.0d, spline2.evaluate(-1.0d), 0.1d); + assertEquals(5.0d, spline2.evaluate(1.0d), 0.1d); + assertEquals((5.0d - 20.0d) * 0.5d, spline2.evaluate(2.0d), 0.1d); + assertEquals(-20.0d, spline2.evaluate(3.0d), 0.1d); + assertEquals(-20.0d, spline2.evaluate(4.0d), 0.1d); + + // Larger + Spline spline3 = new Spline(1.0d, 3.0d, weights); + spline3.resample(100); + testSpline(spline3, 0.2d); + spline3.resample(200); + testSpline(spline3, 0.2d); + } + + void testSpline(Spline spline, double tol) { + double a = spline.evaluate(1.5d); + log.info("spline 1.5 is " + a); + assertEquals(5.0d, spline.evaluate(-1.0d), tol); + assertEquals(5.0d, spline.evaluate(1.0d), tol); + assertEquals(7.5f, spline.evaluate(1.5d), tol); + assertEquals(10.0d, spline.evaluate(1.99d), tol); + assertEquals(10.0d, spline.evaluate(2.0d), tol); + assertEquals(0.0d, spline.evaluate(2.3333d), tol); + assertEquals(-10.0d, spline.evaluate(2.667d), tol); + assertEquals(-20.0d, spline.evaluate(2.99999d), tol); + assertEquals(-20.0d, spline.evaluate(3.0d), tol); + assertEquals(-20.0d, spline.evaluate(4.0d), tol); + } + + double func(double x) { + return 0.1d * (x + 0.5d) * (x - 4.0d) * (x - 1.0d); + } + + @Test + public void testSplineUpdate() { + double[] weights = new double[8]; + Spline spline = new Spline(-1.0d, 5.0d, weights); + Random rnd = new java.util.Random(123); + for (int i = 0; i < 1000; i++) { + double x = (double) (rnd.nextDouble() * 6.0 - 1.0); + double y = func(x); + double tmp = spline.evaluate(x); + double delta =0.1d * (y - tmp); + spline.update(delta, x); + } + // Check we get roots where we expect them to be. + assertEquals(0.0d, spline.evaluate(-0.5d), 0.1d); + assertEquals(0.0d, spline.evaluate(1.0d), 0.1d); + assertEquals(0.0d, spline.evaluate(4.0d), 0.1d); + for (int i = 0; i < 20; i++) { + double x = (double) (6.0 * i / 20.0 - 1.0d); + double expected = func(x); + double eval = spline.evaluate(x); + log.info("x = " + x + " expected = " + expected + " got = " + eval); + assertEquals(expected, spline.evaluate(x), 0.1d); + } + } + + @Test + public void testSplineL1Norm() { + double[] weights1 = {5.0d, 10.0d, -20.0d}; + Spline spline1 = new Spline(1.0d, 3.0d, weights1); + assertEquals(35.0d, spline1.L1Norm(), 0.01d); + + double[] weights2 = {0.0d, 0.0d}; + Spline spline2 = new Spline(1.0d, 3.0d, weights2); + assertEquals(0.0d, spline2.L1Norm(), 0.01d); + } + + @Test + public void testSplineLInfinityNorm() { + double[] weights1 = {5.0d, 10.0d, -20.0d}; + Spline spline1 = new Spline(1.0d, 3.0d, weights1); + assertEquals(20.0d, spline1.LInfinityNorm(), 0.01d); + + double[] weights2 = {0.0d, 0.0d}; + Spline spline2 = new Spline(1.0d, 3.0d, weights2); + assertEquals(0.0d, spline2.LInfinityNorm(), 0.01d); + } + + @Test + public void testSplineLInfinityCap() { + double[] weights = {5.0d, 10.0d, -20.0d}; + Spline spline1 = new Spline(1.0d, 3.0d, weights); + // Larger (no scale) + spline1.LInfinityCap(30.0d); + assertEquals(5.0d, spline1.getWeights()[0], 0.01d); + assertEquals(10.0d, spline1.getWeights()[1], 0.01d); + assertEquals(-20.0d, spline1.getWeights()[2], 0.01d); + // Negative + spline1.LInfinityCap(-10.0d); + assertEquals(5.0d, spline1.getWeights()[0], 0.01d); + assertEquals(10.0d, spline1.getWeights()[1], 0.01d); + assertEquals(-20.0d, spline1.getWeights()[2], 0.01d); + // Smaller (with scale) + Spline spline2 = new Spline(1.0d, 3.0d, weights); + spline2.LInfinityCap(10.0d); + double scale = 10.0d / 20.0d; + assertEquals(5.0d * scale, spline2.getWeights()[0], 0.01d); + assertEquals(10.0d * scale, spline2.getWeights()[1], 0.01d); + assertEquals(-20.0d * scale, spline2.getWeights()[2], 0.01d); + } +} diff --git a/core/src/test/java/com/airbnb/aerosolve/core/images/HOGFeatureTest.java b/core/src/test/java/com/airbnb/aerosolve/core/images/HOGFeatureTest.java index 288d05de..0e145f61 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/images/HOGFeatureTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/images/HOGFeatureTest.java @@ -1,12 +1,11 @@ package com.airbnb.aerosolve.core.images; +import java.awt.image.BufferedImage; +import java.util.List; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.awt.image.BufferedImage; -import java.util.List; - import static org.junit.Assert.assertTrue; public class HOGFeatureTest { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/images/HSVFeatureTest.java b/core/src/test/java/com/airbnb/aerosolve/core/images/HSVFeatureTest.java index edb52c07..08a58c51 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/images/HSVFeatureTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/images/HSVFeatureTest.java @@ -1,12 +1,10 @@ package com.airbnb.aerosolve.core.images; +import java.awt.image.BufferedImage; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.awt.image.BufferedImage; -import java.util.List; - import static org.junit.Assert.assertTrue; public class HSVFeatureTest { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractorTest.java b/core/src/test/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractorTest.java index a2e19b50..3cc07131 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractorTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/images/ImageFeatureExtractorTest.java @@ -1,43 +1,44 @@ package com.airbnb.aerosolve.core.images; -import com.airbnb.aerosolve.core.FeatureVector; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import java.awt.image.BufferedImage; -import java.util.Map; -import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; import static org.junit.Assert.assertTrue; +@Slf4j public class ImageFeatureExtractorTest { - private static final Logger log = LoggerFactory.getLogger(ImageFeatureExtractorTest.class); + private final FeatureRegistry registry = new FeatureRegistry(); - public void validateFeature(Map> denseFeatures, + public void validateFeature(MultiFamilyVector vector, String name, int expectedCount) { - assertTrue(denseFeatures.containsKey(name)); - assertTrue(denseFeatures.get(name).size() == expectedCount); - log.info("feature " + name + "[0] = " + denseFeatures.get(name).get(0)); + assertTrue(vector.contains(name)); + FamilyVector fVec = vector.get(name); + assertTrue(fVec.size() == expectedCount); + assertTrue(fVec instanceof DenseVector); + log.info("feature " + name + "[0] = " + ((DenseVector) fVec).denseArray()[0]); } - public void validateFeatureVector(FeatureVector featureVector) { - Map> denseFeatures = featureVector.getDenseFeatures(); - assertTrue(denseFeatures != null); - assertTrue(denseFeatures.containsKey("rgb")); + public void validateFeatureVector(MultiFamilyVector featureVector) { + assertTrue(featureVector.numFamilies() == 4); + assertTrue(featureVector.contains("rgb")); final int kNumGrids = 1 + 4 + 16; - validateFeature(denseFeatures, "rgb", 512 * kNumGrids); - validateFeature(denseFeatures, "hog", 9 * kNumGrids); - validateFeature(denseFeatures, "lbp", 256 * kNumGrids); - validateFeature(denseFeatures, "hsv", 65 * kNumGrids); + validateFeature(featureVector, "rgb", 512 * kNumGrids); + validateFeature(featureVector, "hog", 9 * kNumGrids); + validateFeature(featureVector, "lbp", 256 * kNumGrids); + validateFeature(featureVector, "hsv", 65 * kNumGrids); } @Test public void testBlackImage() { BufferedImage image = new BufferedImage(10, 10, BufferedImage.TYPE_BYTE_GRAY); ImageFeatureExtractor featureExtractor = ImageFeatureExtractor.getInstance(); - FeatureVector featureVector = featureExtractor.getFeatureVector(image); + MultiFamilyVector featureVector = featureExtractor.getFeatureVector(image, registry); validateFeatureVector(featureVector); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/images/LBPFeatureTest.java b/core/src/test/java/com/airbnb/aerosolve/core/images/LBPFeatureTest.java index 823aa08d..7e6032fc 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/images/LBPFeatureTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/images/LBPFeatureTest.java @@ -1,12 +1,10 @@ package com.airbnb.aerosolve.core.images; +import java.awt.image.BufferedImage; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.awt.image.BufferedImage; -import java.util.List; - import static org.junit.Assert.assertTrue; public class LBPFeatureTest { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/images/RGBFeatureTest.java b/core/src/test/java/com/airbnb/aerosolve/core/images/RGBFeatureTest.java index 89678f7b..0e36cffd 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/images/RGBFeatureTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/images/RGBFeatureTest.java @@ -1,12 +1,11 @@ package com.airbnb.aerosolve.core.images; +import java.awt.image.BufferedImage; +import java.util.List; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.awt.image.BufferedImage; -import java.util.List; - import static org.junit.Assert.assertTrue; public class RGBFeatureTest { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelPerfTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelPerfTest.java new file mode 100644 index 00000000..afe38259 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelPerfTest.java @@ -0,0 +1,90 @@ +package com.airbnb.aerosolve.core.models; + +import com.airbnb.aerosolve.core.Example; +import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.SimpleExample; +import com.airbnb.aerosolve.core.transforms.Transformer; +import com.airbnb.aerosolve.core.util.Util; +import com.google.common.base.Optional; +import com.google.common.io.ByteStreams; +import com.typesafe.config.Config; +import com.typesafe.config.ConfigFactory; +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import org.apache.commons.codec.binary.Base64; +import org.junit.Test; + +/** + * + */ +public class AdditiveModelPerfTest { + private final FeatureRegistry registry = new FeatureRegistry(); + private final Random random = new Random(); + + @Test + public void testPerformance() throws Exception { + + ClassLoader loader = Thread.currentThread().getContextClassLoader(); + byte[] vectorBytes = ByteStreams.toByteArray(loader.getResourceAsStream("vector.bin")); + String vectorStr = new String(Base64.encodeBase64(vectorBytes)); + BasicMultiFamilyVector fastVector = (BasicMultiFamilyVector) Util.decodeFeatureVector(vectorStr, + registry); + List features = new ArrayList<>(fastVector.keySet()); + InputStream modelStream = loader.getResourceAsStream("daphne.model"); + BufferedReader fileReader = new BufferedReader(new InputStreamReader(modelStream)); + Optional modelOpt = ModelFactory.createFromReader(fileReader, registry); + if (!modelOpt.isPresent()) { + throw new IllegalStateException("Could not load model"); + } + AdditiveModel model = (AdditiveModel) modelOpt.get(); + + Config config = ConfigFactory.parseResources("model_daphne.conf"); + Transformer transformer = new Transformer(config, "pricing_model_config", registry, model); + + double transformTime = 0; + double scoreTime = 0; + double scoreNum = 0; + int iterations = 100; + for (int i = 0; i < iterations; i++) { + Example newExample = newExample(fastVector, features); + + long millis = System.nanoTime(); + newExample.transform(transformer, model); + transformTime += System.nanoTime() - millis; + millis = System.nanoTime(); + + for (FeatureVector scoreVector : newExample) { + scoreNum += model.scoreItem(scoreVector); + } + scoreTime += System.nanoTime() - millis; + } + transformTime = transformTime / 1000000; + System.out.println(String.format("Took %.3f s to transform", transformTime/1000)); + System.out.println(String.format("Took %.2f micros per transform", transformTime/iterations)); + + scoreTime = scoreTime / 1000000; + System.out.println(String.format("Took %.3f s to score", scoreTime/1000)); + System.out.println(String.format("Took %.2f micros per score", scoreTime/iterations)); + + System.out.println("ScoreNum " + scoreNum); + } + + private Example newExample(BasicMultiFamilyVector fastVector, List features) { + Example example = new SimpleExample(registry); + for (int i = 0; i < 1000; i++) { + fastVector = new BasicMultiFamilyVector(fastVector); + int index = random.nextInt(features.size()); + Feature feature = features.get(index); + fastVector.put(feature, random.nextFloat()); + example.addToExample(fastVector); + } + return example; + } +} diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelTest.java index 34d392aa..51226581 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/AdditiveModelTest.java @@ -1,18 +1,19 @@ package com.airbnb.aerosolve.core.models; import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; -import com.airbnb.aerosolve.core.FunctionForm; -import com.airbnb.aerosolve.core.function.Function; -import com.airbnb.aerosolve.core.function.Linear; -import com.airbnb.aerosolve.core.function.Spline; +import com.airbnb.aerosolve.core.functions.Function; +import com.airbnb.aerosolve.core.functions.Linear; +import com.airbnb.aerosolve.core.functions.Spline; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.CharArrayWriter; @@ -20,100 +21,85 @@ import java.io.StringReader; import java.io.StringWriter; import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; import java.util.Map; -import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Test the additive model */ +@Slf4j public class AdditiveModelTest { - private static final Logger log = LoggerFactory.getLogger(AdditiveModelTest.class); + private final FeatureRegistry registry = new FeatureRegistry(); AdditiveModel makeAdditiveModel() { - AdditiveModel model = new AdditiveModel(); - Map> weights = new HashMap<>(); - Map innerSplineFloat = new HashMap(); - Map innerLinearFloat = new HashMap(); - Map innerSplineString = new HashMap(); - Map innerLinearString = new HashMap(); - weights.put("spline_float", innerSplineFloat); - weights.put("linear_float", innerLinearFloat); - weights.put("spline_string", innerSplineString); - weights.put("linear_string", innerLinearString); - float [] ws = {5.0f, 10.0f, -20.0f}; - innerSplineFloat.put("aaa", new Spline(1.0f, 3.0f, ws)); + AdditiveModel model = new AdditiveModel(registry); + Map weights = model.weights(); + double [] ws = {5.0d, 10.0d, -20.0d}; + weights.put(registry.feature("spline_float", "aaa"), new Spline(1.0d, 3.0d, ws)); + // for string feature, only the first element in weight is meaningful. - innerSplineString.put("bbb", new Spline(1.0f, 2.0f, ws)); - float [] wl = {1.0f, 2.0f}; - innerLinearFloat.put("ccc", new Linear(-10.0f, 5.0f, wl)); - innerLinearString.put("ddd", new Linear(1.0f, 1.0f, wl)); - model.setWeights(weights); - model.setOffset(0.5f); - model.setSlope(1.5f); + weights.put(registry.feature("spline_string", "bbb"), new Spline(1.0d, 2.0d, ws)); + double [] wl = {1.0d, 2.0d}; + weights.put(registry.feature("linear_float", "ccc"), new Linear(-10.0d, 5.0d, wl)); + weights.put(registry.feature("linear_string", "ddd"), new Linear(1.0d, 1.0d, wl)); + model.offset(0.5f); + model.slope(1.5f); return model; } - public FeatureVector makeFeatureVector(float a, float c) { - FeatureVector featureVector = new FeatureVector(); - HashMap stringFeatures = new HashMap>(); - featureVector.setStringFeatures(stringFeatures); - HashMap floatFeatures = new HashMap>(); - featureVector.setFloatFeatures(floatFeatures); + public FeatureVector makeFeatureVector(double a, double c) { + FeatureVector featureVector = TransformTestingHelper.makeEmptyVector(registry); + // prepare string features - Set list1 = new HashSet(); - list1.add("bbb"); // weight = 5.0f - list1.add("ggg"); // this feature is missing in the model - stringFeatures.put("spline_string", list1); - - Set list2 = new HashSet(); - list2.add("ddd"); // weight = 3.0f - list2.add("ggg"); // this feature is missing in the model - stringFeatures.put("linear_string", list2); - featureVector.setStringFeatures(stringFeatures); + Family splineStringFamily = registry.family("spline_string"); + featureVector.putString(splineStringFamily.feature("bbb")); // weight = 5.0d + featureVector.putString(splineStringFamily.feature("ggg")); // this feature is missing in the model + + Family linearStringFamily = registry.family("linear_string"); + featureVector.putString(linearStringFamily.feature("ddd")); // weight = 3.0d + featureVector.putString(linearStringFamily.feature("ggg")); // this feature is missing in the model + // prepare float features - HashMap splineFloat = new HashMap(); - HashMap linearFloat = new HashMap(); - floatFeatures.put("spline_float", splineFloat); - floatFeatures.put("linear_float", linearFloat); + Family splineFloat = registry.family("spline_float"); + featureVector.put(splineFloat.feature("aaa"), a); // corresponds to Spline(1.0d, 3.0d, {5, 10, -20}) + featureVector.put(splineFloat.feature("ggg"), 1.0); // missing features - splineFloat.put("aaa", (double) a); // corresponds to Spline(1.0f, 3.0f, {5, 10, -20}) - splineFloat.put("ggg", 1.0); // missing features - linearFloat.put("ccc", (double) c); // weight = 1+2*c - linearFloat.put("ggg", 10.0); // missing features + Family linearFloat = registry.family("linear_float"); + featureVector.put(linearFloat.feature("ccc"), c); // weight = 1+2*c + featureVector.put(splineFloat.feature("ggg"), 10.0); // missing features return featureVector; } @Test public void testScoreEmptyFeature() { - FeatureVector featureVector = new FeatureVector(); - AdditiveModel model = new AdditiveModel(); - float score = model.scoreItem(featureVector); - assertEquals(0.0f, score, 1e-10f); + FeatureVector featureVector = TransformTestingHelper.makeEmptyVector(registry); + AdditiveModel model = new AdditiveModel(registry); + double score = model.scoreItem(featureVector); + assertEquals(0.0d, score, 1e-10d); } @Test public void testScoreNonEmptyFeature() { AdditiveModel model = makeAdditiveModel(); - FeatureVector fv1 = makeFeatureVector(1.0f, 0.0f); - float score1 = model.scoreItem(fv1); - assertEquals(8.0f + 5.0f + (1.0f + 2.0f * (0.0f + 10.0f) / 15.0f), score1, 0.001f); + FeatureVector fv1 = makeFeatureVector(1.0d, 0.0d); + double score1 = model.scoreItem(fv1); + assertEquals(8.0d + 5.0d + (1.0d + 2.0d * (0.0d + 10.0d) / 15.0d), score1, 0.001d); - FeatureVector fv2 = makeFeatureVector(-1.0f, 0.0f); - float score2 = model.scoreItem(fv2); - assertEquals(8.0f + 5.0f + (1.0f + 2.0f * (0.0f + 10.0f) / 15.0f), score2, 0.001f); + FeatureVector fv2 = makeFeatureVector(-1.0d, 0.0d); + double score2 = model.scoreItem(fv2); + assertEquals(8.0d + 5.0d + (1.0d + 2.0d * (0.0d + 10.0d) / 15.0d), score2, 0.001d); - FeatureVector fv3 = makeFeatureVector(4.0f, 1.0f); - float score3 = model.scoreItem(fv3); - assertEquals(8.0f - 20.0f + (1.0f + 2.0f * (1.0f + 10.0f) / 15.0f), score3, 0.001f); + FeatureVector fv3 = makeFeatureVector(4.0d, 1.0d); + double score3 = model.scoreItem(fv3); + assertEquals(8.0d - 20.0d + (1.0d + 2.0d * (1.0d + 10.0d) / 15.0d), score3, 0.001d); - FeatureVector fv4 = makeFeatureVector(2.0f, 7.0f); - float score4 = model.scoreItem(fv4); - assertEquals(8.0f + 10.0f + (1.0f + 2.0f * (7.0f + 10.0f) / 15.0f), score4, 0.001f); + FeatureVector fv4 = makeFeatureVector(2.0d, 7.0d); + double score4 = model.scoreItem(fv4); + assertEquals(8.0d + 10.0d + (1.0d + 2.0d * (7.0d + 10.0d) / 15.0d), score4, 0.001d); } @Test @@ -135,26 +121,26 @@ public void testLoad() { ModelRecord record1 = new ModelRecord(); record1.setModelHeader(header); ModelRecord record2 = new ModelRecord(); - record2.setFunctionForm(FunctionForm.Spline); + record2.setFunctionForm(FunctionForm.SPLINE); record2.setFeatureFamily("spline_float"); record2.setFeatureName("aaa"); record2.setWeightVector(ws); record2.setMinVal(1.0); record2.setMaxVal(3.0); ModelRecord record3 = new ModelRecord(); - record3.setFunctionForm(FunctionForm.Spline); + record3.setFunctionForm(FunctionForm.SPLINE); record3.setFeatureFamily("spline_string"); record3.setFeatureName("bbb"); record3.setWeightVector(ws); record3.setMinVal(1.0); record3.setMaxVal(2.0); ModelRecord record4 = new ModelRecord(); - record4.setFunctionForm(FunctionForm.Linear); + record4.setFunctionForm(FunctionForm.LINEAR); record4.setFeatureFamily("linear_float"); record4.setFeatureName("ccc"); record4.setWeightVector(wl); ModelRecord record5 = new ModelRecord(); - record5.setFunctionForm(FunctionForm.Linear); + record5.setFunctionForm(FunctionForm.LINEAR); record5.setFeatureFamily("linear_string"); record5.setFeatureName("ddd"); record5.setWeightVector(wl); @@ -173,12 +159,12 @@ public void testLoad() { assertTrue(serialized.length() > 0); StringReader strReader = new StringReader(serialized); BufferedReader reader = new BufferedReader(strReader); - FeatureVector featureVector = makeFeatureVector(2.0f, 7.0f); + FeatureVector featureVector = makeFeatureVector(2.0d, 7.0d); try { - Optional model = ModelFactory.createFromReader(reader); + Optional model = ModelFactory.createFromReader(reader, registry); assertTrue(model.isPresent()); - float score = model.get().scoreItem(featureVector); - assertEquals(8.0f + 10.0f + 15.0f, score, 0.001f); + double score = model.get().scoreItem(featureVector); + assertEquals(8.0d + 10.0d + 15.0d, score, 0.001d); } catch (IOException e) { assertTrue("Could not read", false); } @@ -201,38 +187,39 @@ public void testSave() { public void testAddFunction() { AdditiveModel model = makeAdditiveModel(); // add an existing feature without overwrite - model.addFunction("spline_float", "aaa", new Spline(2.0f, 10.0f, 5), false); + model.addFunction(registry.feature("spline_float", "aaa"), + new Spline(2.0d, 10.0d, 5), false); // add an existing feature with overwrite - model.addFunction("linear_float", "ccc", new Linear(3.0f, 5.0f), true); + model.addFunction(registry.feature("linear_float", "ccc"), + new Linear(3.0d, 5.0d), true); // add a new feature - model.addFunction("spline_float", "new", new Spline(2.0f, 10.0f, 5), false); - - Map> weights = model.getWeights(); - for (Map.Entry> featureFamily: weights.entrySet()) { - String familyName = featureFamily.getKey(); - Map features = featureFamily.getValue(); - for (Map.Entry feature: features.entrySet()) { - String featureName = feature.getKey(); - Function func = feature.getValue(); - if (familyName.equals("spline_float")) { - Spline spline = (Spline) func; - if (featureName.equals("aaa")) { - assertTrue(spline.getMaxVal() == 3.0f); - assertTrue(spline.getMinVal() == 1.0f); - assertTrue(spline.getWeights().length == 3); - } else if (featureName.equals("new")) { - assertTrue(spline.getMaxVal() == 10.0f); - assertTrue(spline.getMinVal() == 2.0f); - assertTrue(spline.getWeights().length == 5); - } - } else if(familyName.equals("linear_float") && featureName.equals("ccc")) { - Linear linear = (Linear) func; - assertTrue(linear.getWeights().length == 2); - assertTrue(linear.getWeights()[0] == 0.0f); - assertTrue(linear.getWeights()[1] == 0.0f); - assertTrue(linear.getMinVal() == 3.0f); - assertTrue(linear.getMaxVal() == 5.0f); + model.addFunction(registry.feature("spline_float", "new"), + new Spline(2.0d, 10.0d, 5), false); + + Map weights = model.weights(); + for (Map.Entry entry: weights.entrySet()) { + + String featureName = entry.getKey().name(); + String familyName = entry.getKey().family().name(); + Function func = entry.getValue(); + if (familyName.equals("spline_float")) { + Spline spline = (Spline) func; + if (featureName.equals("aaa")) { + assertTrue(spline.getMaxVal() == 3.0d); + assertTrue(spline.getMinVal() == 1.0d); + assertTrue(spline.getWeights().length == 3); + } else if (featureName.equals("new")) { + assertTrue(spline.getMaxVal() == 10.0d); + assertTrue(spline.getMinVal() == 2.0d); + assertTrue(spline.getWeights().length == 5); } + } else if(familyName.equals("linear_float") && featureName.equals("ccc")) { + Linear linear = (Linear) func; + assertTrue(linear.getWeights().length == 2); + assertTrue(linear.getWeights()[0] == 0.0d); + assertTrue(linear.getWeights()[1] == 0.0d); + assertTrue(linear.getMinVal() == 3.0d); + assertTrue(linear.getMaxVal() == 5.0d); } } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/KDTreeModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/KDTreeModelTest.java index bbd008e6..68a21ae8 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/KDTreeModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/KDTreeModelTest.java @@ -2,13 +2,12 @@ import com.airbnb.aerosolve.core.KDTreeNode; import com.airbnb.aerosolve.core.KDTreeNodeType; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.HashSet; import java.util.Set; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/LinearModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/LinearModelTest.java index e5adbe01..1142bd87 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/LinearModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/LinearModelTest.java @@ -4,61 +4,54 @@ import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; import com.airbnb.aerosolve.core.util.Util; - -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.*; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.HashSet; -import java.util.HashMap; -import com.google.common.hash.HashCode; import com.google.common.base.Optional; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.CharArrayWriter; +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ +@Slf4j public class LinearModelTest { - private static final Logger log = LoggerFactory.getLogger(LinearModelTest.class); + private final FeatureRegistry registry = new FeatureRegistry(); - public FeatureVector makeFeatureVector() { - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - // add a feature that is missing in the model - list.add("ccc"); - HashMap stringFeatures = new HashMap>(); - stringFeatures.put("string_feature", list); - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("string_feature", "aaa") + .string("string_feature", "bbb") + .string("string_feature", "ccc") + .build(); } public LinearModel makeLinearModel() { - LinearModel model = new LinearModel(); - Map> weights = new HashMap<>(); - Map inner = new HashMap<>(); - weights.put("string_feature", inner); - inner.put("aaa", 0.5f); - inner.put("bbb", 0.25f); - model.setWeights(weights); - model.setOffset(0.5f); - model.setSlope(1.5f); + LinearModel model = new LinearModel(registry); + Family stringFamily = registry.family("string_feature"); + model.weights().put(stringFamily.feature("aaa"), 0.5d); + model.weights().put(stringFamily.feature("bbb"), 0.25d); + model.offset(0.5d); + model.slope(1.5d); return model; } @Test public void testScoreEmptyFeature() { - FeatureVector featureVector = new FeatureVector(); - LinearModel model = new LinearModel(); - float score = model.scoreItem(featureVector); + MultiFamilyVector featureVector = TransformTestingHelper.makeEmptyVector(registry); + LinearModel model = new LinearModel(registry); + double score = model.scoreItem(featureVector); assertTrue(score < 1e-10f); assertTrue(score > -1e-10f); } @@ -66,14 +59,8 @@ public void testScoreEmptyFeature() { @Test public void testScoreNonEmptyFeature() { FeatureVector featureVector = makeFeatureVector(); - LinearModel model = new LinearModel(); - Map> weights = new HashMap<>(); - Map inner = new HashMap<>(); - weights.put("string_feature", inner); - inner.put("aaa", 0.5f); - inner.put("bbb", 0.25f); - model.setWeights(weights); - float score = model.scoreItem(featureVector); + LinearModel model = makeLinearModel(); + double score = model.scoreItem(featureVector); assertTrue(score < 0.76f); assertTrue(score > 0.74f); } @@ -104,9 +91,9 @@ public void testLoad() { BufferedReader reader = new BufferedReader(strReader); FeatureVector featureVector = makeFeatureVector(); try { - Optional model = ModelFactory.createFromReader(reader); + Optional model = ModelFactory.createFromReader(reader, registry); assertTrue(model.isPresent()); - float score = model.get().scoreItem(featureVector); + double score = model.get().scoreItem(featureVector); assertTrue(score > 0.89f); assertTrue(score < 0.91f); } catch (IOException e) { @@ -134,13 +121,13 @@ public void testDebugScoreComponents() { List scoreRecordsList = model.debugScoreComponents(fv); assertTrue(scoreRecordsList.size() == 2); for (DebugScoreRecord record : scoreRecordsList) { - assertTrue(record.featureFamily == "string_feature"); - assertTrue(record.featureName == "aaa" || record.featureName == "bbb"); - assertTrue(record.featureValue == 1.0); - if (record.featureName == "aaa") { - assertTrue(record.featureWeight == 0.5f); + assertTrue("string_feature".equals(record.getFeatureFamily())); + assertTrue("aaa".equals(record.getFeatureName()) || "bbb".equals(record.getFeatureName())); + assertTrue(record.getFeatureValue() == 1.0); + if ("aaa".equals(record.getFeatureName())) { + assertTrue(record.getFeatureWeight() == 0.5d); } else { - assertTrue(record.featureWeight == 0.25f); + assertTrue(record.getFeatureWeight() == 0.25d); } } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/LowRankLinearModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/LowRankLinearModelTest.java index d6a2bcea..933dff8a 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/LowRankLinearModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/LowRankLinearModelTest.java @@ -5,11 +5,16 @@ import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.MulticlassScoringResult; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; import com.airbnb.aerosolve.core.util.FloatVector; import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; -import org.junit.Test; - +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.CharArrayWriter; @@ -18,17 +23,19 @@ import java.io.StringWriter; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; +import org.junit.Test; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /* Test the low rank linear model */ public class LowRankLinearModelTest { + private final FeatureRegistry registry = new FeatureRegistry(); + ArrayList makeLabelDictionary() { ArrayList labelDictionary = new ArrayList<>(); // construct label dictionary @@ -61,12 +68,12 @@ Map makeLabelWeightVector() { LowRankLinearModel makeLowRankLinearModel() { // A naive model with three classes 'animal', 'color' and 'fruit' // and the size of embedding D = number of labels, W is an identity matrix - LowRankLinearModel model = new LowRankLinearModel(); - model.setEmbeddingDimension(3); - model.setLabelDictionary(makeLabelDictionary()); + LowRankLinearModel model = new LowRankLinearModel(registry); + model.embeddingDimension(3); + model.labelDictionary(makeLabelDictionary()); // construct featureWeightVector - Map> featureWeights = new HashMap<>(); + Reference2ObjectMap featureWeights = new Reference2ObjectOpenHashMap<>(); Map animalFeatures = new HashMap<>(); Map colorFeatures = new HashMap<>(); Map fruitFeatures = new HashMap<>(); @@ -78,69 +85,65 @@ LowRankLinearModel makeLowRankLinearModel() { float[] colorFeature = {0.0f, 1.0f, 0.0f}; float[] fruitFeature = {0.0f, 0.0f, 1.0f}; + Family animalFamily = registry.family("a"); for (String word: animalWords) { - animalFeatures.put(word, new FloatVector(animalFeature)); + featureWeights.put(animalFamily.feature(word), new FloatVector(animalFeature)); } + Family colorFamily = registry.family("c"); for (String word: colorWords) { - colorFeatures.put(word, new FloatVector(colorFeature)); + featureWeights.put(colorFamily.feature(word), new FloatVector(colorFeature)); } + Family fruitFamily = registry.family("f"); for (String word: fruitWords) { - fruitFeatures.put(word, new FloatVector(fruitFeature)); + featureWeights.put(fruitFamily.feature(word), new FloatVector(fruitFeature)); } - featureWeights.put("a", animalFeatures); - featureWeights.put("c", colorFeatures); - featureWeights.put("f", fruitFeatures); - model.setFeatureWeightVector(featureWeights); + model.featureWeightVector(featureWeights); // set labelWeightVector - model.setLabelWeightVector(makeLabelWeightVector()); + model.labelWeightVector(makeLabelWeightVector()); model.buildLabelToIndex(); return model; } public FeatureVector makeFeatureVector(String label) { - FeatureVector featureVector = new FeatureVector(); - HashMap stringFeatures = new HashMap>(); - featureVector.setStringFeatures(stringFeatures); - HashMap floatFeatures = new HashMap>(); - featureVector.setFloatFeatures(floatFeatures); - HashMap feature = new HashMap(); + MultiFamilyVector featureVector = TransformTestingHelper.makeEmptyVector(registry); switch (label) { case "animal": { - feature.put("cat", 1.0); - feature.put("dog", 2.0); - floatFeatures.put("a", feature); + Family family = registry.family("a"); + featureVector.put(family.feature("cat"), 1.0); + featureVector.put(family.feature("dog"), 2.0); break; } case "color": { - feature.put("red", 2.0); - feature.put("black", 4.0); - floatFeatures.put("c", feature); + Family family = registry.family("c"); + featureVector.put(family.feature("red"), 2.0); + featureVector.put(family.feature("black"), 4.0); break; } case "fruit": { - feature.put("apple", 1.0); - feature.put("kiwi", 3.0); - floatFeatures.put("f", feature); + Family family = registry.family("f"); + featureVector.put(family.feature("apple"), 1.0); + featureVector.put(family.feature("kiwi"), 3.0); break; } default: break; } + return featureVector; } @Test public void testScoreEmptyFeature() { - FeatureVector featureVector = new FeatureVector(); + MultiFamilyVector featureVector = TransformTestingHelper.makeEmptyVector(registry); LowRankLinearModel model = makeLowRankLinearModel(); ArrayList score = model.scoreItemMulticlass(featureVector); assertEquals(score.size(), 3); - assertEquals(0.0f, score.get(0).score, 1e-10f); - assertEquals(0.0f, score.get(1).score, 1e-10f); - assertEquals(0.0f, score.get(2).score, 1e-10f); + assertEquals(0.0d, score.get(0).getScore(), 1e-10d); + assertEquals(0.0d, score.get(1).getScore(), 1e-10d); + assertEquals(0.0d, score.get(2).getScore(), 1e-10d); } @Test @@ -152,21 +155,21 @@ public void testScoreNonEmptyFeature() { ArrayList s1 = model.scoreItemMulticlass(animalFv); assertEquals(s1.size(), 3); - assertEquals(0.0f, s1.get(0).score, 3.0f); - assertEquals(0.0f, s1.get(1).score, 1e-10f); - assertEquals(0.0f, s1.get(2).score, 1e-10f); + assertEquals(3.0d, s1.get(0).getScore(), 1e-10d); + assertEquals(0.0d, s1.get(1).getScore(), 1e-10d); + assertEquals(0.0d, s1.get(2).getScore(), 1e-10d); ArrayList s2 = model.scoreItemMulticlass(colorFv); assertEquals(s2.size(), 3); - assertEquals(0.0f, s2.get(0).score, 1e-10f); - assertEquals(0.0f, s2.get(1).score, 6.0f); - assertEquals(0.0f, s2.get(2).score, 1e-10f); + assertEquals(0.0d, s2.get(0).getScore(), 1e-10d); + assertEquals(6.0d, s2.get(1).getScore(), 1e-10d); + assertEquals(0.0d, s2.get(2).getScore(), 1e-10d); ArrayList s3 = model.scoreItemMulticlass(fruitFv); assertEquals(s3.size(), 3); - assertEquals(0.0f, s3.get(0).score, 1e-10f); - assertEquals(0.0f, s3.get(1).score, 1e-10f); - assertEquals(0.0f, s3.get(2).score, 4.0f); + assertEquals(0.0d, s3.get(0).getScore(), 1e-10d); + assertEquals(0.0d, s3.get(1).getScore(), 1e-10d); + assertEquals(4.0d, s3.get(2).getScore(), 1e-10d); } @Test @@ -231,18 +234,18 @@ public void testLoad() { FeatureVector animalFv = makeFeatureVector("animal"); FeatureVector colorFv = makeFeatureVector("color"); try { - Optional model = ModelFactory.createFromReader(reader); + Optional model = ModelFactory.createFromReader(reader, registry); assertTrue(model.isPresent()); ArrayList s1 = model.get().scoreItemMulticlass(animalFv); assertEquals(s1.size(), 3); - assertEquals(0.0f, s1.get(0).score, 3.0f); - assertEquals(0.0f, s1.get(1).score, 1e-10f); - assertEquals(0.0f, s1.get(2).score, 1e-10f); + assertEquals(3.0d, s1.get(0).getScore(), 1e-10d); + assertEquals(0.0d, s1.get(1).getScore(), 1e-10d); + assertEquals(0.0d, s1.get(2).getScore(), 1e-10d); ArrayList s2 = model.get().scoreItemMulticlass(colorFv); assertEquals(s2.size(), 3); - assertEquals(0.0f, s2.get(0).score, 1e-10f); - assertEquals(0.0f, s2.get(1).score, 1e-10f); - assertEquals(0.0f, s2.get(2).score, 1e-10f); + assertEquals(0.0d, s2.get(0).getScore(), 1e-10d); + assertEquals(0.0d, s2.get(1).getScore(), 1e-10d); + assertEquals(0.0d, s2.get(2).getScore(), 1e-10d); } catch (IOException e) { assertTrue("Could not read", false); } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/MlpModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/MlpModelTest.java index 97a0738d..9a5cc1da 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/MlpModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/MlpModelTest.java @@ -4,11 +4,16 @@ import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelHeader; import com.airbnb.aerosolve.core.ModelRecord; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; import com.airbnb.aerosolve.core.util.FloatVector; import com.airbnb.aerosolve.core.util.Util; import com.google.common.base.Optional; -import org.junit.Test; - +import it.unimi.dsi.fastutil.objects.Reference2ObjectMap; +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.CharArrayWriter; @@ -17,43 +22,40 @@ import java.io.StringWriter; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; +import org.junit.Test; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /* Test the MLP model */ public class MlpModelTest { - public FeatureVector makeFeatureVector() { - FeatureVector featureVector = new FeatureVector(); - HashMap stringFeatures = new HashMap>(); - featureVector.setStringFeatures(stringFeatures); - HashMap floatFeatures = new HashMap>(); - featureVector.setFloatFeatures(floatFeatures); - HashMap feature = new HashMap(); - feature.put("a", 1.0); - feature.put("b", 2.0); - floatFeatures.put("in", feature); + private final FeatureRegistry registry = new FeatureRegistry(); + + public MultiFamilyVector makeFeatureVector() { + MultiFamilyVector featureVector = TransformTestingHelper.makeEmptyVector(registry); + Family family = registry.family("in"); + featureVector.put(family.feature("a"), 1.0); + featureVector.put(family.feature("b"), 2.0); + return featureVector; } public MlpModel makeMlpModel(FunctionForm func) { // construct a network with 1 hidden layer // and there are 3 nodes in the hidden layer - ArrayList nodeNum = new ArrayList(2); + ArrayList nodeNum = new ArrayList<>(2); nodeNum.add(3); nodeNum.add(1); // assume bias at each node are zeros - ArrayList activations = new ArrayList(); + ArrayList activations = new ArrayList<>(); activations.add(func); activations.add(func); - MlpModel model = new MlpModel(activations, nodeNum); + MlpModel model = new MlpModel(activations, nodeNum, registry); // set input layer - HashMap inputLayer = new HashMap<>(); - HashMap inner = new HashMap<>(); + Reference2ObjectMap inputLayer = new Reference2ObjectOpenHashMap<>(); FloatVector f11 = new FloatVector(3); f11.set(0, 0.0f); f11.set(1, 1.0f); @@ -62,45 +64,45 @@ public MlpModel makeMlpModel(FunctionForm func) { f12.set(0, 1.0f); f12.set(1, 1.0f); f12.set(2, 0.0f); - inner.put("a", f11); - inner.put("b", f12); - inputLayer.put("in", inner); - model.setInputLayerWeights(inputLayer); + Family inputFamily = registry.family("in"); + inputLayer.put(inputFamily.feature("a"), f11); + inputLayer.put(inputFamily.feature("b"), f12); + model.inputLayerWeights(inputLayer); // set hidden layer - HashMap hiddenLayer = new HashMap<>(); + Map> hiddenLayer = new HashMap<>(); FloatVector f21 = new FloatVector(1); FloatVector f22 = new FloatVector(1); FloatVector f23 = new FloatVector(1); f21.set(0, 0.5f); f22.set(0, 1.0f); f23.set(0, 2.0f); - ArrayList hidden = new ArrayList(3); + ArrayList hidden = new ArrayList<>(3); hidden.add(f21); hidden.add(f22); hidden.add(f23); hiddenLayer.put(0, hidden); - model.setHiddenLayerWeights(hiddenLayer); + model.hiddenLayerWeights(hiddenLayer); return model; } @Test public void testConstructedModel() { MlpModel model = makeMlpModel(FunctionForm.RELU); - assertEquals(model.getNumHiddenLayers(), 1); - assertEquals(model.getActivationFunction().get(0), FunctionForm.RELU); - assertEquals(model.getHiddenLayerWeights().size(), 1); - assertEquals(model.getHiddenLayerWeights().get(0).size(), 3); - assertEquals(model.getInputLayerWeights().entrySet().size(), 1); - assertEquals(model.getInputLayerWeights().get("in").entrySet().size(), 2); - assertEquals(model.getInputLayerWeights().get("in").get("a").length(), 3); - assertEquals(model.getInputLayerWeights().get("in").get("b").length(), 3); + assertEquals(model.numHiddenLayers(), 1); + assertEquals(model.activationFunction().get(0), FunctionForm.RELU); + assertEquals(model.hiddenLayerWeights().size(), 1); + assertEquals(model.hiddenLayerWeights().get(0).size(), 3); + assertEquals(model.inputLayerWeights().entrySet().size(), 2); + Family inputFamily = registry.family("in"); + assertEquals(model.inputLayerWeights().get(inputFamily.feature("a")).length(), 3); + assertEquals(model.inputLayerWeights().get(inputFamily.feature("b")).length(), 3); } @Test public void testScoring() { FeatureVector fv = makeFeatureVector(); MlpModel model = makeMlpModel(FunctionForm.RELU); - float output = model.scoreItem(fv); + double output = model.scoreItem(fv); assertEquals(output, 6.0f, 1e-10f); } @@ -199,9 +201,9 @@ public void testLoad() { BufferedReader reader = new BufferedReader(strReader); FeatureVector fv = makeFeatureVector(); try { - Optional model = ModelFactory.createFromReader(reader); + Optional model = ModelFactory.createFromReader(reader, registry); assertTrue(model.isPresent()); - float s = model.get().scoreItem(fv); + double s = model.get().scoreItem(fv); assertEquals(s, 6.0f, 1e-10f); } catch (IOException e) { assertTrue("Could not read", false); diff --git a/core/src/test/java/com/airbnb/aerosolve/core/models/NDTreeModelTest.java b/core/src/test/java/com/airbnb/aerosolve/core/models/NDTreeModelTest.java index 9804e90e..edf48655 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/models/NDTreeModelTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/models/NDTreeModelTest.java @@ -85,7 +85,7 @@ public static NDTreeModel getNDTreeModel1D() { @Test public void testDimension() { NDTreeModel tree = getNDTreeModel(); - assertEquals(2, tree.getDimension()); + assertEquals(2, tree.dimension()); NDTreeNode parent = new NDTreeNode(); parent.setAxisIndex(0); @@ -93,10 +93,10 @@ public void testDimension() { one.setAxisIndex(3); NDTreeNode[] arr = {parent, one}; tree = new NDTreeModel(arr); - assertEquals(4, tree.getDimension()); + assertEquals(4, tree.dimension()); tree = NDTreeModelTest.getNDTreeModel1D(); - assertEquals(1, tree.getDimension()); + assertEquals(1, tree.dimension()); } @Test diff --git a/core/src/test/java/com/airbnb/aerosolve/core/scoring/ModelScorerTest.java b/core/src/test/java/com/airbnb/aerosolve/core/scoring/ModelScorerTest.java index c171504c..f54c5680 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/scoring/ModelScorerTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/scoring/ModelScorerTest.java @@ -1,22 +1,22 @@ package com.airbnb.aerosolve.core.scoring; import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; -import com.airbnb.aerosolve.core.features.*; +import com.airbnb.aerosolve.core.features.InputGenerator; +import com.airbnb.aerosolve.core.features.InputSchema; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.features.SimpleExample; import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @Slf4j public class ModelScorerTest { + private final FeatureRegistry registry = new FeatureRegistry(); @Test public void rawProbability() throws Exception { @@ -28,51 +28,42 @@ public void rawProbability() throws Exception { ModelScorer modelScorer = new ModelScorer(incomeModel); - FeatureMapping featureMapping = new FeatureMapping(); - featureMapping.add(dataName1); - featureMapping.add(dataName2); - featureMapping.add(dataName3); - featureMapping.finish(); + InputSchema inputSchema = new InputSchema(); + inputSchema.add(dataName1); + inputSchema.add(dataName2); + inputSchema.add(dataName3); + inputSchema.finish(); - FeatureGen f = new FeatureGen(featureMapping); + InputGenerator f = new InputGenerator(inputSchema); f.add(data1, dataName1); f.add(data2, dataName2); f.add(data3, dataName3); - Features features = f.gen(); - - List stringFamilies = new ArrayList<>(); - stringFamilies.add(new StringFamily("S")); - - List floatFamilies = new ArrayList<>(); - floatFamilies.add(new FloatFamily("F")); - Example example = FeatureVectorGen.toSingleFeatureVectorExample(features, stringFamilies, floatFamilies); + Example example = new SimpleExample(registry); + MultiFamilyVector featureVector = f.load(example.createVector()); - FeatureVector featureVector = example.getExample().get(0); - final Map> floatFeatures = featureVector.getFloatFeatures(); - Map floatFeatureFamily = floatFeatures.get("F"); - assertEquals(floatFeatureFamily.get("age"), 30, 0.1); - assertEquals(floatFeatureFamily.get("hours"), 40, 0.1); + Family floatFeatureFamily = registry.family("F"); + assertEquals(featureVector.get(floatFeatureFamily.feature("age")), 30, 0.1); + assertEquals(featureVector.get(floatFeatureFamily.feature("hours")), 40, 0.1); - final Map> stringFeatures = featureVector.getStringFeatures(); - Set stringFeatureFamily = stringFeatures.get("S"); - assertFalse(stringFeatureFamily.contains("marital-status")); - assertTrue(stringFeatureFamily.contains("married")); + Family stringFeatureFamily = registry.family("S"); + assertFalse(featureVector.containsKey(stringFeatureFamily.feature("marital-status"))); + assertTrue(featureVector.containsKey(stringFeatureFamily.feature("marital-status:married"))); double score = modelScorer.score(example); log.info("score {}", score); } - private static final String[] dataName1 = {"age", "fnlwgt", "edu-num"}; - private static final float[] data1 = {30, 10, 10}; + private static final String[] dataName1 = {"F_age", "F_fnlwgt", "F_edu-num"}; + private static final double[] data1 = {30, 10, 10}; - private static final String[] dataName2 = {"capital-gain", "capital-loss", "hours"}; - private static final float[] data2 = {3000, 1000, 40}; + private static final String[] dataName2 = {"F_capital-gain", "F_capital-loss", "F_hours"}; + private static final double[] data2 = {3000, 1000, 40}; private static final String[] dataName3 = { - "workclass", "education", "marital-status", - "occupation", "relationship", "race", "sex", - "native-country" + "S_workclass", "S_education", "S_marital-status", + "S_occupation", "S_relationship", "S_race", "S_sex", + "S_native-country" }; private static final String[] data3 = { diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransformTest.java index 6ca48889..251fcca1 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ApproximatePercentileTransformTest.java @@ -1,45 +1,24 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class ApproximatePercentileTransformTest { - private static final Logger log = LoggerFactory.getLogger(ApproximatePercentileTransformTest.class); - - public FeatureVector makeFeatureVector(double low, double high, double val) { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("10th", low); - map.put("90th", high); - floatFeatures.put("DECILES", map); - - Map map2 = new HashMap<>(); - map2.put("foo", val); - floatFeatures.put("F", map2); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +public class ApproximatePercentileTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector(double low, double high, double val) { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .sparse("DECILES", "10th", low) + .sparse("DECILES", "90th", high) + .sparse("F", "foo", val) + .build(); } public String makeConfig() { @@ -55,20 +34,15 @@ public String makeConfig() { " outputKey : percentile\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_approximate_percentile"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_approximate_percentile"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_approximate_percentile"); + Transform transform = getTransform(); double[] values = { -1.0, 10.0, 15.0, 20.0, 50.0, 60.0, 100.0, 200.0 }; double[] expected = { 0.0, 0.0, 0.05, 0.11, 0.44, 0.55, 1.0, 1.0 }; @@ -76,24 +50,21 @@ public void testTransform() { for (int i = 0; i < values.length; i++) { double val = values[i]; - FeatureVector featureVector = makeFeatureVector(10.0, 100.0, val); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + MultiFamilyVector featureVector = makeFeatureVector(10.0, 100.0, val); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); - Map out = featureVector.floatFeatures.get("PERCENTILE"); - assertTrue(out.size() == 1); - assertEquals(expected[i], out.get("percentile"), 0.01); + assertSparseFamily(featureVector, "PERCENTILE", 1, + ImmutableMap.of("percentile", expected[i])); } } @Test public void testAbstain() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_approximate_percentile"); + Transform transform = getTransform(); - FeatureVector featureVector = makeFeatureVector(10.0, 11.0, 1.0); - transform.doTransform(featureVector); - assertTrue(featureVector.floatFeatures.get("PERCENTILE") == null); + MultiFamilyVector featureVector = makeFeatureVector(10.0, 11.0, 1.0); + transform.apply(featureVector); + assertFalse(featureVector.contains(registry.family("PERCENTILE"))); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/BaseTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/BaseTransformTest.java new file mode 100644 index 00000000..b2c2c8a0 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/BaseTransformTest.java @@ -0,0 +1,123 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.DenseVector; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FamilyVector; +import com.airbnb.aerosolve.core.features.Feature; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.FeatureValue; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; +import com.typesafe.config.Config; +import com.typesafe.config.ConfigFactory; +import java.util.Arrays; +import java.util.Map; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * + */ +@Slf4j +public abstract class BaseTransformTest { + + protected final FeatureRegistry registry = new FeatureRegistry(); + + abstract public String makeConfig(); + + abstract public String configKey(); + + // Hackz. This could be better. + protected boolean runEmptyTest() { + return true; + } + + @Test + public void testEmptyFeatureVector() { + if (!runEmptyTest()) { + return; + } + MultiFamilyVector featureVector = transformVector( + TransformTestingHelper.makeEmptyVector(registry)); + + assertTrue(featureVector.size() == 0); + } + + protected MultiFamilyVector transformVector(MultiFamilyVector featureVector) { + getTransform().apply(featureVector); + return featureVector; + } + + protected Transform getTransform() { + return getTransform(makeConfig(), configKey()); + } + + protected Transform getTransform(String configStr, String configKey) { + Config config = ConfigFactory.parseString(configStr); + return TransformFactory.createTransform(config, configKey, registry, null); + } + + public void assertStringFamily(MultiFamilyVector vector, String familyName, + int expectedSize, Set expected) { + assertStringFamily(vector, familyName, expectedSize, expected, ImmutableSet.of()); + } + + // Set expectedSize to -1 if you don't want to test the size. + public void assertStringFamily(MultiFamilyVector vector, String familyName, + int expectedSize, Set expected, + Set unexpectedKeys) { + Family family = registry.family(familyName); + FamilyVector fam = vector.get(family); + assertNotNull(fam); + for (FeatureValue value : fam) { + log.info(value.toString()); + } + assertTrue(fam.size() == expectedSize || expectedSize == -1); + for (String name : expected) { + assertTrue(fam.containsKey(family.feature(name))); + } + for (String name : unexpectedKeys) { + assertFalse(fam.containsKey(family.feature(name))); + } + } + + protected void assertSparseFamily(MultiFamilyVector vector, String familyName, + int expectedSize, + Map expected) { + assertSparseFamily(vector, familyName, expectedSize, expected, ImmutableSet.of()); + } + + // Set expectedSize to -1 if you don't want to test the size. + protected void assertSparseFamily(MultiFamilyVector vector, String familyName, + int expectedSize, + Map expected, + Set unexpectedKeys) { + Family family = registry.family(familyName); + FamilyVector fam = vector.get(family); + assertNotNull(fam); + for (FeatureValue value : fam) { + log.info(value.toString()); + } + assertTrue(fam.size() == expectedSize || expectedSize == -1); + for (Map.Entry entry : expected.entrySet()) { + Feature feat = family.feature(entry.getKey()); + assertTrue(fam.containsKey(feat)); + assertEquals(fam.getDouble(feat), entry.getValue(), 0.01); + } + for (String name : unexpectedKeys) { + assertFalse(fam.containsKey(family.feature(name))); + } + } + + public void assertDenseFamily(MultiFamilyVector vector, String familyName, double[] values) { + FamilyVector fVec = vector.get(registry.family(familyName)); + assertTrue(fVec instanceof DenseVector); + assertTrue(Arrays.equals(((DenseVector) fVec).denseArray(), values)); + } +} diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransformTest.java deleted file mode 100644 index f7061674..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketFloatTransformTest.java +++ /dev/null @@ -1,81 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; - -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertEquals; - -/** - * @author Hector Yee - */ -public class BucketFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(BucketFloatTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.4); - map.put("zero", 0.0); - map.put("negative", -1.5); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } - - public String makeConfig() { - return "test_quantize {\n" + - " transform : bucket_float\n" + - " field1 : loc\n" + - " bucket : 1.0\n" + - " output : loc_quantized\n" + - "}"; - } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); - } - - @Test - public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - - Map out = featureVector.getFloatFeatures().get("loc_quantized"); - log.info("quantize output"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + "=" + entry.getValue()); - } - assertTrue(out.size() == 4); - assertEquals(0.7, out.get("lat[1.0]=37.0").doubleValue(), 0.1); - assertEquals(0.4, out.get("long[1.0]=40.0").doubleValue(), 0.1); - assertEquals(0.0, out.get("zero[1.0]=0.0").doubleValue(), 0.1); - assertEquals(-0.5, out.get("negative[1.0]=-1.0").doubleValue(), 0.1); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketTransformTest.java new file mode 100644 index 00000000..ac13fbda --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/BucketTransformTest.java @@ -0,0 +1,52 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +/** + * @author Hector Yee + */ +public class BucketTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .sparse("loc", "lat", 37.7) + .sparse("loc", "long", 40.4) + .sparse("loc", "zero", 0.0) + .sparse("loc", "negative", -1.5) + .build(); + } + + public String makeConfig() { + return "test_quantize {\n" + + " transform : bucket_float\n" + + " field1 : loc\n" + + " bucket : 1.0\n" + + " output : loc_quantized\n" + + "}"; + } + + @Override + public String configKey() { + return "test_quantize"; + } + + @Test + public void testTransform() { + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "loc_quantized", 4, + ImmutableMap.of("lat[1.0]=37.0", 0.7, + "long[1.0]=40.0", 0.4, + "zero[1.0]=0.0", 0.0, + "negative[1.0]=-1.0", -0.5)); + } +} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CapFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CapFloatTransformTest.java index 4f290a9b..e539cdfe 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CapFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CapFloatTransformTest.java @@ -1,23 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class CapFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(CapFloatTransformTest.class); - +public class CapFloatTransformTest extends BaseTransformTest { public String makeConfig() { return "test_cap {\n" + " transform : cap_float\n" + @@ -38,54 +30,42 @@ public String makeConfigWithOutput() { " output : new_output \n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cap"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_cap"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cap"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); - Map feat1 = featureVector.getFloatFeatures().get("loc"); - - assertEquals(3, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertEquals(39.0, feat1.get("long"), 0.1); - assertEquals(1.0, feat1.get("z"), 0.1); + assertSparseFamily(featureVector, "loc", 3, + ImmutableMap.of("lat", 37.7, + "long", 39.0, + "z", 1.0)); } @Test public void testTransformWithNewOutput() { - Config config = ConfigFactory.parseString(makeConfigWithOutput()); - Transform transform = TransformFactory.createTransform(config, "test_cap"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(makeConfigWithOutput(), "test_cap"); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 5); + // original feature should not change - Map feat1 = featureVector.getFloatFeatures().get("loc"); - assertEquals(3, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertEquals(40.0, feat1.get("long"), 0.1); - assertEquals(-20, feat1.get("z"), 0.1); + assertSparseFamily(featureVector, "loc", 3, + ImmutableMap.of("lat", 37.7, + "long", 40.0, + "z", -20.0)); // capped features are in a new feature family - assertTrue(featureVector.getFloatFeatures().containsKey("new_output")); - Map feat2 = featureVector.getFloatFeatures().get("new_output"); - assertEquals(3, feat2.size()); - assertEquals(37.7, feat2.get("lat"), 0.1); - assertEquals(39.0, feat2.get("long"), 0.1); - assertEquals(1.0, feat2.get("z"), 0.1); + assertSparseFamily(featureVector, "new_output", 3, + ImmutableMap.of("lat", 37.7, + "long", 39.0, + "z", 1.0)); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransformTest.java index a08bd94d..986874c3 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ConvertStringCaseTransformTest.java @@ -1,27 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 1/27/16. */ -public class ConvertStringCaseTransformTest { - private static final Logger log = LoggerFactory.getLogger(ConvertStringCaseTransformTest.class); +public class ConvertStringCaseTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("strFeature1", "I like BLUEBERRY pie, APPLE pie; and I also like BLUE!") + .string("strFeature1", "I'm so excited: I like blue!?!!") + .build(); + } public String makeConfig(boolean convertToUppercase, boolean overwriteInput) { StringBuilder sb = new StringBuilder(); @@ -41,83 +37,48 @@ public String makeConfig(boolean convertToUppercase, boolean overwriteInput) { return sb.toString(); } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - - Set list = new HashSet<>(); - list.add("I like BLUEBERRY pie, APPLE pie; and I also like BLUE!"); - list.add("I'm so excited: I like blue!?!!"); - stringFeatures.put("strFeature1", list); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + public String makeConfig() { + return makeConfig(false, false); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig(false, false)); - Transform transform = TransformFactory.createTransform(config, "test_convert_string_case"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); + @Override + public String configKey() { + return "test_convert_string_case"; } @Test public void testTransformConvertToLowercase() { - Config config = ConfigFactory.parseString(makeConfig(false, false)); - Transform transform = TransformFactory.createTransform(config, "test_convert_string_case"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); - - Set output = stringFeatures.get("bar"); - - assertNotNull(output); - assertEquals(2, output.size()); - assertTrue(output.contains("i like blueberry pie, apple pie; and i also like blue!")); - assertTrue(output.contains("i'm so excited: i like blue!?!!")); + doTest(false, false, ImmutableSet.of( + "i like blueberry pie, apple pie; and i also like blue!", + "i'm so excited: i like blue!?!!"), + 2, "bar"); } @Test public void testTransformConvertToUppercase() { - Config config = ConfigFactory.parseString(makeConfig(true, false)); - Transform transform = TransformFactory.createTransform(config, "test_convert_string_case"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); + doTest(true, false, ImmutableSet.of( + "I LIKE BLUEBERRY PIE, APPLE PIE; AND I ALSO LIKE BLUE!", + "I'M SO EXCITED: I LIKE BLUE!?!!"), + 2, "bar"); + } - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); + private void doTest(boolean convertToUpperCase, boolean overwriteInput, + Set expected, int numFamilies, String family) { + Transform transform = getTransform( + makeConfig(convertToUpperCase, overwriteInput), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - Set output = stringFeatures.get("bar"); + assertTrue(featureVector.numFamilies() == numFamilies); - assertNotNull(output); - assertEquals(2, output.size()); - assertTrue(output.contains("I LIKE BLUEBERRY PIE, APPLE PIE; AND I ALSO LIKE BLUE!")); - assertTrue(output.contains("I'M SO EXCITED: I LIKE BLUE!?!!")); + assertStringFamily(featureVector, family, 2, expected); } @Test public void testTransformOverwriteInput() { - Config config = ConfigFactory.parseString(makeConfig(true, true)); - Transform transform = TransformFactory.createTransform(config, "test_convert_string_case"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - - assertNotNull(stringFeatures); - assertEquals(1, stringFeatures.size()); - - Set output = stringFeatures.get("strFeature1"); - - assertNotNull(output); - assertEquals(2, output.size()); - assertTrue(output.contains("I LIKE BLUEBERRY PIE, APPLE PIE; AND I ALSO LIKE BLUE!")); - assertTrue(output.contains("I'M SO EXCITED: I LIKE BLUE!?!!")); + doTest(true, true, ImmutableSet.of( + "I LIKE BLUEBERRY PIE, APPLE PIE; AND I ALSO LIKE BLUE!", + "I'M SO EXCITED: I LIKE BLUE!?!!"), + 1, "strFeature1"); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CrossTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CrossTransformTest.java index 2a23e916..b774cf6a 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CrossTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CrossTransformTest.java @@ -1,42 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Map; -import java.util.HashSet; -import java.util.Set; -import java.util.HashMap; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class CrossTransformTest { - private static final Logger log = LoggerFactory.getLogger(CrossTransformTest.class); - - public FeatureVector makeFeatureVector() { - HashMap stringFeatures = new HashMap>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("feature1", list); - - Set list2 = new HashSet(); - list2.add("11"); - list2.add("22"); - stringFeatures.put("feature2", list2); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; +public class CrossTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("feature1", "aaa") + .string("feature1", "bbb") + .string("feature2", "11") + .string("feature2", "22") + .build(); } public String makeConfig() { @@ -47,34 +28,23 @@ public String makeConfig() { " output : out\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_cross"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 3); - Set out = stringFeatures.get("out"); - assertTrue(out.size() == 4); - log.info("Cross output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("aaa^11")); - assertTrue(out.contains("aaa^22")); - assertTrue(out.contains("bbb^11")); - assertTrue(out.contains("bbb^22")); - + MultiFamilyVector featureVector = makeFeatureVector(); + transformVector(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "out", 4, ImmutableSet.of( + "aaa^11", + "aaa^22", + "bbb^11", + "bbb^22" + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransformTest.java index 7700f289..31a4389d 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomLinearLogQuantizeTransformTest.java @@ -1,46 +1,19 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; -public class CustomLinearLogQuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(CustomLinearLogQuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("a", 0.0); - map.put("b", 0.13); - map.put("c", 1.23); - map.put("d", 5.0); - map.put("e", 17.5); - map.put("f", 99.98); - map.put("g", 365.0); - map.put("h", 65537.0); - map.put("i", -1.0); - map.put("j", -23.0); - - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +@Slf4j +public class CustomLinearLogQuantizeTransformTest extends BaseTransformTest{ + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .complexLocation() + .build(); } public String makeConfig() { @@ -61,42 +34,27 @@ public String makeConfig() { "}"; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 10); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("a=0.0")); - assertTrue(out.contains("b=0.125")); - assertTrue(out.contains("c=1.0")); - assertTrue(out.contains("d=5.0")); - assertTrue(out.contains("e=16.0")); - assertTrue(out.contains("f=90.0")); - assertTrue(out.contains("g=350.0")); - assertTrue(out.contains("h=10000.0")); - assertTrue(out.contains("i=-1.0")); - assertTrue(out.contains("j=-22.0")); + MultiFamilyVector featureVector = transformVector(makeFeatureVector()); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 10, + ImmutableSet.of("a=0.0", + "b=0.125", + "c=1.0", + "d=5.0", + "e=16.0", + "f=90.0", + "g=350.0", + "h=10000.0", + "i=-1.0", + "j=-22.0")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransformTest.java index 9ef3083f..8a820c86 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CustomMultiscaleQuantizeTransformTest.java @@ -1,141 +1,96 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashSet; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import static org.junit.Assert.assertTrue; -public class CustomMultiscaleQuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(CustomMultiscaleQuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - map.put("zero", 0.0); - map.put("negative", -1.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } - - public String makeConfig(String input) { - return "test_quantize {\n" + - " transform : custom_multiscale_quantize\n" + - " field1 : loc\n" + input + - " buckets : [1, 10]\n" + - " output : loc_quantized\n" + - "}"; - } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig("")); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); - } - - @Test - public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig("")); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 7); - assertTrue(out.contains("lat[10.0]=30.0")); - assertTrue(out.contains("long[1.0]=40.0")); - assertTrue(out.contains("long[10.0]=40.0")); - assertTrue(out.contains("lat[1.0]=37.0")); - assertTrue(out.contains("zero=0")); - assertTrue(out.contains("negative[1.0]=-1.0")); - assertTrue(out.contains("negative[10.0]=0.0")); - } +public class CustomMultiscaleQuantizeTransformTest extends BaseTransformTest { + + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .location() + .sparse("loc", "zero", 0.0) + .sparse("loc", "negative", -1.0) + .build(); + } + + public String makeConfig() { + return makeConfig(""); + } + + public String makeConfig(String input) { + return "test_quantize {\n" + + " transform : custom_multiscale_quantize\n" + + " field1 : loc\n" + input + + " buckets : [1, 10]\n" + + " output : loc_quantized\n" + + "}"; + } + + @Override + public String configKey() { + return "test_quantize"; + } + + @Test + public void testTransform() { + MultiFamilyVector featureVector = makeFeatureVector(); + transformVector(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 7, ImmutableSet.of( + "lat[10.0]=30.0", + "long[1.0]=40.0", + "long[10.0]=40.0", + "lat[1.0]=37.0", + "zero=0", + "negative[1.0]=-1.0", + "negative[10.0]=0.0" + )); + } @Test public void testSelectFeatures() { - Config config = ConfigFactory.parseString(makeConfig("select_features: [\"lat\"] \n")); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 2); - assertTrue(out.contains("lat[10.0]=30.0")); - assertTrue(out.contains("lat[1.0]=37.0")); + Transform transform = getTransform(makeConfig("select_features: [\"lat\"] \n"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 2, ImmutableSet.of( + "lat[10.0]=30.0", + "lat[1.0]=37.0")); } @Test public void testExcludeFeatures() { - Config config = ConfigFactory.parseString(makeConfig("exclude_features: [\"lat\"] \n")); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 5); - assertTrue(out.contains("long[1.0]=40.0")); - assertTrue(out.contains("long[10.0]=40.0")); - assertTrue(out.contains("zero=0")); - assertTrue(out.contains("negative[1.0]=-1.0")); - assertTrue(out.contains("negative[10.0]=0.0")); + Transform transform = getTransform(makeConfig("exclude_features: [\"lat\"] \n"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 5, ImmutableSet.of( + "long[1.0]=40.0", + "long[10.0]=40.0", + "zero=0", + "negative[1.0]=-1.0", + "negative[10.0]=0.0")); } @Test public void testSelectAndExcludeFeatures() { - Config config = ConfigFactory.parseString( - makeConfig("select_features: [\"lat\", \"long\"] \n" + "exclude_features: [\"lat\"] \n")); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 2); - assertTrue(out.contains("long[1.0]=40.0")); - assertTrue(out.contains("long[10.0]=40.0")); + Transform transform = getTransform(makeConfig( + "select_features: [\"lat\", \"long\"] \n" + "exclude_features: [\"lat\"] \n"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 2, ImmutableSet.of( + "long[1.0]=40.0", + "long[10.0]=40.0")); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutFloatTransformTest.java deleted file mode 100644 index e10f8c87..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutFloatTransformTest.java +++ /dev/null @@ -1,114 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; -import java.util.Set; - -import static org.junit.Assert.*; - -public class CutFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(CapFloatTransformTest.class); - - public String makeConfig() { - return "test_cut {\n" + - " transform : cut_float\n" + - " field1 : loc\n" + - " upper_bound : 39.0\n" + - " keys : [lat,long,z,aaa]\n" + - "}"; - } - - public String makeConfigWithOutput() { - return "test_cut {\n" + - " transform : cut_float\n" + - " field1 : loc\n" + - " lower_bound : 1.0\n" + - " upper_bound : 39.0\n" + - " keys : [lat,long,z,aaa]\n" + - " output : new_output \n" + - "}"; - } - - public String makeConfigWithLowerBoundOnly() { - return "test_cut {\n" + - " transform : cut_float\n" + - " field1 : loc\n" + - " lower_bound : 1.0\n" + - " keys : [lat,long,z,aaa]\n" + - "}"; - } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cut"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); - } - - @Test - public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cut"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - - Map feat1 = featureVector.getFloatFeatures().get("loc"); - - assertEquals(2, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertNull(feat1.get("long")); - assertEquals(-20.0, feat1.get("z"), 0.1); - } - - @Test - public void testTransformWithNewOutput() { - Config config = ConfigFactory.parseString(makeConfigWithOutput()); - Transform transform = TransformFactory.createTransform(config, "test_cut"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - // original feature should not change - Map feat1 = featureVector.getFloatFeatures().get("loc"); - assertEquals(3, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertEquals(40.0, feat1.get("long"), 0.1); - assertEquals(-20, feat1.get("z"), 0.1); - - // capped features are in a new feature family - assertTrue(featureVector.getFloatFeatures().containsKey("new_output")); - Map feat2 = featureVector.getFloatFeatures().get("new_output"); - assertEquals(1, feat2.size()); - assertEquals(37.7, feat2.get("lat"), 0.1); - assertNull(feat2.get("long")); - assertNull(feat2.get("z")); - } - - @Test - public void testTransformLowerBoundOnly() { - Config config = ConfigFactory.parseString(makeConfigWithLowerBoundOnly()); - Transform transform = TransformFactory.createTransform(config, "test_cut"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - - Map feat1 = featureVector.getFloatFeatures().get("loc"); - - assertEquals(2, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertEquals(40.0, feat1.get("long"), 0.1); - assertNull(feat1.get("z")); - } - -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutTransformTest.java new file mode 100644 index 00000000..30654533 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/CutTransformTest.java @@ -0,0 +1,89 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class CutTransformTest extends BaseTransformTest { + + public String makeConfig() { + return "test_cut {\n" + + " transform : cut_float\n" + + " field1 : loc\n" + + " upper_bound : 39.0\n" + + " keys : [lat,long,z,aaa]\n" + + "}"; + } + + public String makeConfigWithOutput() { + return "test_cut {\n" + + " transform : cut_float\n" + + " field1 : loc\n" + + " lower_bound : 1.0\n" + + " upper_bound : 39.0\n" + + " keys : [lat,long,z,aaa]\n" + + " output : new_output \n" + + "}"; + } + + public String makeConfigWithLowerBoundOnly() { + return "test_cut {\n" + + " transform : cut_float\n" + + " field1 : loc\n" + + " lower_bound : 1.0\n" + + " keys : [lat,long,z,aaa]\n" + + "}"; + } + + @Override + public String configKey() { + return "test_cut"; + } + + @Test + public void testTransform() { + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); + + assertSparseFamily(featureVector, "loc", 2, ImmutableMap.of( + "lat", 37.7, + "z", -20.0), + ImmutableSet.of("long")); + } + + @Test + public void testTransformWithNewOutput() { + Transform transform = getTransform(makeConfigWithOutput(), configKey()); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 5); + + // original feature should not change + assertSparseFamily(featureVector, "loc", 3, ImmutableMap.of( + "lat", 37.7, + "long", 40.0, + "z", -20.0)); + + assertSparseFamily(featureVector, "new_output", 1, ImmutableMap.of( + "lat", 37.7), + ImmutableSet.of("long", "z")); + } + + @Test + public void testTransformLowerBoundOnly() { + Transform transform = getTransform(makeConfigWithLowerBoundOnly(), configKey()); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); + + assertSparseFamily(featureVector, "loc", 2, ImmutableMap.of( + "lat", 37.7, + "long", 40.0), + ImmutableSet.of("z")); + } +} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateDiffTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateDiffTransformTest.java index 1e481701..ea074eca 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateDiffTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateDiffTransformTest.java @@ -1,26 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.LoggerFactory; -import org.slf4j.Logger; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import static org.junit.Assert.assertTrue; /** * Created by seckcoder on 12/17/15. */ -public class DateDiffTransformTest { - private static final Logger log = LoggerFactory.getLogger(DateDiffTransformTest.class); +public class DateDiffTransformTest extends BaseTransformTest { public String makeConfig() { return "test_datediff {\n" + " transform: date_diff\n" + @@ -30,33 +19,27 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - Set endDates = new HashSet(); - Set startDates = new HashSet(); - endDates.add("2009-03-01"); - startDates.add("2009-02-27"); - stringFeatures.put("endDates", endDates); - stringFeatures.put("startDates", startDates); + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("endDates", "2009-03-01") + .string("startDates", "2009-02-27") + .build(); + } - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + @Override + public String configKey() { + return "test_datediff"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_datediff"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - assertEquals(2, out.get("2009-03-01-m-2009-02-27").intValue()); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "bar", 1, ImmutableMap.of( + "2009-03-01-m-2009-02-27", 2.0 + )); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateValTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateValTransformTest.java index 95ba9fc2..89b3a07d 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateValTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DateValTransformTest.java @@ -1,24 +1,19 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import static org.junit.Assert.assertTrue; /** * Created by seckcoder on 12/20/15. */ -public class DateValTransformTest { +public class DateValTransformTest extends BaseTransformTest { + public String makeConfig() { + return makeConfig("day_of_month"); + } public String makeConfig(String dateType) { return "test_date {\n" + " transform: date_val\n" + @@ -27,98 +22,80 @@ public String makeConfig(String dateType) { " output: bar\n" + "}"; } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set dates = new HashSet(); - dates.add("2009-03-01"); - dates.add("2009-02-27"); - stringFeatures.put("dates", dates); + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("dates", "2009-03-01") + .string("dates", "2009-02-27") + .build(); + } - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + @Override + public String configKey() { + return "test_date"; } @Test public void testDayOfMonthTransform() { - Config config = ConfigFactory.parseString(makeConfig("day_of_month")); - Transform transform = TransformFactory.createTransform(config, "test_date"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - - assertEquals(out.get("2009-03-01"), 1, 0.1); - assertEquals(out.get("2009-02-27"), 27, 0.1); + Transform transform = getTransform(makeConfig("day_of_month"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 2, ImmutableMap.of( + "2009-03-01", 1.0, + "2009-02-27", 27.0 + )); } @Test public void testDayOfWeekTransform() { - Config config = ConfigFactory.parseString(makeConfig("day_of_week")); - Transform transform = TransformFactory.createTransform(config, "test_date"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - - assertEquals(out.get("2009-03-01"), 1, 0.1); - assertEquals(out.get("2009-02-27"), 6, 0.1); + Transform transform = getTransform(makeConfig("day_of_week"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 2, ImmutableMap.of( + "2009-03-01", 1.0, + "2009-02-27", 6.0 + )); } @Test public void testDayOfYearTransform() { - Config config = ConfigFactory.parseString(makeConfig("day_of_year")); - Transform transform = TransformFactory.createTransform(config, "test_date"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - - assertEquals(out.get("2009-03-01"), 60, 0.1); - assertEquals(out.get("2009-02-27"), 58, 0.1); + Transform transform = getTransform(makeConfig("day_of_year"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 2, ImmutableMap.of( + "2009-03-01", 60.0, + "2009-02-27", 58.0 + )); } @Test public void testYearOfDateTransform() { - Config config = ConfigFactory.parseString(makeConfig("year")); - Transform transform = TransformFactory.createTransform(config, "test_date"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - - assertEquals(out.get("2009-03-01"), 2009, 0.1); - assertEquals(out.get("2009-02-27"), 2009, 0.1); + Transform transform = getTransform(makeConfig("year"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 2, ImmutableMap.of( + "2009-03-01", 2009.0, + "2009-02-27", 2009.0 + )); } @Test public void testMonthOfDateTransform() { - Config config = ConfigFactory.parseString(makeConfig("month")); - Transform transform = TransformFactory.createTransform(config, "test_date"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 1); - - Map out = floatFeatures.get("bar"); - - assertEquals(out.get("2009-03-01"), 3, 0.1); - assertEquals(out.get("2009-02-27"), 2, 0.1); + Transform transform = getTransform(makeConfig("month"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 2, ImmutableMap.of( + "2009-03-01", 3.0, + "2009-02-27", 2.0 + )); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransformTest.java index 649a01f6..693c0829 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DecisionTreeTransformTest.java @@ -1,15 +1,13 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.ModelRecord; import com.airbnb.aerosolve.core.models.DecisionTreeModel; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -17,27 +15,15 @@ /** * @author Hector Yee */ -public class DecisionTreeTransformTest { - private static final Logger log = LoggerFactory.getLogger(DecisionTreeTransformTest.class); - - public FeatureVector makeFeatureVector(double x, double y) { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("x", x); - map.put("y", y); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +@Slf4j +public class DecisionTreeTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector(double x, double y) { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .sparse("loc", "x", x) + .sparse("loc", "y", y) + .build(); } public String makeConfig() { @@ -58,6 +44,11 @@ public String makeConfig() { "}"; } + @Override + public String configKey() { + return "test_tree"; + } + /* * XOR like decision regions * @@ -73,8 +64,8 @@ public String makeConfig() { public DecisionTreeModel makeTree() { ArrayList records = new ArrayList<>(); - DecisionTreeModel tree = new DecisionTreeModel(); - tree.setStumps(records); + DecisionTreeModel tree = new DecisionTreeModel(registry); + tree.stumps(records); // 0 - an x split at 2 ModelRecord record = new ModelRecord(); @@ -126,6 +117,16 @@ record = new ModelRecord(); return tree; } + // TODO (Brad): An empty vector does not result in an empty result. + // From what I can tell, the old test worked because it short circuited when the float features + // were null. + // But if the float features were empty instead of null, it would create a non-empty result. . . + // Is this intended? + @Override + protected boolean runEmptyTest() { + return false; + } + @Test public void testToHumanReadableConfig() { DecisionTreeModel tree = makeTree(); @@ -137,34 +138,14 @@ public void testToHumanReadableConfig() { assertTrue(tokens[4].contains("L,3,0.250000,LEAF_3")); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_tree"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); - } - public void testTransformAt(double x, double y, String expectedLeaf, double expectedOutput) { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_tree"); - - FeatureVector featureVector; - featureVector = makeFeatureVector(x, y); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertEquals(2, stringFeatures.size()); - - Set out = featureVector.stringFeatures.get("LEAF"); - for (String entry : out) { - log.info(entry); - } - assertTrue(out.contains(expectedLeaf)); - - Map treeOutput = featureVector.floatFeatures.get("SCORE"); - assertTrue(treeOutput.containsKey("TREE0")); - assertEquals(expectedOutput, treeOutput.get("TREE0"), 0.1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(x, y); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); + + assertStringFamily(featureVector, "LEAF", -1, ImmutableSet.of(expectedLeaf)); + assertSparseFamily(featureVector, "SCORE", -1, ImmutableMap.of("TREE0", expectedOutput)); } @Test diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransformTest.java index fcc0d3c8..f1f9a690 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DefaultStringTokenizerTransformTest.java @@ -1,27 +1,19 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 1/27/16. */ -public class DefaultStringTokenizerTransformTest { - private static final Logger log = LoggerFactory.getLogger( - DefaultStringTokenizerTransformTest.class); +public class DefaultStringTokenizerTransformTest extends BaseTransformTest { + + public String makeConfig() { + return makeConfig("regex", false); + } public String makeConfig(String regex, boolean generateBigrams) { return "test_tokenizer {\n" + @@ -34,117 +26,80 @@ public String makeConfig(String regex, boolean generateBigrams) { "}"; } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("I like blueberry pie, apple pie; and I also like blue!"); - list.add("I'm so excited: I like blue!?!!"); - stringFeatures.put("strFeature1", list); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("strFeature1", "I like blueberry pie, apple pie; and I also like blue!") + .string("strFeature1", "I'm so excited: I like blue!?!!") + .build(); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig("regex", false)); - Transform transform = TransformFactory.createTransform(config, "test_tokenizer"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); - assertTrue(featureVector.getFloatFeatures() == null); + @Override + public String configKey() { + return "test_tokenizer"; } @Test public void testTransformWithoutBigrams() { - Config config = ConfigFactory.parseString(makeConfig("\"\"\"[\\s\\p{Punct}]\"\"\"", false)); - Transform transform = TransformFactory.createTransform(config, "test_tokenizer"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Map> floatFeatures = featureVector.getFloatFeatures(); - - assertEquals(1, stringFeatures.size()); - assertEquals(1, floatFeatures.size()); - - Map output = floatFeatures.get("bar"); - - assertEquals(11, output.size()); - assertEquals(1.0, output.get("apple"), 0.0); - assertEquals(1.0, output.get("blueberry"), 0.0); - assertEquals(2.0, output.get("blue"), 0.0); - assertEquals(3.0, output.get("like"), 0.0); - assertEquals(1.0, output.get("excited"), 0.0); - assertEquals(1.0, output.get("and"), 0.0); - assertEquals(4.0, output.get("I"), 0.0); - assertEquals(1.0, output.get("also"), 0.0); - assertEquals(1.0, output.get("so"), 0.0); - assertEquals(2.0, output.get("pie"), 0.0); - assertEquals(1.0, output.get("m"), 0.0); + Transform transform = getTransform(makeConfig("\"\"\"[\\s\\p{Punct}]\"\"\"", false), + configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "bar", 11, ImmutableMap.builder() + .put("apple", 1.0) + .put("blueberry", 1.0) + .put("blue", 2.0) + .put("like", 3.0) + .put("excited", 1.0) + .put("and", 1.0) + .put("I", 4.0) + .put("also", 1.0) + .put("so", 1.0) + .put("pie", 2.0) + .put("m", 1.0) + .build()); } @Test public void testTransformWithBigrams() { - Config config = ConfigFactory.parseString(makeConfig("\"\"\"[\\s\\p{Punct}]\"\"\"", true)); - Transform transform = TransformFactory.createTransform(config, "test_tokenizer"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - Map> floatFeatures = featureVector.getFloatFeatures(); - - assertEquals(1, stringFeatures.size()); - assertEquals(2, floatFeatures.size()); - - Map output = floatFeatures.get("bar"); - - assertEquals(11, output.size()); - assertEquals(1.0, output.get("apple"), 0.0); - assertEquals(1.0, output.get("blueberry"), 0.0); - assertEquals(2.0, output.get("blue"), 0.0); - assertEquals(3.0, output.get("like"), 0.0); - assertEquals(1.0, output.get("excited"), 0.0); - assertEquals(1.0, output.get("and"), 0.0); - assertEquals(4.0, output.get("I"), 0.0); - assertEquals(1.0, output.get("also"), 0.0); - assertEquals(1.0, output.get("so"), 0.0); - assertEquals(2.0, output.get("pie"), 0.0); - assertEquals(1.0, output.get("m"), 0.0); - - Map bigrams = floatFeatures.get("bigrams"); - - assertEquals(14, bigrams.size()); - assertEquals(2.0, bigrams.get( - "I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "like"), 0.0); - assertEquals(1.0, bigrams.get( - "like" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "blueberry"), 0.0); - assertEquals(1.0, bigrams.get( - "blueberry" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "pie"), 0.0); - assertEquals(1.0, bigrams.get( - "pie" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "apple"), 0.0); - assertEquals(1.0, bigrams.get( - "apple" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "pie"), 0.0); - assertEquals(1.0, bigrams.get( - "pie" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "and"), 0.0); - assertEquals(1.0, bigrams.get( - "and" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "I"), 0.0); - assertEquals(1.0, bigrams.get( - "I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "also"), 0.0); - assertEquals(1.0, bigrams.get( - "also" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "like"), 0.0); - assertEquals(2.0, bigrams.get( - "like" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "blue"), 0.0); - assertEquals(1.0, bigrams.get( - "I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "m"), 0.0); - assertEquals(1.0, bigrams.get( - "m" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "so"), 0.0); - assertEquals(1.0, bigrams.get( - "so" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "excited"), 0.0); - assertEquals(1.0, bigrams.get( - "excited" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "I"), 0.0); + Transform transform = getTransform(makeConfig("\"\"\"[\\s\\p{Punct}]\"\"\"", true), + configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "bar", 11, ImmutableMap.builder() + .put("apple", 1.0) + .put("blueberry", 1.0) + .put("blue", 2.0) + .put("like", 3.0) + .put("excited", 1.0) + .put("and", 1.0) + .put("I", 4.0) + .put("also", 1.0) + .put("so", 1.0) + .put("pie", 2.0) + .put("m", 1.0) + .build()); + + assertSparseFamily(featureVector, "bigrams", 14, ImmutableMap.builder() + .put("I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "like", 2.0) + .put("like" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "blueberry", 1.0) + .put("blueberry" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "pie", 1.0) + .put("pie" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "apple", 1.0) + .put("apple" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "pie", 1.0) + .put("pie" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "and", 1.0) + .put("and" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "I", 1.0) + .put("I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "also", 1.0) + .put("also" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "like", 1.0) + .put("like" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "blue", 2.0) + .put("I" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "m", 1.0) + .put("m" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "so", 1.0) + .put("so" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "excited", 1.0) + .put("excited" + DefaultStringTokenizerTransform.BIGRAM_SEPARATOR + "I", 1.0) + .build()); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFeatureTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFeatureTransformTest.java new file mode 100644 index 00000000..ebcbae34 --- /dev/null +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFeatureTransformTest.java @@ -0,0 +1,46 @@ +package com.airbnb.aerosolve.core.transforms; + +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +/** + * @author Hector Yee + */ +public class DeleteFeatureTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.makeSimpleVector(registry); + } + + public String makeConfig() { + return "test_delete {\n" + + " transform : delete_float_feature\n" + + " field1 : loc\n" + + " keys : [long,aaa]\n" + + "}"; + } + + @Override + public String configKey() { + return "test_delete"; + } + + @Test + public void testTransform() { + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "loc", 1, + ImmutableMap.of("lat", 37.7), + ImmutableSet.of("long")); + + assertStringFamily(featureVector, "strFeature1", 2, + ImmutableSet.of("aaa", "bbb")); + } +} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransformTest.java index fa7fa031..a23dc7ec 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureFamilyTransformTest.java @@ -1,24 +1,13 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.Map; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -public class DeleteFloatFeatureFamilyTransformTest { - private static final Logger log = LoggerFactory.getLogger( - DeleteFloatFeatureFamilyTransformTest.class); +public class DeleteFloatFeatureFamilyTransformTest extends BaseTransformTest { public String makeConfig() { return "test_delete_float_feature_family {\n" + @@ -27,67 +16,40 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVector() { - Map> floatFeatures = new HashMap<>(); - - Map family1 = new HashMap<>(); - family1.put("A", 1.0); - family1.put("B", 2.0); - - Map family2 = new HashMap<>(); - family2.put("C", 3.0); - family2.put("D", 4.0); - - Map family3 = new HashMap<>(); - family3.put("E", 5.0); - family3.put("F", 6.0); - - Map family4 = new HashMap<>(); - family4.put("G", 7.0); - family4.put("H", 8.0); - - floatFeatures.put("F1", family1); - floatFeatures.put("F2", family2); - floatFeatures.put("F3", family3); - floatFeatures.put("F4", family4); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + @Override + public String configKey() { + return "test_delete_float_feature_family"; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_delete_float_feature_family"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getFloatFeatures() == null); + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .sparse("F1", "A", 1.0) + .sparse("F1", "B", 2.0) + .sparse("F2", "C", 3.0) + .sparse("F2", "D", 4.0) + .sparse("F3", "E", 5.0) + .sparse("F3", "F", 6.0) + .sparse("F4", "G", 7.0) + .sparse("F4", "H", 8.0) + .build(); } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_delete_float_feature_family"); - FeatureVector fv = makeFeatureVector(); - - assertNotNull(fv.getFloatFeatures()); - assertTrue(fv.getFloatFeatures().containsKey("F1")); - assertTrue(fv.getFloatFeatures().containsKey("F2")); - assertTrue(fv.getFloatFeatures().containsKey("F3")); - assertTrue(fv.getFloatFeatures().containsKey("F4")); - assertEquals(4, fv.getFloatFeatures().size()); - - transform.doTransform(fv); - - assertNotNull(fv.getFloatFeatures()); - assertFalse(fv.getFloatFeatures().containsKey("F1")); - assertFalse(fv.getFloatFeatures().containsKey("F2")); - assertFalse(fv.getFloatFeatures().containsKey("F3")); - assertTrue(fv.getFloatFeatures().containsKey("F4")); - assertEquals(1, fv.getFloatFeatures().size()); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + + assertTrue(featureVector.contains(registry.family("F1"))); + assertTrue(featureVector.contains(registry.family("F2"))); + assertTrue(featureVector.contains(registry.family("F3"))); + assertTrue(featureVector.contains(registry.family("F4"))); + assertEquals(4, featureVector.numFamilies()); + + transform.apply(featureVector); + assertFalse(featureVector.contains(registry.family("F1"))); + assertFalse(featureVector.contains(registry.family("F2"))); + assertFalse(featureVector.contains(registry.family("F3"))); + assertTrue(featureVector.contains(registry.family("F4"))); + assertEquals(1, featureVector.numFamilies()); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransformTest.java deleted file mode 100644 index 321af3d4..00000000 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteFloatFeatureTransformTest.java +++ /dev/null @@ -1,72 +0,0 @@ -package com.airbnb.aerosolve.core.transforms; - -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -/** - * @author Hector Yee - */ -public class DeleteFloatFeatureTransformTest { - private static final Logger log = LoggerFactory.getLogger(DeleteFloatFeatureTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } - - public String makeConfig() { - return "test_delete {\n" + - " transform : delete_float_feature\n" + - " field1 : loc\n" + - " keys : [long,aaa]\n" + - "}"; - } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_delete"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); - } - - @Test - public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_delete"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - - Map feat1 = featureVector.getFloatFeatures().get("loc"); - - assertEquals(feat1.get("lat"), 37.7, 0.1); - assertTrue(!feat1.containsKey("long")); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransformTest.java index 7fafb33d..736b7c88 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureFamilyTransformTest.java @@ -1,29 +1,16 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 1/29/16. */ -public class DeleteStringFeatureFamilyTransformTest { - private static final Logger log = LoggerFactory.getLogger( - DeleteStringFeatureFamilyTransformTest.class); +public class DeleteStringFeatureFamilyTransformTest extends BaseTransformTest { public String makeConfig() { return "test_delete_string_feature_family {\n" + @@ -32,63 +19,37 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - - Set list1 = new HashSet<>(); - list1.add("I am a string in string feature 1"); - stringFeatures.put("strFeature1", list1); - - Set list2 = new HashSet<>(); - list2.add("I am a string in string feature 2"); - stringFeatures.put("strFeature2", list2); - - Set list3 = new HashSet<>(); - list3.add("I am a string in string feature 3"); - stringFeatures.put("strFeature3", list3); - - Set list4 = new HashSet<>(); - list4.add("I am a string in string feature 4"); - stringFeatures.put("strFeature4", list4); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("strFeature1", "I am a string in string feature 1") + .string("strFeature2", "I am a string in string feature 2") + .string("strFeature3", "I am a string in string feature 3") + .string("strFeature4", "I am a string in string feature 4") + .build(); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_delete_string_feature_family"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); + @Override + public String configKey() { + return "test_delete_string_feature_family"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_delete_string_feature_family"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - - assertNotNull(stringFeatures); - assertTrue(stringFeatures.containsKey("strFeature1")); - assertTrue(stringFeatures.containsKey("strFeature2")); - assertTrue(stringFeatures.containsKey("strFeature3")); - assertTrue(stringFeatures.containsKey("strFeature4")); - assertEquals(4, stringFeatures.size()); - - transform.doTransform(featureVector); - - assertNotNull(stringFeatures); - assertFalse(stringFeatures.containsKey("strFeature1")); - assertFalse(stringFeatures.containsKey("strFeature2")); - assertFalse(stringFeatures.containsKey("strFeature3")); - assertTrue(stringFeatures.containsKey("strFeature4")); - assertEquals(1, stringFeatures.size()); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + + assertTrue(featureVector.contains(registry.family("strFeature1"))); + assertTrue(featureVector.contains(registry.family("strFeature2"))); + assertTrue(featureVector.contains(registry.family("strFeature3"))); + assertTrue(featureVector.contains(registry.family("strFeature4"))); + assertEquals(4, featureVector.numFamilies()); + + transform.apply(featureVector); + + assertFalse(featureVector.contains(registry.family("strFeature1"))); + assertFalse(featureVector.contains(registry.family("strFeature2"))); + assertFalse(featureVector.contains(registry.family("strFeature3"))); + assertTrue(featureVector.contains(registry.family("strFeature4"))); + assertEquals(1, featureVector.numFamilies()); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransformTest.java index c28816f5..2bab3707 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DeleteStringFeatureTransformTest.java @@ -1,42 +1,22 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class DeleteStringFeatureTransformTest { - private static final Logger log = LoggerFactory.getLogger(DeleteStringFeatureTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("aaa:bbbb"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +public class DeleteStringFeatureTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .location() + .string("strFeature1", "aaa:bbbb") + .build(); } public String makeConfig() { @@ -46,28 +26,22 @@ public String makeConfig() { " keys : [long,aaa]\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_delete"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_delete"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_delete"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); - Set feat1 = stringFeatures.get("strFeature1"); - assertEquals(1, feat1.size()); - assertTrue(!feat1.contains("aaa")); - assertTrue(!feat1.contains("aaa:bbbb")); + assertStringFamily(featureVector, "strFeature1", 1, + ImmutableSet.of("bbb"), + ImmutableSet.of("aaa", "aaa:bbbb")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DenseTransformTest.java similarity index 90% rename from core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransformTest.java rename to core/src/test/java/com/airbnb/aerosolve/core/transforms/DenseTransformTest.java index 6b121c27..faee22a8 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatToDenseTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DenseTransformTest.java @@ -1,20 +1,6 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; -import org.junit.Test; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static junit.framework.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -public class FloatToDenseTransformTest { +public class DenseTransformTest { public String makeConfig() { return "test_float_cross_float {\n" + " transform : float_to_dense\n" + @@ -25,7 +11,11 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVectorFull() { + // TODO (Brad): Talk to Julian and make sure we really need this transform + // If we have .denseArray on MultiFamilyVector, can we achieve the same thing with a + // MoveTransform that puts the right features in the family and denseArray()? + + /*public FeatureVector makeFeatureVectorFull() { Map> floatFeatures = new HashMap<>(); Map floatFeature1 = new HashMap<>(); @@ -111,7 +101,7 @@ public void testEmptyFeatureVector() { Config config = ConfigFactory.parseString(makeConfig()); Transform transform = TransformFactory.createTransform(config, "test_float_cross_float"); FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); + transform.apply(featureVector); assertTrue(featureVector.getFloatFeatures() == null); } @@ -182,7 +172,7 @@ public void testString() { public FeatureVector testTransform(FeatureVector featureVector) { Config config = ConfigFactory.parseString(makeConfig()); Transform transform = TransformFactory.createTransform(config, "test_float_cross_float"); - transform.doTransform(featureVector); + transform.apply(featureVector); return featureVector; - } + }*/ } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DivideTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DivideTransformTest.java index bac06164..efc3d6b6 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/DivideTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/DivideTransformTest.java @@ -1,24 +1,17 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class DivideTransformTest { - private static final Logger log = LoggerFactory.getLogger(DivideTransformTest.class); +public class DivideTransformTest extends BaseTransformTest { - public String makeConfigWithKeys() { + public String makeConfig() { return "test_divide {\n" + " transform : divide\n" + " field1 : loc\n" + @@ -29,33 +22,23 @@ public String makeConfigWithKeys() { " output : bar\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfigWithKeys()); - Transform transform = TransformFactory.createTransform(config, "test_divide"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_divide"; } @Test public void testTransformWithKeys() { - Config config = ConfigFactory.parseString(makeConfigWithKeys()); - Transform transform = TransformFactory.createTransform(config, "test_divide"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 4); - Map out = featureVector.floatFeatures.get("bar"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + "=" + entry.getValue()); - } - assertTrue(out.size() == 3); - // the existing features under the family "bar" should not be deleted - assertEquals(1.0, out.get("bar_fv"), 0.1); - assertEquals(37.7 / 1.6, out.get("lat-d-foo"), 0.1); - assertEquals(40.0 / 1.6, out.get("long-d-foo"), 0.1); + assertSparseFamily(featureVector, "bar", 3, ImmutableMap.of( + "bar_fv", 1.0, + "lat-d-foo", 37.7 / 1.6, + "long-d-foo", 40.0 / 1.6 + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransformTest.java index acb08ab2..5a45fedb 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/FloatCrossFloatTransformTest.java @@ -1,27 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 4/8/16. */ -public class FloatCrossFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(FloatCrossFloatTransformTest.class); +public class FloatCrossFloatTransformTest extends BaseTransformTest { public String makeConfig() { return "test_float_cross_float {\n" + @@ -34,64 +22,40 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVector() { - Map> floatFeatures = new HashMap<>(); - - Map floatFeature1 = new HashMap<>(); - - floatFeature1.put("x", 50.0); - floatFeature1.put("y", 1.3); - floatFeature1.put("z", 2000.0); - - Map floatFeature2 = new HashMap<>(); - - floatFeature2.put("i", 1.2); - floatFeature2.put("j", 3.4); - floatFeature2.put("k", 5.6); - - floatFeatures.put("floatFeature1", floatFeature1); - floatFeatures.put("floatFeature2", floatFeature2); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + @Override + public String configKey() { + return "test_float_cross_float"; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_float_cross_float"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getFloatFeatures() == null); + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .sparse("floatFeature1", "x", 50.0) + .sparse("floatFeature1", "y", 1.3) + .sparse("floatFeature1", "z", 2000.0) + .sparse("floatFeature2", "i", 1.2) + .sparse("floatFeature2", "j", 3.4) + .sparse("floatFeature2", "k", 5.6) + .build(); } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_float_cross_float"); - FeatureVector featureVector = makeFeatureVector(); - - transform.doTransform(featureVector); - - Map> floatFeatures = featureVector.getFloatFeatures(); - - assertNotNull(floatFeatures); - assertEquals(3, floatFeatures.size()); - - Map out = floatFeatures.get("out"); - - assertEquals(9, out.size()); - - assertEquals(1.2, out.get("x=50.0^i"), 0.0); - assertEquals(1.2, out.get("y=1.0^i"), 0.0); - assertEquals(1.2, out.get("z=1000.0^i"), 0.0); - assertEquals(3.4, out.get("x=50.0^j"), 0.0); - assertEquals(3.4, out.get("y=1.0^j"), 0.0); - assertEquals(3.4, out.get("z=1000.0^j"), 0.0); - assertEquals(5.6, out.get("x=50.0^k"), 0.0); - assertEquals(5.6, out.get("y=1.0^k"), 0.0); - assertEquals(5.6, out.get("z=1000.0^k"), 0.0); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "out", 9, ImmutableMap.builder() + .put("x=50.0^i", 1.2) + .put("y=1.0^i", 1.2) + .put("z=1000.0^i", 1.2) + .put("x=50.0^j", 3.4) + .put("y=1.0^j", 3.4) + .put("z=1000.0^j", 3.4) + .put("x=50.0^k", 5.6) + .put("y=1.0^k", 5.6) + .put("z=1000.0^k", 5.6) + .build() + ); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransformTest.java index b33ecb64..2b64c633 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeContinuousTransformTest.java @@ -1,42 +1,19 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class KdtreeContinuousTransformTest { - private static final Logger log = LoggerFactory.getLogger(KdtreeContinuousTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +@Slf4j +public class KdtreeContinuousTransformTest extends BaseTransformTest { public String makeConfig() { return "test_kdtree {\n" + @@ -49,37 +26,29 @@ public String makeConfig() { " output : loc_kdt\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_kdtree"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_kdtree"; } @Test public void testTransform() { Config config = ConfigFactory.parseString(makeConfig()); log.info("Model encoded is " + config.getString("test_kdtree.model_base64")); - Transform transform = TransformFactory.createTransform(config, "test_kdtree"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - Map> floatFeatures = featureVector.getFloatFeatures(); - Map out = floatFeatures.get("loc_kdt"); - log.info("loc_kdt"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + " = " + entry.getValue()); - } - assertTrue(out.size() == 2); + Transform transform = + TransformFactory.createTransform(config, "test_kdtree", registry, null); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + // 4 // |--------------- y = 2 // 1 | 2 3 // x = 1 - assertEquals(out.get("0"), 37.7 - 1.0, 0.1); - assertEquals(out.get("2"), 40.0 - 2.0, 0.1); + assertSparseFamily(featureVector, "loc_kdt", 2, ImmutableMap.of( + "0", 37.7 - 1.0, + "2", 40.0 - 2.0 + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeTransformTest.java index 1e80ce9c..35b2fddf 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/KdtreeTransformTest.java @@ -1,41 +1,19 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class KdtreeTransformTest { - private static final Logger log = LoggerFactory.getLogger(KdtreeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +@Slf4j +public class KdtreeTransformTest extends BaseTransformTest { public String makeConfig() { return "test_kdtree {\n" + @@ -48,32 +26,23 @@ public String makeConfig() { " output : loc_kdt\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_kdtree"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_kdtree"; } @Test public void testTransform() { Config config = ConfigFactory.parseString(makeConfig()); log.info("Model encoded is " + config.getString("test_kdtree.model_base64")); - Transform transform = TransformFactory.createTransform(config, "test_kdtree"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_kdt"); - log.info("loc_kdt"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 2); - assertTrue(out.contains("2")); - assertTrue(out.contains("4")); + Transform transform = + TransformFactory.createTransform(config, "test_kdtree", registry, null); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_kdt", 2, ImmutableSet.of("2", "4")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransformTest.java index 68075855..338d2328 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/LinearLogQuantizeTransformTest.java @@ -1,49 +1,21 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class LinearLogQuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(LinearLogQuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("a", 0.0); - map.put("b", 0.13); - map.put("c", 1.23); - map.put("d", 5.0); - map.put("e", 17.5); - map.put("f", 99.98); - map.put("g", 365.0); - map.put("h", 65537.0); - map.put("i", -1.0); - map.put("j", -23.0); +public class LinearLogQuantizeTransformTest extends BaseTransformTest { - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .complexLocation() + .build(); } public String makeConfig() { @@ -54,39 +26,29 @@ public String makeConfig() { " output : loc_quantized\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - assertTrue(out.size() == 10); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("a=0")); - assertTrue(out.contains("b=0.125")); - assertTrue(out.contains("c=1.125")); - assertTrue(out.contains("d=5.0")); - assertTrue(out.contains("e=17.5")); - assertTrue(out.contains("f=90")); - assertTrue(out.contains("g=350")); - assertTrue(out.contains("h=65536")); - assertTrue(out.contains("i=-1.0")); - assertTrue(out.contains("j=-23.0")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 10, + ImmutableSet.of("a=0", + "b=0.125", + "c=1.125", + "d=5.0", + "e=17.5", + "f=90", + "g=350", + "h=65536", + "i=-1.0", + "j=-23.0")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ListTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ListTransformTest.java index 24070adc..79e128e9 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ListTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ListTransformTest.java @@ -1,44 +1,18 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.HashSet; -import java.util.Map; -import java.util.List; -import java.util.HashMap; -import java.util.Set; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class ListTransformTest { - private static final Logger log = LoggerFactory.getLogger(ListTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); +public class ListTransformTest extends BaseTransformTest { - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.makeSimpleVector(registry); } public String makeConfig() { @@ -59,31 +33,22 @@ public String makeConfig() { " transforms : [test_quantize, test_cross]\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_list"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_list"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_list"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 3); - Set out = stringFeatures.get("out"); - assertTrue(out.size() == 4); - log.info("crossed quantized output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("bbb^long=400")); - assertTrue(out.contains("aaa^lat=377")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); + + assertStringFamily(featureVector, "out", 4, + ImmutableSet.of("bbb^long=400", + "aaa^lat=377")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MathFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MathFloatTransformTest.java index af1ceafc..871976ad 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MathFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MathFloatTransformTest.java @@ -1,21 +1,16 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +public class MathFloatTransformTest extends BaseTransformTest { -public class MathFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(MathFloatTransformTest.class); + public String makeConfig() { + return makeConfig("log10"); + } public String makeConfig(String functionName) { return "test_math {\n" + @@ -27,54 +22,51 @@ public String makeConfig(String functionName) { "}"; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig("log10")); - Transform transform = TransformFactory.createTransform(config, "test_math"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getFloatFeatures() == null); + @Override + public String configKey() { + return "test_math"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig("log10")); - Transform transform = TransformFactory.createTransform(config, "test_math"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); + + // the original features are unchanged + assertSparseFamily(featureVector, "loc", 3, + ImmutableMap.of("lat", 37.7, + "long", 40.0, + "z", -20.0)); + + assertSparseFamily(featureVector, "bar", 3, + ImmutableMap.of("lat", Math.log10(37.7), + "long", Math.log10(40.0), + // existing feature in 'bar' should not change + "bar_fv", 1.0), + // for negative value, it would be a missing feature + ImmutableSet.of("z")); + } - Map feat1 = featureVector.getFloatFeatures().get("loc"); - // the original features are not changed - assertEquals(3, feat1.size()); - assertEquals(37.7, feat1.get("lat"), 0.1); - assertEquals(40.0, feat1.get("long"), 0.1); - assertEquals(-20.0, feat1.get("z"), 0.1); + @Test(expected = IllegalArgumentException.class) + public void testUndefinedFunction() { + Transform transform = getTransform(makeConfig("tan"), configKey()); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); - Map feat2 = featureVector.getFloatFeatures().get("bar"); - assertEquals(3, feat2.size()); - assertEquals(Math.log10(37.7), feat2.get("lat"), 0.1); - assertEquals(Math.log10(40.0), feat2.get("long"), 0.1); - // for negative value, it would be a missing feature - assertTrue(!feat2.containsKey("z")); - // existing feature in 'bar' should not change - assertEquals(1.0, feat2.get("bar_fv"), 0.1); + // TODO (Brad): I made this throw an exception. Does it really make sense to just continue + // as usual doing nothing if the function is unknown? + assertTrue(featureVector.numFamilies() == 4); - // test an undefined function - Config config2 = ConfigFactory.parseString(makeConfig("tan")); - Transform transform2 = TransformFactory.createTransform(config2, "test_math"); - FeatureVector featureVector2 = TransformTestingHelper.makeFeatureVector(); - transform2.doTransform(featureVector2); // the original features are unchanged - Map feat3 = featureVector2.getFloatFeatures().get("loc"); - assertEquals(3, feat3.size()); - assertEquals(37.7, feat3.get("lat"), 0.1); - assertEquals(40.0, feat3.get("long"), 0.1); - assertEquals(-20.0, feat3.get("z"), 0.1); + assertSparseFamily(featureVector, "loc", 3, + ImmutableMap.of("lat", 37.7, + "long", 40.0, + "z", -20.0)); // new features should not exist - Map feat4 = featureVector2.getFloatFeatures().get("bar"); - assertEquals(1, feat4.size()); - assertEquals(1.0, feat4.get("bar_fv"), 0.1); + assertSparseFamily(featureVector, "bar", 1, + ImmutableMap.of("bar_fv", 1.0)); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ModelTransformsTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ModelTransformsTest.java index fdbf1550..9b2ab1ef 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ModelTransformsTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ModelTransformsTest.java @@ -1,46 +1,33 @@ package com.airbnb.aerosolve.core.transforms; import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.features.SimpleExample; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.HashMap; import static org.junit.Assert.assertTrue; // Tests all the model transforms. -public class ModelTransformsTest { - private static final Logger log = LoggerFactory.getLogger(ModelTransformsTest.class); +@Slf4j +public class ModelTransformsTest extends BaseTransformTest { // Creates a feature vector given a feature family name and latitude and longitude features. - public FeatureVector makeFeatureVector(String familyName, - Double lat, - Double lng) { - Map> floatFeatures = new HashMap<>(); - Map> denseFeatures = new HashMap<>(); - - Map map = new HashMap<>(); - map.put("lat", lat); - map.put("long", lng); - floatFeatures.put(familyName, map); - - List list = new ArrayList<>(); - list.add(lat); - list.add(lng); - String denseFamilyName = familyName + "_dense"; - denseFeatures.put(denseFamilyName, list); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setFloatFeatures(floatFeatures); - featureVector.setDenseFeatures(denseFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector( + MultiFamilyVector vector, + String familyName, + Double lat, + Double lng) { + return TransformTestingHelper.builder(registry, vector) + .sparse(familyName, "lat", lat) + .sparse(familyName, "long", lng) + .dense(familyName + "_dense", new double[]{lat, lng}) + .build(); } public String makeConfig() { @@ -82,38 +69,49 @@ public String makeConfig() { } private Example makeExample() { - Example example = new Example(); - example.setContext(makeFeatureVector("guest_loc", 1.0, 2.0)); - example.addToExample(makeFeatureVector("host_loc", 3.1, 4.2)); - example.addToExample(makeFeatureVector("host_loc", 5.3, 6.4)); + Example example = new SimpleExample(registry); + makeFeatureVector(example.context(), "guest_loc", 1.0, 2.0); + makeFeatureVector(example.createVector(), "host_loc", 3.1, 4.2); + makeFeatureVector(example.createVector(), "host_loc", 5.3, 6.4); return example; } + // Bit of a hack to use the quantize_guest_loc config. I did this because the main transform + // won't work with the base test. + @Override + public String configKey() { + return "quantize_guest_loc"; + } + @Test public void testTransform() { Config config = ConfigFactory.parseString(makeConfig()); - Transformer transformer = new Transformer(config, "model_transforms"); + Transformer transformer = new Transformer(config, "model_transforms", registry, null); Example example = makeExample(); - transformer.combineContextAndItems(example); - assertTrue(example.example.size() == 2); - FeatureVector ex = example.example.get(0); - assertTrue(ex.stringFeatures.size() == 3); - assertTrue(ex.stringFeatures.get("guest_loc_quantized").contains("lat=10")); - assertTrue(ex.stringFeatures.get("guest_loc_quantized").contains("long=20")); - assertTrue(ex.stringFeatures.get("host_loc_quantized").contains("lat=31")); - assertTrue(ex.stringFeatures.get("host_loc_quantized").contains("long=42")); - assertTrue(ex.stringFeatures.get("gxh_loc").contains("lat=10^lat=31")); - assertTrue(ex.stringFeatures.get("gxh_loc").contains("long=20^lat=31")); - assertTrue(ex.stringFeatures.get("gxh_loc").contains("lat=10^long=42")); - assertTrue(ex.stringFeatures.get("gxh_loc").contains("long=20^long=42")); - assertTrue(ex.floatFeatures.get("guest_loc").get("lat") == 1.0); - assertTrue(ex.floatFeatures.get("guest_loc").get("long") == 2.0); - assertTrue(ex.floatFeatures.get("host_loc").get("lat") == 3.1); - assertTrue(ex.floatFeatures.get("host_loc").get("long") == 4.2); - assertTrue(ex.denseFeatures.get("guest_loc_dense").contains(1.0)); - assertTrue(ex.denseFeatures.get("guest_loc_dense").contains(2.0)); - assertTrue(ex.denseFeatures.get("host_loc_dense").contains(3.1)); - assertTrue(ex.denseFeatures.get("host_loc_dense").contains(4.2)); + example.transform(transformer); + + assertTrue(Iterables.size(example) == 2); + MultiFamilyVector ex = example.iterator().next(); + + assertStringFamily(ex, "guest_loc_quantized", 2, + ImmutableSet.of("lat=10", "long=20")); + + assertStringFamily(ex, "host_loc_quantized", 2, + ImmutableSet.of("lat=31", "long=42")); + + assertStringFamily(ex, "gxh_loc", 4, + ImmutableSet.of("lat=10^lat=31", "long=20^lat=31", + "lat=10^long=42", "long=20^long=42")); + + assertSparseFamily(ex, "guest_loc", 2, + ImmutableMap.of("lat", 1.0, "long", 2.0)); + + assertSparseFamily(ex, "host_loc", 2, + ImmutableMap.of("lat", 3.1, "long", 4.2)); + + assertDenseFamily(ex, "guest_loc_dense", new double[]{1.0, 2.0}); + assertDenseFamily(ex, "host_loc_dense", new double[]{3.1, 4.2}); + log.info(example.toString()); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransformTest.java index b3d68f13..9fb0dbde 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringAndFloatTransformTest.java @@ -1,27 +1,18 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 4/8/16. */ -public class MoveFloatToStringAndFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger( - MoveFloatToStringAndFloatTransformTest.class); +@Slf4j +public class MoveFloatToStringAndFloatTransformTest extends BaseTransformTest { public String makeConfig() { return "test_move_float_to_string_and_float {\n" + @@ -36,75 +27,49 @@ public String makeConfig() { "}"; } - public FeatureVector makeFeatureVector() { - Map> floatFeatures = new HashMap<>(); - - Map floatFeature1 = new HashMap<>(); - - floatFeature1.put("a", 0.0); - floatFeature1.put("b", 10.0); - floatFeature1.put("c", 9.9); - floatFeature1.put("d", 10.1); - floatFeature1.put("e", 11.01); - floatFeature1.put("f", -0.1); - floatFeature1.put("g", -1.01); - floatFeature1.put("h", 21.3); - floatFeature1.put("i", 2000.0); - floatFeature1.put("j", 1.0); - floatFeature1.put("k", 9000.0); - - floatFeatures.put("floatFeature1", floatFeature1); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .sparse("floatFeature1", "a", 0.0) + .sparse("floatFeature1", "b", 10.0) + .sparse("floatFeature1", "c", 9.9) + .sparse("floatFeature1", "d", 10.1) + .sparse("floatFeature1", "e", 11.01) + .sparse("floatFeature1", "f", -0.1) + .sparse("floatFeature1", "g", -1.01) + .sparse("floatFeature1", "h", 21.3) + .sparse("floatFeature1", "i", 2000.0) + .sparse("floatFeature1", "j", 1.0) + .sparse("floatFeature1", "k", 9000.0) + .build(); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_move_float_to_string_and_float"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); - assertTrue(featureVector.getFloatFeatures() == null); + @Override + public String configKey() { + return "test_move_float_to_string_and_float"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform( - config, "test_move_float_to_string_and_float"); - FeatureVector featureVector = makeFeatureVector(); - - transform.doTransform(featureVector); - - Map> stringFeatures = featureVector.getStringFeatures(); - Map> floatFeatures = featureVector.getFloatFeatures(); - - assertNotNull(stringFeatures); - assertEquals(1, stringFeatures.size()); - - assertNotNull(floatFeatures); - assertEquals(2, floatFeatures.size()); - - Set stringOutput = stringFeatures.get("stringOutput"); - Map floatOutput = floatFeatures.get("floatOutput"); - - assertEquals(5, stringOutput.size()); - assertEquals(4, floatOutput.size()); - - assertTrue(stringOutput.contains("a=0.0")); - assertTrue(stringOutput.contains("b=10.0")); - assertTrue(stringOutput.contains("c=9.0")); - assertTrue(stringOutput.contains("d=10.0")); - assertTrue(stringOutput.contains("f=0.0")); - - assertEquals(11.01, floatOutput.get("e"), 0.0); - assertEquals(-1.01, floatOutput.get("g"), 0.0); - assertEquals(21.3, floatOutput.get("h"), 0.0); - assertEquals(2000.0, floatOutput.get("i"), 0.0); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + // TODO (Brad): This test is failing because I intentionally output to a single family + // pending a conversation with Chris about this new transform. + assertTrue(featureVector.numFamilies() == 2); + + assertStringFamily(featureVector, "stringOutput", 5, + ImmutableSet.of("a=0.0", + "b=10.0", + "c=9.0", + "d=10.0", + "f=0.0")); + + assertSparseFamily(featureVector, "floatOutput", 4, + ImmutableMap.of("e", 11.01, + "g", -1.01, + "h", 21.3, + "i", 2000.0)); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransformTest.java index 5b4722d5..60ea18be 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MoveFloatToStringTransformTest.java @@ -1,21 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class MoveFloatToStringTransformTest { - private static final Logger log = LoggerFactory.getLogger(MoveFloatToStringTransformTest.class); +public class MoveFloatToStringTransformTest extends BaseTransformTest { public String makeConfig(boolean moveAllKeys) { StringBuilder sb = new StringBuilder(); @@ -32,49 +26,38 @@ public String makeConfig(boolean moveAllKeys) { sb.append("}"); return sb.toString(); } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig(false)); - Transform transform = TransformFactory.createTransform(config, "test_move_float_to_string"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_move_float_to_string"; + } + + @Override + public String makeConfig() { + return makeConfig(false); } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig(false)); - Transform transform = TransformFactory.createTransform(config, "test_move_float_to_string"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - assertTrue(out.size() == 1); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("lat=37.0")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 5); + + assertStringFamily(featureVector, "loc_quantized", 1, + ImmutableSet.of("lat=37.0")); } @Test public void testTransformMoveAllKeys() { - Config config = ConfigFactory.parseString(makeConfig(true)); - Transform transform = TransformFactory.createTransform(config, "test_move_float_to_string"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - assertTrue(out.size() == 3); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("lat=37.0")); - assertTrue(out.contains("long=40.0")); - assertTrue(out.contains("z=-20.0")); + Transform transform = getTransform(makeConfig(true), configKey()); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); + + assertStringFamily(featureVector, "loc_quantized", 3, + ImmutableSet.of("lat=37.0", "long=40.0", "z=-20.0")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransformTest.java index 7f00e2dd..77236f46 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridContinuousTransformTest.java @@ -1,42 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertEquals; /** * @author Hector Yee */ -public class MultiscaleGridContinuousTransformTest { - private static final Logger log = LoggerFactory.getLogger(MultiscaleGridContinuousTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +public class MultiscaleGridContinuousTransformTest extends BaseTransformTest { public String makeConfig() { return "test_grid {\n" + @@ -48,35 +21,25 @@ public String makeConfig() { " output : loc_continuous\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_grid"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_grid"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_grid"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 2); - Map out = floatFeatures.get("loc_continuous"); - log.info("grid output"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + "=" + entry.getValue()); - } - assertEquals(4, out.size()); - assertEquals(0.7, out.get("[1.0]=(37.0,40.0)@1"), 0.01); - assertEquals(2.7, out.get("[5.0]=(35.0,40.0)@1"), 0.01); - assertEquals(0.0, out.get("[1.0]=(37.0,40.0)@2"), 0.01); - assertEquals(0.0, out.get("[5.0]=(35.0,40.0)@2"), 0.01); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "loc_continuous", 4, ImmutableMap.of( + "[1.0]=(37.0,40.0)@1", 0.7, + "[5.0]=(35.0,40.0)@1", 2.7, + "[1.0]=(37.0,40.0)@2", 0.0, + "[5.0]=(35.0,40.0)@2", 0.0 + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransformTest.java index c207c56b..8aedf145 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleGridQuantizeTransformTest.java @@ -1,41 +1,17 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class MultiscaleGridQuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(MultiscaleGridQuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +@Slf4j +public class MultiscaleGridQuantizeTransformTest extends BaseTransformTest { public String makeConfig() { return "test_quantize {\n" + @@ -47,31 +23,21 @@ public String makeConfig() { " output : loc_quantized\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 2); - assertTrue(out.contains("[10.0]=(30.0,40.0)")); - assertTrue(out.contains("[1.0]=(37.0,40.0)")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 2, + ImmutableSet.of("[10.0]=(30.0,40.0)", "[1.0]=(37.0,40.0)")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransformTest.java index 20678bf3..09fe54c8 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleMoveFloatToStringTransformTest.java @@ -1,41 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class MultiscaleMoveFloatToStringTransformTest { - private static final Logger log = LoggerFactory.getLogger(MultiscaleMoveFloatToStringTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +public class MultiscaleMoveFloatToStringTransformTest extends BaseTransformTest { public String makeConfig() { return "test_quantize {\n" + @@ -46,32 +20,22 @@ public String makeConfig() { " output : loc_quantized\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 2); - assertTrue(out.contains("lat[1.0]=37.0")); - assertTrue(out.contains("lat[10.0]=30.0")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 2, + ImmutableSet.of("lat[1.0]=37.0", "lat[10.0]=30.0")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransformTest.java index 4435da9d..9280ea2a 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/MultiscaleQuantizeTransformTest.java @@ -1,42 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class MultiscaleQuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(MultiscaleQuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - map.put("zero", 0.0); - map.put("negative", -1.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +public class MultiscaleQuantizeTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .location() + .sparse("loc", "zero", 0.0) + .sparse("loc", "negative", -1.0) + .build(); } public String makeConfig() { @@ -47,36 +28,27 @@ public String makeConfig() { " output : loc_quantized\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 7); - assertTrue(out.contains("lat[10.0]=30.0")); - assertTrue(out.contains("long[1.0]=40.0")); - assertTrue(out.contains("long[10.0]=40.0")); - assertTrue(out.contains("lat[1.0]=37.0")); - assertTrue(out.contains("zero=0")); - assertTrue(out.contains("negative[1.0]=-1.0")); - assertTrue(out.contains("negative[10.0]=0.0")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 7, ImmutableSet.of( + "lat[10.0]=30.0", + "long[1.0]=40.0", + "long[10.0]=40.0", + "lat[1.0]=37.0", + "zero=0", + "negative[1.0]=-1.0", + "negative[10.0]=0.0")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NearestTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NearestTransformTest.java index 98380b83..c6fb5d67 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NearestTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NearestTransformTest.java @@ -1,46 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class NearestTransformTest { - private static final Logger log = LoggerFactory.getLogger(NearestTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - Map map2 = new HashMap<>(); - map2.put("foo", 41.0); - floatFeatures.put("f2", map2); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +public class NearestTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .location() + .sparse("f2", "foo", 41.0) + .build(); } - public String makeConfig() { return "test_nearest {\n" + " transform : nearest\n" + @@ -50,30 +27,21 @@ public String makeConfig() { " output : nearest\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_nearest"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_nearest"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_nearest"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("nearest"); - log.info("nearest output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 1); - assertTrue(out.contains("foo~=long")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); + + assertStringFamily(featureVector, "nearest", 1, + ImmutableSet.of("foo~=long")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransformTest.java index ca4f7b4e..0113ca86 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeFloatTransformTest.java @@ -1,22 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class NormalizeFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(CapFloatTransformTest.class); +public class NormalizeFloatTransformTest extends BaseTransformTest { public String makeConfig() { return "test_norm {\n" + @@ -25,31 +18,25 @@ public String makeConfig() { "}"; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_norm"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + @Override + public String configKey() { + return "test_norm"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_norm"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); - Map feat1 = featureVector.getFloatFeatures().get("loc"); + assertTrue(featureVector.numFamilies() == 4); - assertEquals(3, feat1.size()); double scale = 1.0 / Math.sqrt(37.7 * 37.7 + 40.0 * 40.0 + 20.0 * 20.0); - assertEquals(scale * 37.7, feat1.get("lat"), 0.1); - assertEquals(scale * 40.0, feat1.get("long"), 0.1); - assertEquals(scale * -20.0, feat1.get("z"), 0.1); + + assertSparseFamily(featureVector, "loc", 3, + ImmutableMap.of("lat", scale * 37.7, + "long", scale * 40.0, + "z", scale * -20.0)); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8TransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8TransformTest.java index b3edb4ee..a11737a6 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8TransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/NormalizeUtf8TransformTest.java @@ -1,28 +1,17 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import java.text.Normalizer; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 1/27/16. */ -public class NormalizeUtf8TransformTest { - private static final Logger log = LoggerFactory.getLogger(NormalizeUtf8TransformTest.class); +public class NormalizeUtf8TransformTest extends BaseTransformTest { + public static final String FUNKY_STRING = "Funky string: \u03D3\u03D4\u1E9B"; public String makeConfigWithoutNormalizationFormAndOutput() { return "test_normalize_utf_8 {\n" + @@ -31,6 +20,10 @@ public String makeConfigWithoutNormalizationFormAndOutput() { "}"; } + public String makeConfig() { + return makeConfigWithoutNormalizationFormAndOutput(); + } + public String makeConfigWithNormalizationForm(String normalizationForm) { return "test_normalize_utf_8 {\n" + " transform: normalize_utf_8\n" + @@ -40,122 +33,81 @@ public String makeConfigWithNormalizationForm(String normalizationForm) { "}"; } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - - Set list = new HashSet<>(); - list.add("Funky string: \u03D3\u03D4\u1E9B"); - stringFeatures.put("strFeature1", list); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("strFeature1", FUNKY_STRING) + .build(); } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfigWithoutNormalizationFormAndOutput()); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); + @Override + public String configKey() { + return "test_normalize_utf_8"; } @Test public void testTransformDefaultNormalizationFormAndOverwriteInput() { - Config config = ConfigFactory.parseString(makeConfigWithoutNormalizationFormAndOutput()); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - transform.doTransform(featureVector); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - assertNotNull(stringFeatures); - assertEquals(1, stringFeatures.size()); + assertTrue(featureVector.numFamilies() == 1); - Set output = stringFeatures.get("strFeature1"); - - assertNotNull(output); - assertEquals(1, output.size()); - assertTrue(output.contains(Normalizer.normalize( - "Funky string: \u03D3\u03D4\u1E9B", NormalizeUtf8Transform.DEFAULT_NORMALIZATION_FORM))); + assertStringFamily(featureVector, "strFeature1", 1, ImmutableSet.of( + Normalizer.normalize(FUNKY_STRING, NormalizeUtf8Transform.DEFAULT_NORMALIZATION_FORM) + )); } @Test public void testTransformNfcNormalizationForm() { - Config config = ConfigFactory.parseString(makeConfigWithNormalizationForm("NFC")); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - transform.doTransform(featureVector); - - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); + Transform transform = getTransform( + makeConfigWithNormalizationForm("NFC"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - Set output = stringFeatures.get("bar"); + assertTrue(featureVector.numFamilies() == 2); - assertNotNull(output); - assertEquals(1, output.size()); - assertTrue(output.contains("Funky string: \u03D3\u03D4\u1E9B")); - assertTrue(Normalizer.isNormalized("Funky string: \u03D3\u03D4\u1E9B", Normalizer.Form.NFC)); + assertStringFamily(featureVector, "bar", 1, ImmutableSet.of(FUNKY_STRING)); + assertTrue(Normalizer.isNormalized(FUNKY_STRING, Normalizer.Form.NFC)); } @Test public void testTransformNfdNormalizationForm() { - Config config = ConfigFactory.parseString(makeConfigWithNormalizationForm("NFD")); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - transform.doTransform(featureVector); - - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); - - Set output = stringFeatures.get("bar"); - - assertNotNull(output); - assertEquals(1, output.size()); - assertTrue(output.contains("Funky string: \u03D2\u0301\u03D2\u0308\u017F\u0307")); - assertTrue(Normalizer.isNormalized( - "Funky string: \u03D2\u0301\u03D2\u0308\u017F\u0307", Normalizer.Form.NFD)); + Transform transform = getTransform( + makeConfigWithNormalizationForm("NFD"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); + + String normalizedString = "Funky string: \u03D2\u0301\u03D2\u0308\u017F\u0307"; + assertStringFamily(featureVector, "bar", 1, ImmutableSet.of(normalizedString)); + assertTrue(Normalizer.isNormalized(normalizedString, Normalizer.Form.NFD)); } @Test public void testTransformNfkcNormalizationForm() { - Config config = ConfigFactory.parseString(makeConfigWithNormalizationForm("NFKC")); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - transform.doTransform(featureVector); + Transform transform = getTransform( + makeConfigWithNormalizationForm("NFKC"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); + assertTrue(featureVector.numFamilies() == 2); - Set output = stringFeatures.get("bar"); - - assertNotNull(output); - assertEquals(1, output.size()); - assertTrue(output.contains("Funky string: \u038e\u03ab\u1e61")); - assertTrue(Normalizer.isNormalized("Funky string: \u038e\u03ab\u1e61", Normalizer.Form.NFKC)); + String normalizedString = "Funky string: \u038e\u03ab\u1e61"; + assertStringFamily(featureVector, "bar", 1, ImmutableSet.of(normalizedString)); + assertTrue(Normalizer.isNormalized(normalizedString, Normalizer.Form.NFKC)); } @Test public void testTransformNfkdNormalizationForm() { - Config config = ConfigFactory.parseString(makeConfigWithNormalizationForm("NFKD")); - Transform transform = TransformFactory.createTransform(config, "test_normalize_utf_8"); - FeatureVector featureVector = makeFeatureVector(); - Map> stringFeatures = featureVector.getStringFeatures(); - transform.doTransform(featureVector); - - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); - - Set output = stringFeatures.get("bar"); - - assertNotNull(output); - assertEquals(1, output.size()); - assertTrue(output.contains("Funky string: \u03a5\u0301\u03a5\u0308\u0073\u0307")); - assertTrue(Normalizer.isNormalized( - "Funky string: \u03a5\u0301\u03a5\u0308\u0073\u0307", Normalizer.Form.NFKD)); + Transform transform = getTransform(makeConfigWithNormalizationForm("NFKD"), configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); + + String normalizedString = "Funky string: \u03a5\u0301\u03a5\u0308\u0073\u0307"; + assertStringFamily(featureVector, "bar", 1, ImmutableSet.of(normalizedString)); + assertTrue(Normalizer.isNormalized(normalizedString, Normalizer.Form.NFKD)); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ProductTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ProductTransformTest.java index 81d2d222..38c1a3c6 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ProductTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ProductTransformTest.java @@ -1,37 +1,21 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class ProductTransformTest { - private static final Logger log = LoggerFactory.getLogger(ProductTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); +public class ProductTransformTest extends BaseTransformTest { - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - map.put("foo", 7.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .location() + .sparse("loc", "foo", 7.0) + .build(); } public String makeConfig() { @@ -42,26 +26,21 @@ public String makeConfig() { " output : loc_prod\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_prod"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_prod"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_prod"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> floatFeatures = featureVector.getFloatFeatures(); - assertTrue(floatFeatures.size() == 2); - Map out = floatFeatures.get("loc_prod"); - assertTrue(out.size() == 1); - assertEquals((1 + 37.7)*(1+40.0), out.get("*"), 0.1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); + + assertSparseFamily(featureVector, "loc_prod", 1, + ImmutableMap.of("*", (1 + 37.7) * (1 + 40.0))); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/QuantizeTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/QuantizeTransformTest.java index b6bcdc6b..831184b4 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/QuantizeTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/QuantizeTransformTest.java @@ -1,41 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class QuantizeTransformTest { - private static final Logger log = LoggerFactory.getLogger(QuantizeTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +public class QuantizeTransformTest extends BaseTransformTest { public String makeConfig() { return "test_quantize {\n" + @@ -45,31 +19,21 @@ public String makeConfig() { " output : loc_quantized\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_quantize"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_quantize"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("loc_quantized"); - assertTrue(out.size() == 2); - log.info("quantize output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.contains("lat=377")); - assertTrue(out.contains("long=400")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertStringFamily(featureVector, "loc_quantized", 2, + ImmutableSet.of("lat=377", "long=400")); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransformTest.java index 8c0f6583..ac75ebc4 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/ReplaceAllStringsTransformTest.java @@ -1,29 +1,23 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; - -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** * Created by christhetree on 1/27/16. */ -public class ReplaceAllStringsTransformTest { - private static final Logger log = LoggerFactory.getLogger(ReplaceAllStringsTransformTest.class); +public class ReplaceAllStringsTransformTest extends BaseTransformTest { + + public String makeConfig() { + return makeConfig(makeReplacements(), false); + } public String makeConfig(List> replacements, boolean overwriteInput) { StringBuilder sb = new StringBuilder(); @@ -50,17 +44,16 @@ public String makeConfig(List> replacements, boolean overwri return sb.toString(); } - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - - Set list = new HashSet<>(); - list.add("I like blueberry pie, apple pie; and I also like blue!"); - list.add("I'm so excited: I like blue!?!!"); - stringFeatures.put("strFeature1", list); + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .string("strFeature1", "I like blueberry pie, apple pie; and I also like blue!") + .string("strFeature1", "I'm so excited: I like blue!?!!") + .build(); + } - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + @Override + public String configKey() { + return "test_replace_all_strings"; } public List> makeReplacements() { @@ -77,51 +70,32 @@ public List> makeReplacements() { return replacements; } - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig(makeReplacements(), false)); - Transform transform = TransformFactory.createTransform(config, "test_replace_all_strings"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - - assertTrue(featureVector.getStringFeatures() == null); - } - @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig(makeReplacements(), false)); - Transform transform = TransformFactory.createTransform(config, "test_replace_all_strings"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - - assertNotNull(stringFeatures); - assertEquals(2, stringFeatures.size()); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - Set output = stringFeatures.get("bar"); + assertTrue(featureVector.numFamilies() == 2); - assertNotNull(output); - assertEquals(2, output.size()); - assertTrue(output.contains("you like blackberry pie, apple pie; and you also like black!")); - assertTrue(output.contains("I'm so excited: you like black!?!!")); + assertStringFamily(featureVector, "bar", 2, ImmutableSet.of( + "you like blackberry pie, apple pie; and you also like black!", + "I'm so excited: you like black!?!!" + )); } @Test public void testTransformOverwriteInput() { - Config config = ConfigFactory.parseString(makeConfig(makeReplacements(), true)); - Transform transform = TransformFactory.createTransform(config, "test_replace_all_strings"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - - assertNotNull(stringFeatures); - assertEquals(1, stringFeatures.size()); + Transform transform = getTransform(makeConfig(makeReplacements(), true), + configKey()); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - Set output = stringFeatures.get("strFeature1"); + assertTrue(featureVector.numFamilies() == 1); - assertNotNull(output); - assertEquals(2, output.size()); - assertTrue(output.contains("you like blackberry pie, apple pie; and you also like black!")); - assertTrue(output.contains("I'm so excited: you like black!?!!")); + assertStringFamily(featureVector, "strFeature1", 2, ImmutableSet.of( + "you like blackberry pie, apple pie; and you also like black!", + "I'm so excited: you like black!?!!" + )); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransformTest.java index df038733..6e6cbcc4 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/SelfCrossTransformTest.java @@ -1,70 +1,45 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Map; -import java.util.HashSet; -import java.util.Set; -import java.util.HashMap; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class SelfCrossTransformTest { - private static final Logger log = LoggerFactory.getLogger(SelfCrossTransformTest.class); - - public FeatureVector makeFeatureVector() { - HashMap stringFeatures = new HashMap>(); +public class SelfCrossTransformTest extends BaseTransformTest { - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("feature1", list); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .build(); } public String makeConfig() { return "test_cross {\n" + " transform : self_cross\n" + - " field1 : feature1\n" + + " field1 : strFeature1\n" + " output : out\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_cross"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); - Set out = stringFeatures.get("out"); - log.info("Cross output"); - for (String string : out) { - log.info(string); - } - assertTrue(out.size() == 1); - assertTrue(out.contains("aaa^bbb")); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 2); + + assertStringFamily(featureVector, "out", 1, ImmutableSet.of( + "aaa^bbb" + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransformTest.java index f289f82c..59cd6a70 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StringCrossFloatTransformTest.java @@ -1,46 +1,15 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Map; -import java.util.HashSet; -import java.util.Set; -import java.util.HashMap; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertEquals; /** * @author Hector Yee */ -public class StringCrossFloatTransformTest { - private static final Logger log = LoggerFactory.getLogger(StringCrossFloatTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; - } +public class StringCrossFloatTransformTest extends BaseTransformTest { public String makeConfig() { return "test_cross {\n" + @@ -50,35 +19,24 @@ public String makeConfig() { " output : out\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_cross"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_cross"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertEquals(1, stringFeatures.size()); - Map> floatFeatures = featureVector.getFloatFeatures(); - assertEquals(2, floatFeatures.size()); - Map out = floatFeatures.get("out"); - assertTrue(out.size() == 4); - log.info("Cross output"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + '=' + entry.getValue()); - } - assertEquals(37.7, out.get("aaa^lat"), 0.1); - assertEquals(37.7, out.get("bbb^lat"), 0.1); - assertEquals(40.0, out.get("aaa^long"), 0.1); - assertEquals(40.0, out.get("bbb^long"), 0.1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeSimpleVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 3); + + assertSparseFamily(featureVector, "out", 4, ImmutableMap.of( + "aaa^lat", 37.7, + "bbb^lat", 37.7, + "aaa^long", 40.0, + "bbb^long", 40.0)); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StuffIdTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StuffIdTransformTest.java index 9f3a03d2..60d6f247 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StuffIdTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StuffIdTransformTest.java @@ -1,44 +1,24 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class StuffIdTransformTest { - private static final Logger log = LoggerFactory.getLogger(StuffIdTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - Map map = new HashMap<>(); - map.put("searches", 37.7); - floatFeatures.put("FEAT", map); +public class StuffIdTransformTest extends BaseTransformTest { - Map map2 = new HashMap<>(); - map2.put("id", 123456789.0); - floatFeatures.put("ID", map2); - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .sparse("FEAT", "searches", 37.7) + .sparse("ID", "id", 123456789.0) + .build(); } public String makeConfig() { @@ -51,32 +31,21 @@ public String makeConfig() { " output : bar\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_stuff"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_stuff"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_stuff"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); - log.info(featureVector.toString()); + assertTrue(featureVector.numFamilies() == 4); - Map out = featureVector.floatFeatures.get("bar"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + "=" + entry.getValue()); - } - assertTrue(out.size() == 1); - assertEquals(37.7, out.get("searches@123456789"), 0.1); + assertSparseFamily(featureVector, "bar", 1, + ImmutableMap.of("searches@123456789", 37.7)); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StumpTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StumpTransformTest.java index cde6579c..a06245f5 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/StumpTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/StumpTransformTest.java @@ -1,45 +1,22 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class StumpTransformTest { - private static final Logger log = LoggerFactory.getLogger(StumpTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - floatFeatures.put("loc", map); - - Map map2 = new HashMap<>(); - map2.put("foo", 1.0); - floatFeatures.put("F", map2); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; +public class StumpTransformTest extends BaseTransformTest { + + public MultiFamilyVector makeFeatureVector() { + return TransformTestingHelper.builder(registry) + .simpleStrings() + .location() + .sparse("F", "foo", 1.0) + .build(); } public String makeConfig() { @@ -54,30 +31,22 @@ public String makeConfig() { " output : bar\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_stump"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_stump"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_stump"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 2); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); - Set out = featureVector.stringFeatures.get("bar"); - for (String entry : out) { - log.info(entry); - } - assertTrue(out.contains("lat>=30.0")); - assertTrue(out.contains("foo>=0.0")); + assertStringFamily(featureVector, "bar", 2, ImmutableSet.of( + "lat>=30.0", "foo>=0.0" + )); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/SubtractTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/SubtractTransformTest.java index ddd605e7..9108f4a1 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/SubtractTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/SubtractTransformTest.java @@ -1,24 +1,17 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.*; - -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ -public class SubtractTransformTest { - private static final Logger log = LoggerFactory.getLogger(SubtractTransformTest.class); +public class SubtractTransformTest extends BaseTransformTest { - public String makeConfigWithKeys() { + public String makeConfig() { return "test_subtract {\n" + " transform : subtract\n" + " field1 : loc\n" + @@ -28,31 +21,22 @@ public String makeConfigWithKeys() { " output : bar\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfigWithKeys()); - Transform transform = TransformFactory.createTransform(config, "test_subtract"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_subtract"; } @Test public void testTransformWithKeys() { - Config config = ConfigFactory.parseString(makeConfigWithKeys()); - Transform transform = TransformFactory.createTransform(config, "test_subtract"); - FeatureVector featureVector = TransformTestingHelper.makeFeatureVector(); - transform.doTransform(featureVector); - Map> stringFeatures = featureVector.getStringFeatures(); - assertTrue(stringFeatures.size() == 1); - - Map out = featureVector.floatFeatures.get("bar"); - for (Map.Entry entry : out.entrySet()) { - log.info(entry.getKey() + "=" + entry.getValue()); - } - assertTrue(out.size() == 2); - assertEquals(36.2, out.get("lat-foo"), 0.1); - assertEquals(1.0, out.get("bar_fv"), 0.1); + Transform transform = getTransform(); + MultiFamilyVector featureVector = TransformTestingHelper.makeFoobarVector(registry); + transform.apply(featureVector); + + assertTrue(featureVector.numFamilies() == 4); + + assertSparseFamily(featureVector, "bar", 2, + ImmutableMap.of("lat-foo", 36.2, + "bar_fv", 1.0)); } } \ No newline at end of file diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/TransformTestingHelper.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/TransformTestingHelper.java index a0770bd7..c5677ef7 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/TransformTestingHelper.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/TransformTestingHelper.java @@ -1,39 +1,106 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import com.airbnb.aerosolve.core.features.BasicMultiFamilyVector; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; public class TransformTestingHelper { - public static FeatureVector makeFeatureVector() { - Map> stringFeatures = new HashMap<>(); - Map> floatFeatures = new HashMap<>(); - - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - stringFeatures.put("strFeature1", list); - - Map map = new HashMap<>(); - map.put("lat", 37.7); - map.put("long", 40.0); - map.put("z", -20.0); - floatFeatures.put("loc", map); - - Map map2 = new HashMap<>(); - map2.put("foo", 1.5); - floatFeatures.put("F", map2); - - Map map3 = new HashMap<>(); - map3.put("bar_fv", 1.0); - floatFeatures.put("bar", map3); - - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public static MultiFamilyVector makeEmptyVector() { + FeatureRegistry registry = new FeatureRegistry(); + return makeEmptyVector(registry); + } + + public static MultiFamilyVector makeEmptyVector(FeatureRegistry registry) { + return new BasicMultiFamilyVector(registry); + } + + public static MultiFamilyVector makeSimpleVector(FeatureRegistry registry) { + return builder(registry) + .simpleStrings() + .location() + .build(); + } + + public static MultiFamilyVector makeFoobarVector(FeatureRegistry registry) { + return builder(registry) + .simpleStrings() + .location() + .foobar() + .build(); + } + + public static VectorBuilder builder(FeatureRegistry registry) { + return new VectorBuilder(registry); + } + + public static VectorBuilder builder(FeatureRegistry registry, MultiFamilyVector vector) { + return new VectorBuilder(registry, vector); + } + + // Not actually a real builder. But calling things twice should be idempotent so . . . + public static class VectorBuilder { + private final FeatureRegistry registry; + private final MultiFamilyVector vector; + + public VectorBuilder(FeatureRegistry registry) { + this(registry, new BasicMultiFamilyVector(registry)); + } + + public VectorBuilder(FeatureRegistry registry, MultiFamilyVector vector) { + this.registry = registry; + this.vector = vector; + } + + public MultiFamilyVector build() { + return vector; + } + + public VectorBuilder sparse(String family, String name, double value) { + vector.put(registry.feature(family, name), value); + return this; + } + + public VectorBuilder string(String family, String name) { + vector.putString(registry.feature(family, name)); + return this; + } + + public VectorBuilder dense(String family, double[] values) { + vector.putDense(registry.family(family), values); + return this; + } + + public VectorBuilder simpleStrings() { + return this + .string("strFeature1", "aaa") + .string("strFeature1", "bbb"); + } + + public VectorBuilder location() { + return this + .sparse("loc", "lat", 37.7) + .sparse("loc", "long", 40.0); + } + + public VectorBuilder foobar() { + return this + .sparse("loc", "z", -20.0) + .sparse("F", "foo", 1.5) + .sparse("bar", "bar_fv", 1.0); + } + + public VectorBuilder complexLocation() { + return this + .sparse("loc", "a", 0.0) + .sparse("loc", "b", 0.13) + .sparse("loc", "c", 1.23) + .sparse("loc", "d", 5.0) + .sparse("loc", "e", 17.5) + .sparse("loc", "f", 99.98) + .sparse("loc", "g", 365.0) + .sparse("loc", "h", 65537.0) + .sparse("loc", "i", -1.0) + .sparse("loc", "j", -23.0); + } } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/transforms/WtaTransformTest.java b/core/src/test/java/com/airbnb/aerosolve/core/transforms/WtaTransformTest.java index d660859b..5d4d81a4 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/transforms/WtaTransformTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/transforms/WtaTransformTest.java @@ -1,33 +1,28 @@ package com.airbnb.aerosolve.core.transforms; -import com.airbnb.aerosolve.core.FeatureVector; -import com.typesafe.config.Config; -import com.typesafe.config.ConfigFactory; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.google.common.collect.ImmutableSet; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertTrue; -public class WtaTransformTest { - private static final Logger log = LoggerFactory.getLogger(WtaTransformTest.class); - - public FeatureVector makeFeatureVector() { - Map> denseFeatures = new HashMap<>(); +@Slf4j +public class WtaTransformTest extends BaseTransformTest { - List feature = new ArrayList<>(); - List feature2 = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - feature.add(0.1 * i); - feature2.add(-0.1 * i); + public MultiFamilyVector makeFeatureVector() { + int max = 100; + double[] feature = new double[max]; + double[] feature2 = new double[max]; + for (int i = 0; i < max; i++) { + feature[i] = 0.1 * i; + feature2[i] = -0.1 * i; } - denseFeatures.put("a", feature); - denseFeatures.put("b", feature2); - FeatureVector featureVector = new FeatureVector(); - featureVector.setDenseFeatures(denseFeatures); - return featureVector; + + return TransformTestingHelper.builder(registry) + .dense("a", feature) + .dense("b", feature2) + .build(); } public String makeConfig() { @@ -40,34 +35,27 @@ public String makeConfig() { " num_tokens_per_word : 4\n" + "}"; } - - @Test - public void testEmptyFeatureVector() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_wta"); - FeatureVector featureVector = new FeatureVector(); - transform.doTransform(featureVector); - assertTrue(featureVector.getStringFeatures() == null); + + @Override + public String configKey() { + return "test_wta"; } @Test public void testTransform() { - Config config = ConfigFactory.parseString(makeConfig()); - Transform transform = TransformFactory.createTransform(config, "test_wta"); - FeatureVector featureVector = makeFeatureVector(); - transform.doTransform(featureVector); + Transform transform = getTransform(); + MultiFamilyVector featureVector = makeFeatureVector(); + transform.apply(featureVector); + log.info(featureVector.toString()); - assertTrue(featureVector.stringFeatures != null); - Set wta = featureVector.stringFeatures.get("wta"); - assertTrue(wta != null); - assertTrue(wta.size() == 8); - assertTrue(wta.contains("a0:71")); - assertTrue(wta.contains("a1:60")); - assertTrue(wta.contains("a2:81")); - assertTrue(wta.contains("a3:103")); - assertTrue(wta.contains("b0:34")); - assertTrue(wta.contains("b1:107")); - assertTrue(wta.contains("b2:7")); - assertTrue(wta.contains("b3:193")); + + assertTrue(featureVector.numFamilies() == 3); + + // TODO (Brad): Because of the change to not re-seed the random on each family, this now + // produces different result for the second family ("b"). Maybe revisit but this seems + // reasonable to change the values. + assertStringFamily(featureVector, "wta", 8, ImmutableSet.of( + "a0:71", "a1:60", "a2:81", "a3:103", "b0:254", "b1:104", "b2:134", "b3:21" + )); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/DateUtilTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/DateUtilTest.java index d66d6088..871f4b6a 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/DateUtilTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/DateUtilTest.java @@ -1,9 +1,8 @@ package com.airbnb.aerosolve.core.util; -import org.junit.Test; - import java.util.Map; import java.util.function.Function; +import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureDictionaryTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureDictionaryTest.java index 3a83925c..e3541cca 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureDictionaryTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureDictionaryTest.java @@ -1,17 +1,13 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; import java.util.ArrayList; import java.util.List; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import java.util.HashSet; +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -19,132 +15,120 @@ /** * @author Hector Yee */ +@Slf4j public class FeatureDictionaryTest { - private static final Logger log = LoggerFactory.getLogger(FeatureDictionaryTest.class); + private final FeatureRegistry registry = new FeatureRegistry(); - public FeatureVector makeDenseFeatureVector(String id, + public MultiFamilyVector makeDenseFeatureVector(String id, double v1, double v2) { - ArrayList list = new ArrayList(); - list.add(v1); - list.add(v2); - HashMap denseFeatures = new HashMap>(); - denseFeatures.put("a", list); - ArrayList list2 = new ArrayList(); - list2.add(v2); - list2.add(v1); - denseFeatures.put("b", list2); - Set list3 = new HashSet<>(); - list3.add(id); - Map> stringFeatures = new HashMap<>(); - stringFeatures.put("id", list3); - FeatureVector featureVector = new FeatureVector(); - featureVector.setDenseFeatures(denseFeatures); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + vector.putDense(registry.family("a"), new double[]{v1, v2}); + vector.putDense(registry.family("b"), new double[]{v2, v1}); + vector.putString(registry.feature("id", id)); + return vector; } - public FeatureVector makeSparseFeatureVector(String id, + public MultiFamilyVector makeSparseFeatureVector(String id, String v1, String v2) { - Set set = new HashSet(); - set.add(v1); - set.add(v2); - Map> stringFeatures = new HashMap<>(); - stringFeatures.put("a", set); - Set set2 = new HashSet<>(); - set2.add(id); - stringFeatures.put("id", set2); - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + Family stringFamily = registry.family("a"); + vector.putString(stringFamily.feature(v1)); + vector.putString(stringFamily.feature(v2)); + vector.putString(registry.feature("id", id)); + return vector; } @Test public void testDictionaryMinKernel() { - List dict = new ArrayList<>(); + List dict = new ArrayList<>(); dict.add(makeDenseFeatureVector("0", 0.3, 0.7)); dict.add(makeDenseFeatureVector("1", 0.7, 0.3)); dict.add(makeDenseFeatureVector("2", 0.5, 0.5)); dict.add(makeDenseFeatureVector("3", 0.0, 0.0)); - FeatureDictionary dictionary = new MinKernelDenseFeatureDictionary(); + FeatureDictionary dictionary = new MinKernelDenseFeatureDictionary(registry); dictionary.setDictionaryList(dict); - KNearestNeighborsOptions opt = new KNearestNeighborsOptions(); - opt.setIdKey("id"); - opt.setOutputKey("mk"); - opt.setNumNearest(2); + Family outputFamily = registry.family("mk"); + KNearestNeighborsOptions opt = KNearestNeighborsOptions.builder() + .idKey(registry.family("id")) + .outputKey(outputFamily) + .numNearest(2) + .build(); - FeatureVector response1 = dictionary.getKNearestNeighbors(opt, dict.get(0)); + MultiFamilyVector response1 = dictionary.getKNearestNeighbors(opt, dict.get(0)); log.info(response1.toString()); - assertTrue(response1.floatFeatures != null); - assertEquals(response1.getFloatFeatures().get("mk").size(), 2); - assertEquals(response1.getFloatFeatures().get("mk").get("1"), 1.2, 0.1); - assertEquals(response1.getFloatFeatures().get("mk").get("2"), 1.6, 0.1); + assertTrue(response1.size() > 0); + assertEquals(response1.get(outputFamily).size(), 2); + assertEquals(response1.get(outputFamily.feature("1")), 1.2, 0.1); + assertEquals(response1.get(outputFamily.feature("2")), 1.6, 0.1); - FeatureVector response2 = dictionary.getKNearestNeighbors(opt, dict.get(1)); + MultiFamilyVector response2 = dictionary.getKNearestNeighbors(opt, dict.get(1)); log.info(response2.toString()); - assertTrue(response2.floatFeatures != null); - assertEquals(response2.getFloatFeatures().get("mk").size(), 2); - assertEquals(response2.getFloatFeatures().get("mk").get("0"), 1.2, 0.1); - assertEquals(response2.getFloatFeatures().get("mk").get("2"), 1.6, 0.1); + assertTrue(response2.size() > 0); + assertEquals(response2.get(outputFamily).size(), 2); + assertEquals(response2.get(outputFamily.feature("0")), 1.2, 0.1); + assertEquals(response2.get(outputFamily.feature("2")), 1.6, 0.1); - FeatureVector response3 = dictionary.getKNearestNeighbors(opt, dict.get(2)); + MultiFamilyVector response3 = dictionary.getKNearestNeighbors(opt, dict.get(2)); log.info(response3.toString()); - assertTrue(response3.floatFeatures != null); - assertEquals(response3.getFloatFeatures().get("mk").size(), 2); - assertEquals(response3.getFloatFeatures().get("mk").get("0"), 1.6, 0.1); - assertEquals(response3.getFloatFeatures().get("mk").get("1"), 1.6, 0.1); + assertTrue(response3.size() > 0); + assertEquals(response3.get(outputFamily).size(), 2); + assertEquals(response3.get(outputFamily.feature("0")), 1.6, 0.1); + assertEquals(response3.get(outputFamily.feature("1")), 1.6, 0.1); } @Test public void testDictionaryLSH() { - List dict = new ArrayList<>(); + List dict = new ArrayList<>(); dict.add(makeSparseFeatureVector("0", "a", "b")); dict.add(makeSparseFeatureVector("1", "b", "c")); dict.add(makeSparseFeatureVector("2", "a", "e")); dict.add(makeSparseFeatureVector("3", "e", "f")); dict.add(makeSparseFeatureVector("4", "@@@", "$$$")); dict.add(makeSparseFeatureVector("5", "$$$", "@@@")); - FeatureDictionary dictionary = new LocalitySensitiveHashSparseFeatureDictionary(); + FeatureDictionary dictionary = new LocalitySensitiveHashSparseFeatureDictionary(registry); dictionary.setDictionaryList(dict); - KNearestNeighborsOptions opt = new KNearestNeighborsOptions(); - opt.setIdKey("id"); - opt.setOutputKey("sim"); - opt.setNumNearest(2); - opt.setFeatureKey("a"); + Family outputFamily = registry.family("sim"); + KNearestNeighborsOptions opt = KNearestNeighborsOptions.builder() + .idKey(registry.family("id")) + .outputKey(outputFamily) + .featureKey(registry.family("a")) + .numNearest(2) + .build(); - FeatureVector response1 = dictionary.getKNearestNeighbors(opt, dict.get(0)); + MultiFamilyVector response1 = dictionary.getKNearestNeighbors(opt, dict.get(0)); log.info(response1.toString()); - assertTrue(response1.floatFeatures != null); - assertEquals(response1.getFloatFeatures().get("sim").size(), 2); - assertEquals(response1.getFloatFeatures().get("sim").get("1"), 1.0, 0.1); - assertEquals(response1.getFloatFeatures().get("sim").get("2"), 1.0, 0.1); + assertTrue(response1.size() > 0); + assertEquals(response1.get(outputFamily).size(), 2); + assertEquals(response1.get(outputFamily.feature("1")), 1.0, 0.1); + assertEquals(response1.get(outputFamily.feature("2")), 1.0, 0.1); - FeatureVector response2 = dictionary.getKNearestNeighbors(opt, dict.get(1)); + MultiFamilyVector response2 = dictionary.getKNearestNeighbors(opt, dict.get(1)); log.info(response2.toString()); - assertTrue(response2.floatFeatures != null); - assertEquals(response2.getFloatFeatures().get("sim").size(), 1); - assertEquals(response2.getFloatFeatures().get("sim").get("0"), 1.0, 0.1); + assertTrue(response2.size() > 0); + assertEquals(response2.get(outputFamily).size(), 1); + assertEquals(response2.get(outputFamily.feature("0")), 1.0, 0.1); - FeatureVector response3 = dictionary.getKNearestNeighbors(opt, dict.get(2)); + MultiFamilyVector response3 = dictionary.getKNearestNeighbors(opt, dict.get(2)); log.info(response3.toString()); - assertTrue(response3.floatFeatures != null); - assertEquals(response3.getFloatFeatures().get("sim").size(), 2); - assertEquals(response3.getFloatFeatures().get("sim").get("0"), 1.0, 0.1); - assertEquals(response3.getFloatFeatures().get("sim").get("3"), 1.0, 0.1); + assertTrue(response3.size() > 0); + assertEquals(response3.get(outputFamily).size(), 2); + assertEquals(response3.get(outputFamily.feature("0")), 1.0, 0.1); + assertEquals(response3.get(outputFamily.feature("3")), 1.0, 0.1); - FeatureVector response4 = dictionary.getKNearestNeighbors(opt, dict.get(3)); + MultiFamilyVector response4 = dictionary.getKNearestNeighbors(opt, dict.get(3)); log.info(response4.toString()); - assertTrue(response4.floatFeatures != null); - assertEquals(response4.getFloatFeatures().get("sim").size(), 1); - assertEquals(response4.getFloatFeatures().get("sim").get("2"), 1.0, 0.1); + assertTrue(response4.size() > 0); + assertEquals(response4.get(outputFamily).size(), 1); + assertEquals(response4.get(outputFamily.feature("2")), 1.0, 0.1); - FeatureVector response5 = dictionary.getKNearestNeighbors(opt, dict.get(4)); + MultiFamilyVector response5 = dictionary.getKNearestNeighbors(opt, dict.get(4)); log.info(response5.toString()); - assertTrue(response5.floatFeatures != null); - assertEquals(response5.getFloatFeatures().get("sim").size(), 1); - assertEquals(response5.getFloatFeatures().get("sim").get("5"), 2.0, 0.1); + assertTrue(response5.size() > 0); + assertEquals(response5.get(outputFamily).size(), 1); + assertEquals(response5.get(outputFamily.feature("5")), 2.0, 0.1); } } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureVectorUtilTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureVectorUtilTest.java index ded3903d..06939a63 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureVectorUtilTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/FeatureVectorUtilTest.java @@ -1,39 +1,31 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.HashMap; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ +@Slf4j public class FeatureVectorUtilTest { - private static final Logger log = LoggerFactory.getLogger(FeatureVectorUtilTest.class); - - public FeatureVector makeFeatureVector(String key, double v1, double v2) { - ArrayList list = new ArrayList(); - list.add(v1); - list.add(v2); - HashMap denseFeatures = new HashMap>(); - denseFeatures.put(key, list); - FeatureVector featureVector = new FeatureVector(); - featureVector.setDenseFeatures(denseFeatures); - return featureVector; + private final FeatureRegistry registry = new FeatureRegistry(); + + public MultiFamilyVector makeFeatureVector(String key, double v1, double v2) { + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + vector.putDense(registry.family(key), new double[]{v1, v2}); + return vector; } @Test public void testEmptyFeatureVectorMinKernel() { - FeatureVector a = new FeatureVector(); - FeatureVector b = new FeatureVector(); + FeatureVector a = TransformTestingHelper.makeEmptyVector(registry); + FeatureVector b = TransformTestingHelper.makeEmptyVector(registry); assertEquals(0.0, FeatureVectorUtil.featureVectorMinKernel(a, b), 0.1); } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/FloatVectorTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/FloatVectorTest.java index aadf7473..a06a4f49 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/FloatVectorTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/FloatVectorTest.java @@ -1,18 +1,10 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.HashMap; -import java.util.Set; - import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @author Hector Yee diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/ReinforcementLearningTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/ReinforcementLearningTest.java index e395ed96..a0bc1cc5 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/ReinforcementLearningTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/ReinforcementLearningTest.java @@ -1,67 +1,77 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.models.AbstractModel; import com.airbnb.aerosolve.core.models.KernelModel; - +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ +@Slf4j public class ReinforcementLearningTest { - private static final Logger log = LoggerFactory.getLogger(ReinforcementLearningTest.class); - + private final FeatureRegistry registry = new FeatureRegistry(); + StringDictionary makeDictionary() { StringDictionary dict = new StringDictionary(); // The locations vary between 0 and 10 - dict.possiblyAdd("S", "x", 0.0, 1.0); - dict.possiblyAdd("S", "y", 0.0, 1.0); + dict.possiblyAdd(registry.feature("S", "x"), 0.0, 1.0); + dict.possiblyAdd(registry.feature("S", "y"), 0.0, 1.0); // The actions are +/- 1 - dict.possiblyAdd("A", "dx", 0.0, 1.0); - dict.possiblyAdd("A", "dy", 0.0, 1.0); + dict.possiblyAdd(registry.feature("A", "dx"), 0.0, 1.0); + dict.possiblyAdd(registry.feature("A", "dy"), 0.0, 1.0); return dict; } - public HashMap makeState(double x, double y) { - HashMap stateFeatures = new HashMap(); + public Map makeState(double x, double y) { + Map stateFeatures = new HashMap<>(); stateFeatures.put("x", x); stateFeatures.put("y", y); return stateFeatures; } - public HashMap makeAction(double dx, double dy) { - HashMap actionFeatures = new HashMap(); + public Map makeAction(double dx, double dy) { + Map actionFeatures = new HashMap<>(); actionFeatures.put("dx", dx); actionFeatures.put("dy", dy); return actionFeatures; } - public FeatureVector makeFeatureVector(HashMap state, HashMap action) { - HashMap floatFeatures = new HashMap>(); - floatFeatures.put("S", state); - floatFeatures.put("A", action); - FeatureVector featureVector = new FeatureVector(); - featureVector.setFloatFeatures(floatFeatures); - return featureVector; + public MultiFamilyVector makeFeatureVector(Map state, + Map action) { + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + loadVector(vector, registry.family("S"), state); + loadVector(vector, registry.family("A"), action); + + return vector; + } + + private void loadVector(MultiFamilyVector vector, Family family, Map values) { + for (Map.Entry entry : values.entrySet()) { + vector.put(family.feature(entry.getKey()), entry.getValue()); + } } - + public ArrayList stateActions(double x, double y) { - HashMap currState = makeState(x, y); + Map currState = makeState(x, y); ArrayList potential = new ArrayList<>(); - HashMap up = makeAction(0.0f, 1.0f); - HashMap down = makeAction(0.0f, -1.0f); - HashMap left = makeAction(-1.0f, 0.0f); - HashMap right = makeAction(1.0f, 0.0f); + Map up = makeAction(0.0f, 1.0f); + Map down = makeAction(0.0f, -1.0f); + Map left = makeAction(-1.0f, 0.0f); + Map right = makeAction(1.0f, 0.0f); potential.add(makeFeatureVector(currState, up)); potential.add(makeFeatureVector(currState, down)); @@ -71,16 +81,16 @@ public ArrayList stateActions(double x, double y) { } AbstractModel makeModel() { - KernelModel model = new KernelModel(); + KernelModel model = new KernelModel(registry); StringDictionary dict = makeDictionary(); - model.setDictionary(dict); - List supportVectors = model.getSupportVectors(); + model.dictionary(dict); + List supportVectors = model.supportVectors(); Random rnd = new Random(12345); for (double x = 0.0; x <= 10.0; x += 1.0) { for (double y = 0.0; y <= 10.0; y += 1.0) { ArrayList potential = stateActions(x, y); for (FeatureVector sa : potential) { - FloatVector vec = dict.makeVectorFromSparseFloats(Util.flattenFeature(sa)); + FloatVector vec = dict.makeVectorFromSparseFloats(sa); SupportVector sv = new SupportVector(vec, FunctionForm.RADIAL_BASIS_FUNCTION, 1.0f, 0.0f); supportVectors.add(sv); } diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/StringDictionaryTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/StringDictionaryTest.java index 6b5aee46..617da434 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/StringDictionaryTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/StringDictionaryTest.java @@ -1,38 +1,31 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; - -import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.DictionaryEntry; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import java.util.HashSet; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ +@Slf4j public class StringDictionaryTest { - private static final Logger log = LoggerFactory.getLogger(StringDictionaryTest.class); - + private final FeatureRegistry registry = new FeatureRegistry(); + StringDictionary makeDictionary() { StringDictionary dict = new StringDictionary(); DictionaryEntry result = dict.getEntry("foo", "bar"); assertEquals(null, result); - int idx = dict.possiblyAdd("LOC", "lat", 0.1, 0.1); + int idx = dict.possiblyAdd(registry.feature("LOC", "lat"), 0.1, 0.1); assertEquals(0, idx); - idx = dict.possiblyAdd("LOC", "lng", 0.2, 0.2); + idx = dict.possiblyAdd(registry.feature("LOC", "lng"), 0.2, 0.2); assertEquals(1, idx); - idx = dict.possiblyAdd("foo", "bar", 0.3, 0.3); + idx = dict.possiblyAdd(registry.feature("foo", "bar"), 0.3, 0.3); assertEquals(2, idx); return dict; } @@ -41,31 +34,28 @@ StringDictionary makeDictionary() { public void testStringDictionaryAdd() { StringDictionary dict = makeDictionary(); assertEquals(3, dict.getDictionary().getEntryCount()); - assertEquals(0, dict.getEntry("LOC", "lat").index); - assertEquals(1, dict.getEntry("LOC", "lng").index); - assertEquals(2, dict.getEntry("foo", "bar").index); - int result = dict.possiblyAdd("foo", "bar", 0.0, 0.0); + assertEquals(0, dict.getEntry("LOC", "lat").getIndex()); + assertEquals(1, dict.getEntry("LOC", "lng").getIndex()); + assertEquals(2, dict.getEntry("foo", "bar").getIndex()); + int result = dict.possiblyAdd(registry.feature("foo", "bar"), 0.0, 0.0); assertEquals(-1, result); assertEquals(3, dict.getDictionary().getEntryCount()); } @Test public void testStringDictionaryVector() { - Map> feature = new HashMap<>(); - Map loc = new HashMap<>(); - Map foo = new HashMap<>(); + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + Family loc = registry.family("LOC"); + Family foo = registry.family("foo"); - feature.put("LOC", loc); - feature.put("foo", foo); - - loc.put("lat", 1.0); - loc.put("lng", 2.0); - foo.put("bar", 3.0); - foo.put("baz", 4.0); + vector.put(loc.feature("lat"), 1.0); + vector.put(loc.feature("lng"), 2.0); + vector.put(foo.feature("bar"), 3.0); + vector.put(foo.feature("baz"), 4.0); StringDictionary dict = makeDictionary(); - FloatVector vec = dict.makeVectorFromSparseFloats(feature); + FloatVector vec = dict.makeVectorFromSparseFloats(vector); assertEquals(3, vec.values.length); assertEquals(0.1 * (1.0 - 0.1), vec.values[0], 0.1f); assertEquals(0.2 * (2.0 - 0.2), vec.values[1], 0.1f); diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/SupportVectorTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/SupportVectorTest.java index 341cf4f1..5735dfd2 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/SupportVectorTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/SupportVectorTest.java @@ -1,28 +1,17 @@ package com.airbnb.aerosolve.core.util; -import com.airbnb.aerosolve.core.Example; - -import com.airbnb.aerosolve.core.FeatureVector; import com.airbnb.aerosolve.core.FunctionForm; import com.airbnb.aerosolve.core.ModelRecord; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.HashMap; -import java.util.Set; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @author Hector Yee */ +@Slf4j public class SupportVectorTest { - private static final Logger log = LoggerFactory.getLogger(SupportVectorTest.class); - @Test public void testRbf() { FloatVector v1 = new FloatVector(new float[]{1.0f, 2.0f}); diff --git a/core/src/test/java/com/airbnb/aerosolve/core/util/UtilTest.java b/core/src/test/java/com/airbnb/aerosolve/core/util/UtilTest.java index 1711e94b..4b05ca8e 100644 --- a/core/src/test/java/com/airbnb/aerosolve/core/util/UtilTest.java +++ b/core/src/test/java/com/airbnb/aerosolve/core/util/UtilTest.java @@ -3,12 +3,16 @@ import com.airbnb.aerosolve.core.DebugScoreDiffRecord; import com.airbnb.aerosolve.core.DebugScoreRecord; import com.airbnb.aerosolve.core.Example; -import com.airbnb.aerosolve.core.FeatureVector; +import com.airbnb.aerosolve.core.features.Family; +import com.airbnb.aerosolve.core.features.FeatureRegistry; +import com.airbnb.aerosolve.core.features.MultiFamilyVector; +import com.airbnb.aerosolve.core.features.SimpleExample; +import com.airbnb.aerosolve.core.transforms.TransformTestingHelper; +import com.google.common.collect.Sets; +import java.util.ArrayList; +import java.util.List; +import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -16,59 +20,52 @@ /** * @author Hector Yee */ +@Slf4j public class UtilTest { - private static final Logger log = LoggerFactory.getLogger(UtilTest.class); - - public FeatureVector makeFeatureVector() { - Set list = new HashSet(); - list.add("aaa"); - list.add("bbb"); - HashMap stringFeatures = new HashMap>(); - stringFeatures.put("string_feature", list); - FeatureVector featureVector = new FeatureVector(); - featureVector.setStringFeatures(stringFeatures); - return featureVector; - } + private final FeatureRegistry registry = new FeatureRegistry(); @Test public void testEncodeDecodeFeatureVector() { - FeatureVector featureVector = makeFeatureVector(); - String str = Util.encode(featureVector); + MultiFamilyVector featureVector = makeFeatureVector(); + String str = Util.encodeFeatureVector(featureVector); assertTrue(str.length() > 0); log.info(str); - FeatureVector featureVector2 = Util.decodeFeatureVector(str); - assertTrue(featureVector2.stringFeatures != null); - assertTrue(featureVector2.stringFeatures.containsKey("string_feature")); - Set list2 = featureVector2.stringFeatures.get("string_feature"); - assertTrue(list2.size() == 2); + MultiFamilyVector featureVector2 = Util.decodeFeatureVector(str, registry); + assertTrue(featureVector2.numFamilies() == 2); + Family stringFamily = registry.family("string_feature"); + assertTrue(featureVector2.contains(stringFamily)); + assertTrue(featureVector2.get(stringFamily).size() == 2); + + Family sparseFamily = registry.family("sparse_feature"); + assertTrue(featureVector2.contains(sparseFamily)); + assertTrue(featureVector2.get(sparseFamily).size() == 2); + } + + private MultiFamilyVector makeFeatureVector() { + MultiFamilyVector vector = TransformTestingHelper.makeEmptyVector(registry); + + Family stringFamily = registry.family("string_feature"); + vector.putString(stringFamily.feature("aaa")); + vector.putString(stringFamily.feature("bbb")); + + Family sparseFamily = registry.family("sparse_feature"); + + vector.put(sparseFamily.feature("lat"), 37.7); + vector.put(sparseFamily.feature("long"), 40.0); + + return vector; } @Test public void testEncodeDecodeExample() { - FeatureVector featureVector = makeFeatureVector(); - Example example = new Example(); + MultiFamilyVector featureVector = makeFeatureVector(); + Example example = new SimpleExample(registry); example.addToExample(featureVector); - String str = Util.encode(example); + String str = Util.encodeExample(example); assertTrue(str.length() > 0); log.info(str); - Example example2 = Util.decodeExample(str); - assertTrue(example2.example.size() == 1); - } - - @Test - public void testFlattenFeature() { - FeatureVector featureVector = makeFeatureVector(); - Map> floatFeatures = new HashMap<>(); - Map tmp = new HashMap(); - floatFeatures.put("float_feature", tmp); - tmp.put("x", 0.3); - tmp.put("y", -0.2); - featureVector.floatFeatures = floatFeatures; - Map> flatFeature = Util.flattenFeature(featureVector); - assertEquals(1.0, flatFeature.get("string_feature").get("aaa"), 0.1); - assertEquals(1.0, flatFeature.get("string_feature").get("bbb"), 0.1); - assertEquals(0.3, flatFeature.get("float_feature").get("x"), 0.1); - assertEquals(-0.2, flatFeature.get("float_feature").get("y"), 0.1); + Example example2 = Util.decodeExample(str, registry); + assertTrue(Sets.newHashSet(example2).size() == 1); } @Test diff --git a/core/src/test/resources/income_prediction.conf b/core/src/test/resources/income_prediction.conf index 41b1457e..18c100a4 100644 --- a/core/src/test/resources/income_prediction.conf +++ b/core/src/test/resources/income_prediction.conf @@ -1,9 +1,8 @@ // This config file resides in src/main/resources for distribution with jars // but a local override can be made in the main directory. - -rootDir : ${PWD}"/output" -trainingInput : ${PWD}"/src/main/resources/adult.data" -testingInput : ${PWD}"/src/main/resources/adult.test" +rootDir : ${?PWD}"/output" +trainingInput : ${?PWD}"/src/main/resources/adult.data" +testingInput : ${?PWD}"/src/main/resources/adult.test" // Location of training data trainingData : ${rootDir}"/training_data" diff --git a/thrift-cli.gradle b/thrift-cli.gradle index 490c684b..e7df3d3b 100644 --- a/thrift-cli.gradle +++ b/thrift-cli.gradle @@ -17,7 +17,7 @@ task genThrift << { thriftFiles.each() { File file -> exec { executable = 'thrift' - args = ['--gen', 'java', '-o', "${->buildDir}/", file] + args = ['--gen', 'java:private-members', '-o', "${->buildDir}/", file] } } } diff --git a/training/build.gradle b/training/build.gradle index fdcb9b33..d0b34e6a 100644 --- a/training/build.gradle +++ b/training/build.gradle @@ -62,6 +62,8 @@ shadowJar { } } +idea.module.scopes.COMPILE.plus += [configurations.provided] + test { jvmArgs += [ "-XX:MaxPermSize=1024m" ] } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/AdditiveModelTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/AdditiveModelTrainer.scala index 9467989e..ab3d3a1a 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/AdditiveModelTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/AdditiveModelTrainer.scala @@ -1,9 +1,9 @@ package com.airbnb.aerosolve.training import com.airbnb.aerosolve.core._ -import com.airbnb.aerosolve.core.function.{Function, Linear, MultiDimensionSpline, Spline} +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry, MultiFamilyVector} +import com.airbnb.aerosolve.core.functions.{Function, Linear, Spline} import com.airbnb.aerosolve.core.models.AdditiveModel -import com.airbnb.aerosolve.core.util.Util import com.typesafe.config.Config import org.apache.spark.SparkContext import org.apache.spark.broadcast.Broadcast @@ -12,7 +12,6 @@ import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.util.Try /** @@ -32,7 +31,7 @@ object AdditiveModelTrainer { case class AdditiveTrainerParams(numBins: Int, numBags: Int, - rankKey: String, + labelFamily: Family, loss: String, minCount: Int, learningRate: Double, @@ -49,16 +48,18 @@ object AdditiveModelTrainer { rankMargin: Double, // The margin for ranking loss epsilon: Double, // epsilon used in epsilon-insensitive loss for regression training initModelPath: String, - linearFeatureFamilies: Array[String], - priors: Array[String]) + linearFeatureFamilies: Array[Family], + priors: Array[String], + registry: FeatureRegistry) def train(sc: SparkContext, input: RDD[Example], config: Config, - key: String): AdditiveModel = { + key: String, + registry: FeatureRegistry): AdditiveModel = { val trainConfig = config.getConfig(key) val iterations: Int = trainConfig.getInt("iterations") - val params = loadTrainingParameters(trainConfig) + val params = loadTrainingParameters(trainConfig, registry) val transformed = transformExamples(input, config, key, params) val output = config.getString(key + ".model_output") log.info("Training using " + params.loss) @@ -80,9 +81,9 @@ object AdditiveModelTrainer { * During each iteration, we: * * 1. Sample dataset with subsample (this is analogous to mini-batch sgd?) - * 1. Repartition to numBags (this is analogous to ensemble averaging?) - * 1. For each bag we run SGD (observation-wise gradient updates) - * 1. We then average fitted weights for each feature and return them as updated model + * 2. Repartition to numBags (this is analogous to ensemble averaging?) + * 3. For each bag we run SGD (observation-wise gradient updates) + * 4. We then average fitted weights for each feature and return them as updated model * * @param input collection of examples to be trained in sgd iteration * @param paramsBC broadcasted model params @@ -95,24 +96,32 @@ object AdditiveModelTrainer { val model = modelBC.value val params = paramsBC.value - input + // This returns the entire additive model in order to support familyWeights. + val fittedModels: RDD[AdditiveModel] = input .sample(false, params.subsample) .coalesce(params.numBags, true) .mapPartitionsWithIndex((index, partition) => sgdPartition(index, partition, modelBC, paramsBC)) + + val weightsPerFeature: Array[(AnyRef, Function)] = fittedModels + .flatMap(model => model.weights.asScala.iterator ++ model.familyWeights.asScala.iterator) + // TODO (Brad): This should be a reduceByKey. It's complicated by the aggregate method in + // Function. We'd need to aggregate one function at a time. But it should be possible and + // would likely help a lot with performance. .groupByKey() // Average the feature functions // Average the weights .mapValues(x => { - val scale = 1.0f / paramsBC.value.numBags.toFloat - aggregateFuncWeights(x, scale, paramsBC.value.numBins, paramsBC.value.smoothingTolerance.toFloat) - }) - .collect() - .foreach(entry => { - val family = model.getWeights.get(entry._1._1) - if (family != null && family.containsKey(entry._1._2)) { - family.put(entry._1._2, entry._2) - } + val scale = 1.0d / paramsBC.value.numBags + aggregateFuncWeights(x, scale, paramsBC.value.numBins, paramsBC.value.smoothingTolerance) }) + .collect() + + weightsPerFeature + .foreach { + case (key: Family, function: Function) => model.familyWeights.replace(key, function) + case (key: Feature, function: Function) => model.weights.replace(key, function) + } + deleteSmallFunctions(model, params.linfinityThreshold) model @@ -131,7 +140,7 @@ object AdditiveModelTrainer { def sgdPartition(index: Int, partition: Iterator[Example], modelBC: Broadcast[AdditiveModel], - paramsBC: Broadcast[AdditiveTrainerParams]): Iterator[((String, String), Function)] = { + paramsBC: Broadcast[AdditiveTrainerParams]): Iterator[AdditiveModel] = { val workingModel = modelBC.value val params = paramsBC.value val multiscale = params.multiscale @@ -140,15 +149,11 @@ object AdditiveModelTrainer { val newBins = multiscale(index % multiscale.length) log.info(s"Resampling to $newBins bins") - for(family <- workingModel.getWeights.values) { - for(feature <- family.values) { - feature.resample(newBins) - } - } + workingModel.weights.values.foreach(_.resample(newBins)) } val output = sgdPartitionInternal(partition, workingModel, params) - output.iterator + Set(output).iterator } /** @@ -162,18 +167,18 @@ object AdditiveModelTrainer { * @return */ private def aggregateFuncWeights(input: Iterable[Function], - scale: Float, + scale: Double, numBins: Int, - smoothingTolerance: Float): Function = { + smoothingTolerance: Double): Function = { val head: Function = input.head - // TODO: revisit asJava performance impact + // TODO (Forest): revisit asJava performance impact val output = head.aggregate(input.asJava, scale, numBins) output.smooth(smoothingTolerance) output } /** - * Actually perform SGD on examples by applying approriate gradient updates according + * Actually perform SGD on examples by applying appropriate gradient updates according * to model specification * * @param partition list of examples @@ -183,27 +188,18 @@ object AdditiveModelTrainer { */ private def sgdPartitionInternal(partition: Iterator[Example], workingModel: AdditiveModel, - params: AdditiveTrainerParams): mutable.HashMap[(String, String), Function] = { + params: AdditiveTrainerParams): AdditiveModel = { var lossSum: Double = 0.0 var lossCount: Int = 0 partition.foreach(example => { - lossSum += pointwiseLoss(example.example.get(0), workingModel, params.loss, params) + lossSum += pointwiseLoss(example.only(), workingModel, params.loss, params) lossCount = lossCount + 1 if (lossCount % params.lossMod == 0) { log.info(s"Loss = ${lossSum / params.lossMod.toDouble}, samples = $lossCount") lossSum = 0.0 } }) - val output = mutable.HashMap[(String, String), Function]() - // TODO: weights should be a vector instead of stored in hashmap workingModel - .getWeights - .foreach(family => { - family._2.foreach(feature => { - output.put((family._1, feature._1), feature._2) - }) - }) - output } /** @@ -215,98 +211,66 @@ object AdditiveModelTrainer { * @param params model params * @return */ - def pointwiseLoss(fv: FeatureVector, + def pointwiseLoss(fv: MultiFamilyVector, workingModel: AdditiveModel, loss: String, params: AdditiveTrainerParams): Double = { val label: Double = if (loss == "regression") { - TrainingUtils.getLabel(fv, params.rankKey) + TrainingUtils.getLabel(fv, params.labelFamily) } else { - TrainingUtils.getLabel(fv, params.rankKey, params.threshold) + TrainingUtils.getLabel(fv, params.labelFamily, params.threshold) } - loss match { - case "logistic" => updateLogistic(workingModel, fv, label, params) - case "hinge" => updateHinge(workingModel, fv, label, params) - case "regression" => updateRegressor(workingModel, fv, label, params) + val lossFunction = loss match { + case "logistic" => logisticLoss _ + case "hinge" => hingeLoss _ + case "regression" => regressionLoss _ } + updateWithLossFunction(workingModel, fv, label, lossFunction, params) } // http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf // We rescale by 1 / p so that at inference time we don't have to scale by p. // In our case p = 1.0 - dropout rate - def updateLogistic(model: AdditiveModel, - fv: FeatureVector, - label: Double, - params: AdditiveTrainerParams): Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - // only MultiDimensionSpline use denseFeatures for now - val denseFeatures = MultiDimensionSpline.featureDropout(fv, params.dropout) - val prediction = (model.scoreFlatFeatures(flatFeatures) + - model.scoreDenseFeatures(denseFeatures)) / - (1.0 - params.dropout) - // To prevent blowup. + def updateWithLossFunction(model: AdditiveModel, + fv: MultiFamilyVector, + label: Double, + lossAndGradientFunc: (Double, Double, AdditiveTrainerParams) => (Double, Double), + params: AdditiveTrainerParams): Double = { + val newVec = fv.withFamilyDropout(params.dropout) + val prediction = model.scoreItem(newVec) / (1.0 - params.dropout) + val (loss, grad) = lossAndGradientFunc(prediction, label, params) + val gradWithLearningRate = grad * params.learningRate + if (gradWithLearningRate != 0.0) { + model.update(gradWithLearningRate, + params.linfinityCap, + fv) + } + loss + } + + // TODO (Brad): I refactored out the loss functions to reduce code duplication but I'm a bit + // concerned I made a logical error. Would appreciate a close look. + def logisticLoss(prediction: Double, label: Double, params: AdditiveTrainerParams): (Double, Double) = { val corr = scala.math.min(10.0, label * prediction) val expCorr = scala.math.exp(corr) val loss = scala.math.log(1.0 + 1.0 / expCorr) val grad = -label / (1.0 + expCorr) - val gradWithLearningRate = grad.toFloat * params.learningRate.toFloat - model.update(gradWithLearningRate, - params.linfinityCap.toFloat, - flatFeatures) - model.updateDense(gradWithLearningRate, - params.linfinityCap.toFloat, - denseFeatures) - loss + (loss, grad) } - def updateHinge(model: AdditiveModel, - fv: FeatureVector, - label: Double, - params: AdditiveTrainerParams): Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - // only MultiDimensionSpline use denseFeatures for now - val denseFeatures = MultiDimensionSpline.featureDropout(fv, params.dropout) - val prediction = (model.scoreFlatFeatures(flatFeatures) + - model.scoreDenseFeatures(denseFeatures)) / - (1.0 - params.dropout) + def hingeLoss(prediction: Double, label: Double, params : AdditiveTrainerParams): (Double, Double) = { val loss = scala.math.max(0.0, params.margin - label * prediction) - if (loss > 0.0) { - val gradWithLearningRate = -label.toFloat * params.learningRate.toFloat - model.update(gradWithLearningRate, - params.linfinityCap.toFloat, - flatFeatures) - model.updateDense(gradWithLearningRate, - params.linfinityCap.toFloat, - denseFeatures) - } - loss + val grad = if (loss > 0.0) -label else 0.0 + (loss, grad) } - def updateRegressor(model: AdditiveModel, - fv: FeatureVector, - label: Double, - params: AdditiveTrainerParams): Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - // only MultiDimensionSpline use denseFeatures for now - val denseFeatures = MultiDimensionSpline.featureDropout(fv, params.dropout) - val prediction = (model.scoreFlatFeatures(flatFeatures) + - model.scoreDenseFeatures(denseFeatures)) / - (1.0 - params.dropout) - // absolute difference + def regressionLoss(prediction: Double, label: Double, params : AdditiveTrainerParams): (Double, Double) = { val loss = math.abs(prediction - label) - if (prediction - label > params.epsilon) { - model.update(params.learningRate.toFloat, - params.linfinityCap.toFloat, flatFeatures) - model.updateDense(params.learningRate.toFloat, - params.linfinityCap.toFloat, denseFeatures) - } else if (prediction - label < -params.epsilon) { - model.update(-params.learningRate.toFloat, - params.linfinityCap.toFloat, flatFeatures) - model.updateDense(-params.learningRate.toFloat, - params.linfinityCap.toFloat, denseFeatures) - } - loss + val grad = if (prediction - label > params.epsilon) 1.0 + else if (prediction - label < -params.epsilon) -1.0 + else 0.0 + (loss, grad) } private def transformExamples(input: RDD[Example], @@ -314,9 +278,9 @@ object AdditiveModelTrainer { key: String, params: AdditiveTrainerParams): RDD[Example] = { if (params.isRanking) { - LinearRankerUtils.transformExamples(input, config, key) + LinearRankerUtils.transformExamples(input, config, key, params.registry) } else { - LinearRankerUtils.makePointwiseFloat(input, config, key) + LinearRankerUtils.makePointwiseFloat(input, config, key, params.registry) } } @@ -326,7 +290,7 @@ object AdditiveModelTrainer { val initialModel = if (params.initModelPath == "") { None } else { - TrainingUtils.loadScoreModel(params.initModelPath) + TrainingUtils.loadScoreModel(params.initModelPath, params.registry) } // sample examples to be used for model initialization @@ -336,7 +300,7 @@ object AdditiveModelTrainer { initModel(params.minCount, params, initExamples, newModel, false) newModel } else { - val newModel = new AdditiveModel() + val newModel = new AdditiveModel(params.registry) initModel(params.minCount, params, initExamples, newModel, true) setPrior(params.priors, newModel) newModel @@ -352,43 +316,27 @@ object AdditiveModelTrainer { val linearFeatureFamilies = params.linearFeatureFamilies val minMax = TrainingUtils .getFeatureStatistics(minCount, examples) - .filter(x => x._1._1 != params.rankKey) + .filter{ case (feature:Feature, _) => feature.family != params.labelFamily } + log.info("Num features = %d".format(minMax.length)) - val minMaxSpline = minMax.filter(x => !linearFeatureFamilies.contains(x._1._1)) - val minMaxLinear = minMax.filter(x => linearFeatureFamilies.contains(x._1._1)) - // add splines - for (((featureFamily, featureName), stats) <- minMaxSpline) { - val spline = new Spline(stats.min.toFloat, stats.max.toFloat, params.numBins) - model.addFunction(featureFamily, featureName, spline, overwrite) - } - // add linear - for (((featureFamily, featureName), stats) <- minMaxLinear) { - // set default linear function as f(x) = 0 - model.addFunction(featureFamily, featureName, - new Linear(stats.min.toFloat, stats.max.toFloat), overwrite) + + minMax.foreach{ case (feature:Feature, stats) => + val function = if (linearFeatureFamilies.contains(feature.family)) { + // set default linear function as f(x) = 0 + new Linear(stats.min, stats.max) + } else { + new Spline(stats.min, stats.max, params.numBins) + } + model.addFunction(feature, function, overwrite) } } def deleteSmallFunctions(model: AdditiveModel, linfinityThreshold: Double) = { - val toDelete = scala.collection.mutable.ArrayBuffer[(String, String)]() - - model.getWeights.asScala.foreach(family => { - family._2.asScala.foreach(entry => { - val func: Function = entry._2 - if (func.LInfinityNorm() < linfinityThreshold) { - toDelete.append((family._1, entry._1)) - } - }) - }) - - log.info("Deleting %d small functions".format(toDelete.size)) - toDelete.foreach(entry => { - val family = model.getWeights.get(entry._1) - if (family != null && family.containsKey(entry._2)) { - family.remove(entry._2) - } - }) + val newWeights = model.weights.asScala.filter { + case (_, func) => func.LInfinityNorm() >= linfinityThreshold + }.asJava + model.weights(newWeights) } def setPrior(priors: Array[String], model: AdditiveModel): Unit = { @@ -399,14 +347,12 @@ object AdditiveModelTrainer { if (tokens.length == 4) { val family = tokens(0) val name = tokens(1) - val params = Array(tokens(2).toFloat, tokens(3).toFloat) - val familyMap = model.getWeights.get(family) - if (!familyMap.isEmpty) { - val func: Function = familyMap.get(name) - if (func != null) { - log.info("Setting prior %s:%s <- %f to %f".format(family, name, params(0), params(1))) - func.setPriors(params) - } + val params = Array(tokens(2).toDouble, tokens(3).toDouble) + val feature = model.registry.feature(family, name) + val func: Function = model.weights.get(feature) + if (func != null) { + log.info("Setting prior %s:%s <- %f to %f".format(family, name, params(0), params(1))) + func.setPriors(params) } } else { log.error("Incorrect number of parameters for %s".format(prior)) @@ -417,7 +363,8 @@ object AdditiveModelTrainer { } } - def loadTrainingParameters(config: Config): AdditiveTrainerParams = { + def loadTrainingParameters(config: Config, + registry: FeatureRegistry): AdditiveTrainerParams = { val loss: String = config.getString("loss") val isRanking = loss match { case "logistic" => false @@ -428,7 +375,7 @@ object AdditiveModelTrainer { } val numBins: Int = config.getInt("num_bins") val numBags: Int = config.getInt("num_bags") - val rankKey: String = config.getString("rank_key") + val labelFamily: Family = registry.family(config.getString("rank_key")) val learningRate: Double = config.getDouble("learning_rate") val dropout: Double = config.getDouble("dropout") val subsample: Double = config.getDouble("subsample") @@ -443,9 +390,11 @@ object AdditiveModelTrainer { config.getDouble("epsilon") }.getOrElse(0.0) val minCount: Int = config.getInt("min_count") - val linearFeatureFamilies: Array[String] = Try( - config.getStringList("linear_feature").toList.toArray) - .getOrElse(Array[String]()) + val linearFeaturePath = "linear_feature" + val linearFeatureFamilies: Array[Family] = if (config.hasPath(linearFeaturePath)) { + config.getStringList(linearFeaturePath) + .map(familyName => registry.family(familyName)).toList.toArray + } else Array[Family]() val lossMod: Int = Try { config.getInt("loss_mod") }.getOrElse(100) @@ -464,7 +413,7 @@ object AdditiveModelTrainer { AdditiveTrainerParams( numBins, numBags, - rankKey, + labelFamily, loss, minCount, learningRate, @@ -482,14 +431,16 @@ object AdditiveModelTrainer { epsilon, initModelPath, linearFeatureFamilies, - priors) + priors, + registry) } def trainAndSaveToFile(sc: SparkContext, input: RDD[Example], config: Config, - key: String) = { - val model = train(sc, input, config, key) + key: String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/BoostedForestTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/BoostedForestTrainer.scala index 5ceaf0f6..b2273a5c 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/BoostedForestTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/BoostedForestTrainer.scala @@ -2,22 +2,17 @@ package com.airbnb.aerosolve.training import java.util -import com.airbnb.aerosolve.core.models.BoostedStumpsModel -import com.airbnb.aerosolve.core.models.DecisionTreeModel -import com.airbnb.aerosolve.core.models.ForestModel -import com.airbnb.aerosolve.core.{FeatureVector, Example, ModelRecord} -import com.airbnb.aerosolve.core.util.{FloatVector, Util} -import com.airbnb.aerosolve.training.GradientUtils.GradientContainer +import com.airbnb.aerosolve.core.features.{Family, FeatureRegistry, MultiFamilyVector} +import com.airbnb.aerosolve.core.models.{DecisionTreeModel, ForestModel} +import com.airbnb.aerosolve.core.{Example, ModelRecord} import com.typesafe.config.Config import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} -import scala.util.Random -import scala.util.Try import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ +import scala.util.Try // A boosted tree forest trainer. // Alternates between fitting a tree and building one on the importance @@ -27,7 +22,7 @@ object BoostedForestTrainer { private final val log: Logger = LoggerFactory.getLogger("BoostedForestTrainer") case class BoostedForestTrainerParams(candidateSize : Int, - rankKey : String, + labelFamily : Family, rankThreshold : Double, maxDepth : Int, minLeafCount : Int, @@ -40,7 +35,8 @@ object BoostedForestTrainer { samplingStrategy : String, multiclass : Boolean, loss : String, - margin : Double) + margin : Double, + registry: FeatureRegistry) // A container class that returns the tree and leaf a feature vector ends up in case class ForestResponse(tree : Int, leaf : Int) // The sum of all responses to a feature vector of an entire forest. @@ -53,10 +49,11 @@ object BoostedForestTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : ForestModel = { + key : String, + registry: FeatureRegistry) : ForestModel = { val taskConfig = config.getConfig(key) val candidateSize : Int = taskConfig.getInt("num_candidates") - val rankKey : String = taskConfig.getString("rank_key") + val labelFamily : Family = registry.family(taskConfig.getString("rank_key")) val rankThreshold : Double = taskConfig.getDouble("rank_threshold") val maxDepth : Int = taskConfig.getInt("max_depth") val minLeafCount : Int = taskConfig.getInt("min_leaf_items") @@ -71,9 +68,9 @@ object BoostedForestTrainer { val loss : String = Try{taskConfig.getString("loss")}.getOrElse("logistic") val margin: Double = Try{taskConfig.getDouble("margin")}.getOrElse(1.0) val cache: String = Try{taskConfig.getString("cache")}.getOrElse("") - + val params = BoostedForestTrainerParams(candidateSize = candidateSize, - rankKey = rankKey, + labelFamily = labelFamily, rankThreshold = rankThreshold, maxDepth = maxDepth, minLeafCount = minLeafCount, @@ -86,13 +83,14 @@ object BoostedForestTrainer { samplingStrategy = samplingStrategy, multiclass = splitCriteria.contains("multiclass"), loss = loss, - margin = margin + margin = margin, + registry = registry ) - val forest = new ForestModel() - forest.setTrees(new java.util.ArrayList[DecisionTreeModel]()) + val forest = new ForestModel(registry) + forest.trees(new java.util.ArrayList[DecisionTreeModel]()) - val raw = LinearRankerUtils.makePointwiseFloat(input, config, key) + val raw = LinearRankerUtils.makePointwiseFloat(input, config, key, registry) val examples = cache match { case "memory" => raw.cache() @@ -118,10 +116,10 @@ object BoostedForestTrainer { } def optionalExample(ex : Example, forest : ForestModel, params : BoostedForestTrainerParams) - : Option[FeatureVector] = { - val item = ex.example(0) + : Option[MultiFamilyVector] = { + val item = ex.only if (params.multiclass) { - val labels = TrainingUtils.getLabelDistribution(item, params.rankKey) + val labels = TrainingUtils.getLabelDistribution(item, params.labelFamily) if (labels.isEmpty) return None // Assuming that this is a single label multi-class @@ -132,17 +130,17 @@ object BoostedForestTrainer { forest.scoreToProbability(scores) - val probs = scores.filter(x => x.label == label) + val probs = scores.filter(x => x.getLabel == label.name()) if (probs.isEmpty) return None - val importance = 1.0 - probs.head.probability + val importance = 1.0 - probs.head.getProbability if (scala.util.Random.nextDouble < importance) { Some(item) } else { None } } else { - val label = TrainingUtils.getLabel(item, params.rankKey, params.rankThreshold) + val label = TrainingUtils.getLabel(item, params.labelFamily, params.rankThreshold) val score = forest.scoreItem(item) val prob = forest.scoreProbability(score) val importance = if (label > 0) { @@ -167,15 +165,11 @@ object BoostedForestTrainer { val forestBC = sc.broadcast(forest) val paramsBC = sc.broadcast(params) val examples = input - .map(x => optionalExample(x, forestBC.value, paramsBC.value)) - .filter(x => x.isDefined) - .map(x => Util.flattenFeature(x.get)) + .flatMap(x => optionalExample(x, forestBC.value, paramsBC.value)) params.samplingStrategy match { // Picks the first few items that match the criteria. Better for massive data sets. - case "first" => { - examples.take(params.candidateSize) - } + case "first" => examples.take(params.candidateSize) // Picks uniformly. Better for small data sets. case "uniform" => examples.takeSample(false, params.candidateSize) } @@ -197,20 +191,20 @@ object BoostedForestTrainer { 0, 0, params.maxDepth, - params.rankKey, + params.labelFamily, params.rankThreshold, params.numTries, params.minLeafCount, SplitCriteria.splitCriteriaFromName(params.splitCriteria)) - val tree = new DecisionTreeModel() - tree.setStumps(stumps) + val tree = new DecisionTreeModel(params.registry) + tree.stumps(stumps) if (params.multiclass) { // Convert a pdf into something like a weight val scale = 1.0f / params.numTrees.toFloat - for (stump <- tree.getStumps) { - if (stump.labelDistribution != null) { - val dist = stump.labelDistribution.asScala + for (stump <- tree.stumps) { + if (stump.getLabelDistribution != null) { + val dist = stump.getLabelDistribution.asScala for (key <- dist.keys) { val v = dist.get(key) @@ -222,13 +216,13 @@ object BoostedForestTrainer { } } else { val scale = 1.0f / params.numTrees.toFloat - for (stump <- tree.getStumps) { - if (stump.featureWeight != 0.0f) { - stump.featureWeight *= scale + for (stump <- tree.stumps) { + if (stump.getFeatureWeight != 0.0f) { + stump.setFeatureWeight(stump.getFeatureWeight * scale) } } } - forest.getTrees.append(tree) + forest.trees.append(tree) } def boostForest(sc : SparkContext, @@ -272,12 +266,12 @@ object BoostedForestTrainer { .collectAsMap // Gradient step - val trees = forest.getTrees.asScala.toArray + val trees = forest.trees.asScala.toArray var sum : Double = 0.0 for (cg <- countAndGradient) { val key = cg._1 val tree = trees(key._1) - val stump = tree.getStumps.get(key._2) + val stump = tree.stumps.get(key._2) val (count, grad) = cg._2 val curr = stump.getFeatureWeight val avgGrad = grad / count @@ -330,20 +324,20 @@ object BoostedForestTrainer { .collectAsMap // Gradient step - val trees = forest.getTrees.asScala.toArray + val trees = forest.trees.asScala.toArray var sum : Double = 0.0 for (cg <- countAndGradient) { val key = cg._1 val tree = trees(key._1) - val stump = tree.getStumps.get(key._2) + val stump = tree.stumps.get(key._2) val label = key._3 val (count, grad) = cg._2 - val curr = stump.labelDistribution.get(label) + val curr = stump.getLabelDistribution.get(label) val avgGrad = grad / count if (curr == null) { - stump.labelDistribution.put(label, - params.learningRate * avgGrad) + stump.getLabelDistribution.put(label, - params.learningRate * avgGrad) } else { - stump.labelDistribution.put(label, curr - params.learningRate * avgGrad) + stump.getLabelDistribution.put(label, curr - params.learningRate * avgGrad) } sum += avgGrad * avgGrad } @@ -354,21 +348,20 @@ object BoostedForestTrainer { def getForestResponse(forest : ForestModel, ex : Example, params : BoostedForestTrainerParams) : ForestResult = { - val item = ex.example.get(0) - val floatFeatures = Util.flattenFeature(item) + val item = ex.only val result = scala.collection.mutable.ArrayBuffer[ForestResponse]() - val trees = forest.getTrees().asScala.toArray + val trees = forest.trees.asScala.toArray var sum : Double = 0.0 val sumResponses = scala.collection.mutable.HashMap[String, Double]() - for (i <- 0 until trees.size) { + for (i <- trees.indices) { val tree = trees(i) - val leaf = trees(i).getLeafIndex(floatFeatures); + val leaf = trees(i).getLeafIndex(item) if (leaf >= 0) { - val stump = tree.getStumps().get(leaf); - val weight = stump.featureWeight + val stump = tree.stumps.get(leaf); + val weight = stump.getFeatureWeight sum = sum + weight - if (params.multiclass && stump.labelDistribution != null) { - val dist = stump.labelDistribution.asScala + if (params.multiclass && stump.getLabelDistribution != null) { + val dist = stump.getLabelDistribution.asScala for (kv <- dist) { val v = sumResponses.getOrElse(kv._1, 0.0) sumResponses.put(kv._1, v + kv._2) @@ -378,16 +371,19 @@ object BoostedForestTrainer { result.append(response) } } - val label = TrainingUtils.getLabel(item, params.rankKey, params.rankThreshold) - val labels = TrainingUtils.getLabelDistribution(item, params.rankKey) + val label = TrainingUtils.getLabel(item, params.labelFamily, params.rankThreshold) + val labels: Map[String, Double] = TrainingUtils.getLabelDistribution(item, params.labelFamily) + .map(kv => (kv._1.name, kv._2)) + .toMap ForestResult(label, labels, sum, sumResponses, result.toArray) } def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/BoostedStumpsTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/BoostedStumpsTrainer.scala index b92a9132..7d42495b 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/BoostedStumpsTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/BoostedStumpsTrainer.scala @@ -2,13 +2,11 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.{Example, ModelRecord} +import com.airbnb.aerosolve.core.features.{Family, FeatureRegistry, SimpleExample} import com.airbnb.aerosolve.core.models.BoostedStumpsModel -import com.airbnb.aerosolve.core.Example -import com.airbnb.aerosolve.core.ModelRecord -import com.airbnb.aerosolve.core.util.Util import com.typesafe.config.Config import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} @@ -24,86 +22,82 @@ object BoostedStumpsTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : BoostedStumpsModel = { + key : String, + registry: FeatureRegistry) : BoostedStumpsModel = { val candidateSize : Int = config.getInt(key + ".num_candidates") - val rankKey : String = config.getString(key + ".rank_key") + val labelFamily : Family = registry.family(config.getString(key + ".rank_key")) val pointwise : RDD[Example] = LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, registry) - val candidates : Array[ModelRecord] = getCandidateStumps(pointwise, candidateSize, rankKey) + val candidates : Array[ModelRecord] = getCandidateStumps(pointwise, candidateSize, labelFamily) - val data : RDD[Example] = getResponses(sc, pointwise, candidates, rankKey).cache() + val data : RDD[Example] = getResponses(sc, pointwise, candidates, labelFamily).cache() - val weights = LinearRankerTrainer.train(sc, data, config, key).toMap + val weights = LinearRankerTrainer.train(sc, data, config, key, registry).toMap // Lookup each candidate's weights (0 until candidates.size).foreach(i => { val stump = candidates(i) - val pos = weights.getOrElse(("+", i.toString), 0.0) + val pos = weights.getOrElse(registry.feature("+", i.toString), 0.0) stump.setFeatureWeight(pos) }) - val sorted = candidates.toBuffer.sortWith((a, b) => math.abs(a.featureWeight) > math.abs(b.featureWeight)) + val sorted = candidates.toBuffer.sortWith((a, b) => + math.abs(a.getFeatureWeight) > math.abs(b.getFeatureWeight)) val stumps = new util.ArrayList[ModelRecord]() sorted.foreach(stump => { - if (math.abs(stump.featureWeight) > 0.0) { + if (math.abs(stump.getFeatureWeight) > 0.0) { stumps.add(stump) } }) - val model = new BoostedStumpsModel() - model.setStumps(stumps) + val model = new BoostedStumpsModel(registry) + model.stumps(stumps) model } def getCandidateStumps(pointwise : RDD[Example], candidateSize : Int, - rankKey : String) : Array[ModelRecord] = { + labelFamily : Family) : Array[ModelRecord] = { val result = collection.mutable.HashSet[ModelRecord]() pointwise - .flatMap(x => Util.flattenFeature(x.example(0))) - .filter(x => x._1 != rankKey) - .flatMap(x => { - val buffer = collection.mutable.HashMap[(String, String), Double]() - x._2.foreach(feature => { - buffer.put((x._1, feature._1), feature._2) - }) - buffer + .flatMap(example => example.only.iterator) + .filter(featureValue => featureValue.feature.family != labelFamily) + .take(candidateSize) + .foreach(featureValue => { + val rec = new ModelRecord() + rec.setFeatureFamily(featureValue.feature.family.name) + rec.setFeatureName(featureValue.feature.name) + rec.setThreshold(featureValue.value) + result.add(rec) }) - .take(candidateSize) - .foreach(x => { - val rec = new ModelRecord() - rec.setFeatureFamily(x._1._1) - rec.setFeatureName(x._1._2) - rec.setThreshold(x._2) - result.add(rec) - }) - result.toArray + result.toArray } def getResponses(sc : SparkContext, pointwise : RDD[Example], candidates : Array[ModelRecord], - rankKey : String) : RDD[Example] = { + labelFamily : Family) : RDD[Example] = { val candidatesBC = sc.broadcast(candidates) pointwise.map(example => { val cand = candidatesBC.value - val ex = Util.flattenFeature(example.example.get(0)) - val output = new Example() - val fv = Util.createNewFeatureVector() - output.addToExample(fv) - fv.floatFeatures.put(rankKey, example.example.get(0).floatFeatures.get(rankKey)) - val pos = new java.util.HashSet[String]() - fv.stringFeatures.put("+", pos) - - val count = cand.size + val ex = example.only + val output = new SimpleExample(ex.registry) + val fv = output.createVector() + val labelFamilyVector = ex.get(labelFamily) + labelFamilyVector.iterator.asScala.foreach(featureValue => + fv.put(featureValue.feature, featureValue.value) + ) + val plusFamily = fv.registry.family("+") + + val count = cand.length for (i <- 0 until count) { val resp = BoostedStumpsModel.getStumpResponse(cand(i), ex) if (resp) { - pos.add(i.toString) + fv.putString(plusFamily.feature(i.toString)) } } output @@ -113,8 +107,9 @@ object BoostedStumpsTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/DecisionTreeTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/DecisionTreeTrainer.scala index 56cfcb0c..936c2d9c 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/DecisionTreeTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/DecisionTreeTrainer.scala @@ -2,19 +2,16 @@ package com.airbnb.aerosolve.training import java.util -import com.airbnb.aerosolve.core.models.BoostedStumpsModel -import com.airbnb.aerosolve.core.models.DecisionTreeModel -import com.airbnb.aerosolve.core.Example -import com.airbnb.aerosolve.core.ModelRecord -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.{Example, ModelRecord} +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry, MultiFamilyVector} +import com.airbnb.aerosolve.core.models.{BoostedStumpsModel, DecisionTreeModel} import com.typesafe.config.Config import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} -import scala.util.Random -import scala.util.Try import scala.collection.JavaConversions._ +import scala.util.{Random, Try} // Types of split criteria object SplitCriteriaTypes extends Enumeration { @@ -56,9 +53,10 @@ object DecisionTreeTrainer { sc : SparkContext, input : RDD[Example], config : Config, - key : String) : DecisionTreeModel = { + key : String, + registry: FeatureRegistry) : DecisionTreeModel = { val candidateSize : Int = config.getInt(key + ".num_candidates") - val rankKey : String = config.getString(key + ".rank_key") + val labelFamily : Family = registry.family(config.getString(key + ".rank_key")) val rankThreshold : Double = config.getDouble(key + ".rank_threshold") val maxDepth : Int = config.getInt(key + ".max_depth") val minLeafCount : Int = config.getInt(key + ".min_leaf_items") @@ -67,9 +65,9 @@ object DecisionTreeTrainer { .getOrElse("gini") val examples = LinearRankerUtils - .makePointwiseFloat(input, config, key) - .map(x => Util.flattenFeature(x.example(0))) - .filter(x => x.contains(rankKey)) + .makePointwiseFloat(input, config, key, registry) + .map(example => example.only) + .filter(vector => vector.contains(labelFamily)) .take(candidateSize) val stumps = new util.ArrayList[ModelRecord]() @@ -81,38 +79,38 @@ object DecisionTreeTrainer { 0, 0, maxDepth, - rankKey, + labelFamily, rankThreshold, numTries, minLeafCount, SplitCriteria.splitCriteriaFromName(splitCriteriaName) ) - val model = new DecisionTreeModel() - model.setStumps(stumps) + val model = new DecisionTreeModel(registry) + model.stumps(stumps) model } def buildTree( stumps : util.ArrayList[ModelRecord], - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], + vectors : Array[MultiFamilyVector], currIdx : Int, currDepth : Int, maxDepth : Int, - rankKey : String, + labelFamily: Family, rankThreshold : Double, numTries : Int, minLeafCount : Int, splitCriteria : SplitCriteria.Value) : Unit = { if (currDepth >= maxDepth) { - stumps(currIdx) = makeLeaf(examples, rankKey, rankThreshold, splitCriteria) + stumps(currIdx) = makeLeaf(vectors, labelFamily, rankThreshold, splitCriteria) return } val split = getBestSplit( - examples, - rankKey, + vectors, + labelFamily, rankThreshold, numTries, minLeafCount, @@ -120,7 +118,7 @@ object DecisionTreeTrainer { ) if (split.isEmpty) { - stumps(currIdx) = makeLeaf(examples, rankKey, rankThreshold, splitCriteria) + stumps(currIdx) = makeLeaf(vectors, labelFamily, rankThreshold, splitCriteria) return } @@ -133,16 +131,16 @@ object DecisionTreeTrainer { stumps(currIdx).setLeftChild(left) stumps(currIdx).setRightChild(right) - val (rightExamples, leftExamples) = examples.partition( - x => BoostedStumpsModel.getStumpResponse(stumps(currIdx), x)) + val (rightVectors, leftVectors) = vectors.partition( + vector => BoostedStumpsModel.getStumpResponse(stumps(currIdx), vector)) buildTree( stumps, - leftExamples, + leftVectors, left, currDepth + 1, maxDepth, - rankKey, + labelFamily, rankThreshold, numTries, minLeafCount, @@ -151,11 +149,11 @@ object DecisionTreeTrainer { buildTree( stumps, - rightExamples, + rightVectors, right, currDepth + 1, maxDepth, - rankKey, + labelFamily, rankThreshold, numTries, minLeafCount, @@ -164,8 +162,8 @@ object DecisionTreeTrainer { } def makeLeaf( - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], - rankKey : String, + vectors : Array[MultiFamilyVector], + labelFamily : Family, rankThreshold : Double, splitCriteria : SplitCriteria.Value) = { val rec = new ModelRecord() @@ -175,8 +173,9 @@ object DecisionTreeTrainer { var numPos = 0.0 var numNeg = 0.0 - for (example <- examples) { - val label = example.get(rankKey).values().iterator().next() > rankThreshold + for (vector <- vectors) { + // This assumes there's only one label and the family exists + val label = vector.get(labelFamily).iterator().next().value() > rankThreshold if (label) numPos += 1.0 else numNeg += 1.0 } @@ -194,8 +193,8 @@ object DecisionTreeTrainer { var count : Double = 0.0 var sum : Double = 0.0 - for (example <- examples) { - val labelValue = example.get(rankKey).values().iterator().next() + for (vector <- vectors) { + val labelValue = vector.get(labelFamily).iterator().next().value() count += 1.0 sum += labelValue @@ -210,10 +209,10 @@ object DecisionTreeTrainer { var sum = 0.0 - for (example <- examples) { - for (kv <- example.get(rankKey).entrySet()) { - val key = kv.getKey - val value = kv.getValue + for (vector <- vectors) { + for (fv <- vector.get(labelFamily).iterator) { + val key = fv.feature.name + val value = fv.value val count = if (labelDistribution.containsKey(key)) { labelDistribution.get(key) @@ -243,13 +242,13 @@ object DecisionTreeTrainer { // Returns the best split if one exists. def getBestSplit( - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], - rankKey : String, + vectors : Array[MultiFamilyVector], + labelFamily : Family, rankThreshold : Double, numTries : Int, minLeafCount : Int, splitCriteria : SplitCriteria.Value) : Option[ModelRecord] = { - if (examples.length <= minLeafCount) { + if (vectors.length <= minLeafCount) { // If we're at or below the minLeafCount, then there's no point in splitting None } else { @@ -259,28 +258,28 @@ object DecisionTreeTrainer { for (i <- 0 until numTries) { // Pick an example index randomly - val idx = rnd.nextInt(examples.length) - val ex = examples(idx) - val candidateOpt = getCandidateSplit(ex, rankKey, rnd) + val idx = rnd.nextInt(vectors.length) + val vec = vectors(idx) + val candidateOpt = getCandidateSplit(vec, labelFamily, rnd) if (candidateOpt.isDefined) { val candidateValue = SplitCriteria.getCriteriaType(splitCriteria) match { case SplitCriteriaTypes.Classification => evaluateClassificationSplit( - examples, rankKey, + vectors, labelFamily, rankThreshold, minLeafCount, splitCriteria, candidateOpt ) case SplitCriteriaTypes.Regression => evaluateRegressionSplit( - examples, rankKey, + vectors, labelFamily, minLeafCount, splitCriteria, candidateOpt ) case SplitCriteriaTypes.Multiclass => evaluateMulticlassSplit( - examples, rankKey, + vectors, labelFamily, minLeafCount, splitCriteria, candidateOpt ) @@ -299,8 +298,8 @@ object DecisionTreeTrainer { // Evaluate a classification-type split def evaluateClassificationSplit( - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], - rankKey : String, + vectors : Array[MultiFamilyVector], + labelFamily: Family, rankThreshold : Double, minLeafCount : Int, splitCriteria : SplitCriteria.Value, @@ -310,9 +309,9 @@ object DecisionTreeTrainer { var leftNeg : Double = 0.0 var rightNeg : Double = 0.0 - for (example <- examples) { - val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, example) - val label = example.get(rankKey).values().iterator().next() > rankThreshold + for (vector <- vectors) { + val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, vector) + val label = vector.get(labelFamily).iterator.next.value > rankThreshold if (response) { if (label) { @@ -399,8 +398,8 @@ object DecisionTreeTrainer { // Evaluate a multiclass classification-type split def evaluateMulticlassSplit( - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], - rankKey : String, + vectors : Array[MultiFamilyVector], + labelFamily : Family, minLeafCount : Int, splitCriteria : SplitCriteria.Value, candidateOpt : Option[ModelRecord]): Option[Double] = { @@ -410,11 +409,11 @@ object DecisionTreeTrainer { var leftCount = 0 var rightCount = 0 - for (example <- examples) { - val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, example) - for (kv <- example.get(rankKey).entrySet()) { - val key = kv.getKey - val value = kv.getValue + for (vector <- vectors) { + val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, vector) + for (fv <- vector.get(labelFamily).iterator) { + val key = fv.feature.name + val value = fv.value if (response) { val v = rightDist.getOrElse(key, 0.0) @@ -453,8 +452,8 @@ object DecisionTreeTrainer { // Evaluate a regression-type split // See http://www.stat.cmu.edu/~cshalizi/350-2006/lecture-10.pdf for overview of algorithm used def evaluateRegressionSplit( - examples : Array[util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]]], - rankKey : String, + vectors : Array[MultiFamilyVector], + labelFamily : Family, minLeafCount : Int, splitCriteria : SplitCriteria.Value, candidateOpt : Option[ModelRecord]): Option[Double] = { @@ -465,9 +464,9 @@ object DecisionTreeTrainer { var leftMean : Double = 0.0 var leftSumSq : Double = 0.0 - for (example <- examples) { - val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, example) - val labelValue = example.get(rankKey).values().iterator().next() + for (vector <- vectors) { + val response = BoostedStumpsModel.getStumpResponse(candidateOpt.get, vector) + val labelValue = vector.get(labelFamily).iterator.next.value // Using Welford's Method for computing mean and sum-squared errors in numerically stable way; // more details can be found in @@ -499,17 +498,15 @@ object DecisionTreeTrainer { // Returns a candidate split sampled from an example. def getCandidateSplit( - ex : util.Map[java.lang.String, util.Map[java.lang.String, java.lang.Double]], - rankKey : String, + vector : MultiFamilyVector, + labelFamily : Family, rnd : Random) : Option[ModelRecord] = { // Flatten the features and pick one randomly. - val features = collection.mutable.ArrayBuffer[(String, String, Double)]() + val features = collection.mutable.ArrayBuffer[(Feature, Double)]() - for (family <- ex) { - if (!family._1.equals(rankKey)) { - for (feature <- family._2) { - features.append((family._1, feature._1, feature._2)) - } + for (featureValue <- vector.iterator()) { + if (!featureValue.feature().family().equals(labelFamily)) { + features.append((featureValue.feature, featureValue.value)) } } @@ -519,9 +516,9 @@ object DecisionTreeTrainer { val idx = rnd.nextInt(features.size) val rec = new ModelRecord() - rec.setFeatureFamily(features(idx)._1) - rec.setFeatureName(features(idx)._2) - rec.setThreshold(features(idx)._3) + rec.setFeatureFamily(features(idx)._1.family.name) + rec.setFeatureName(features(idx)._1.name) + rec.setThreshold(features(idx)._2) Some(rec) } @@ -531,8 +528,9 @@ object DecisionTreeTrainer { sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/Evaluation.scala b/training/src/main/scala/com/airbnb/aerosolve/training/Evaluation.scala index e2e4bbb4..ec83c7ec 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/Evaluation.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/Evaluation.scala @@ -1,13 +1,12 @@ package com.airbnb.aerosolve.training -import org.slf4j.{Logger, LoggerFactory} -import org.apache.spark.rdd.RDD import com.airbnb.aerosolve.core.EvaluationRecord -import org.apache.spark.SparkContext._ -import scala.collection.{mutable, Map} -import scala.collection.mutable.{ArrayBuffer, Buffer} -import scala.collection.JavaConversions._ +import org.apache.spark.rdd.RDD +import org.slf4j.{Logger, LoggerFactory} + import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Map, mutable} /* * Given an RDD of EvaluationRecord return standard evaluation metrics @@ -21,7 +20,7 @@ object Evaluation { evalMetric : String) : Array[(String, Double)] = { var metrics = mutable.Buffer[(String, Double)]() var bestF1 = -1.0 - val thresholds = records.map(x => x.score).histogram(buckets)._1 + val thresholds = records.map(x => x.getScore).histogram(buckets)._1 // Search all thresholds for the best F1 // At the same time collect the precision and recall. val trainPR = new ArrayBuffer[(Double, Double)]() @@ -55,7 +54,7 @@ object Evaluation { evalMetric : String) : Array[(String, Double)] = { var metrics = mutable.Buffer[(String, Double)]() var bestF1 = -1.0 - val scores: List[Double] = records.map(x => x.score) + val scores: List[Double] = records.map(x => x.getScore) val thresholds = getThresholds(scores, buckets) // Search all thresholds for the best F1 // At the same time collect the precision and recall. @@ -90,20 +89,20 @@ object Evaluation { records.flatMap(rec => { // Metric, value, count val metrics = scala.collection.mutable.ArrayBuffer[(String, (Double, Double))]() - if (rec.scores != null && rec.labels != null) { - val prefix = if (rec.is_training) "TRAIN_" else "HOLD_" + if (rec.getScores != null && rec.getLabels != null) { + val prefix = if (rec.isIs_training) "TRAIN_" else "HOLD_" // Order by top scores. - val sorted = rec.scores.asScala.toBuffer.sortWith((a, b) => a._2 > b._2) + val sorted = rec.getScores.asScala.toBuffer.sortWith((a, b) => a._2 > b._2) // All pairs hinge loss val count = sorted.size var hingeLoss = 0.0 - for (label <- rec.labels.asScala) { + for (label <- rec.getLabels.asScala) { for (j <- 0 until count) { if (label._1 != sorted(j)._1) { - val scorei = rec.scores.get(label._1) + val scorei = rec.getScores.get(label._1) val scorej = sorted(j)._2 val truei = label._2 - var truej = rec.labels.get(sorted(j)._1) + var truej = rec.getLabels.get(sorted(j)._1) if (truej == null) truej = 0.0 if (truei > truej) { val margin = truei - truej @@ -118,7 +117,7 @@ object Evaluation { metrics.append((prefix + "ALL_PAIRS_HINGE_LOSS", (hingeLoss, 1.0))) var inTopK = false for (i <- 0 until sorted.size) { - if (rec.labels.containsKey(sorted(i)._1)) { + if (rec.getLabels.containsKey(sorted(i)._1)) { inTopK = true metrics.append((prefix + "MEAN_RECIPROCAL_RANK", (1.0 / (i + 1), 1.0))) } @@ -299,10 +298,10 @@ object Evaluation { private def getClassificationAUCTrainHold(records : RDD[EvaluationRecord]) : (Double, Double) = { // find minimal and maximal scores - var minScore = records.take(1).apply(0).score + var minScore = records.take(1).apply(0).getScore var maxScore = minScore records.foreach(record => { - val score = record.score + val score = record.getScore minScore = Math.min(minScore, score) maxScore = Math.max(maxScore, score) }) @@ -335,7 +334,7 @@ object Evaluation { private def getClassificationAUCTrainHold(records : List[EvaluationRecord]) : (Double, Double) = { // find minimal and maximal scores - val scores = records.map(rec => rec.score) + val scores = records.map(rec => rec.getScore) val minScore = scores.min var maxScore = scores.max if(minScore >= maxScore) { @@ -392,12 +391,12 @@ object Evaluation { private def evaluateRecordForAUC(record : EvaluationRecord, minScore : Double, maxScore : Double) : (Long, (Long, Long, Long, Long)) = { - var offset = if (record.is_training) 0 else 2 - if (record.label <= 0) { + var offset = if (record.isIs_training) 0 else 2 + if (record.getLabel <= 0) { offset += 1 } - val score : Long = ((record.score - minScore) / (maxScore - minScore) * 100).toLong + val score : Long = ((record.getScore - minScore) / (maxScore - minScore) * 100).toLong offset match { case 0 => (score, (1, 0, 0, 0)) @@ -410,12 +409,12 @@ object Evaluation { private def evaluateRecordRegression(record : EvaluationRecord) : Iterator[(String, Double)] = { val out = collection.mutable.ArrayBuffer[(String, Double)]() - val prefix = if (record.is_training) "TRAIN_" else "HOLD_" - val diff = record.label - record.score + val prefix = if (record.isIs_training) "TRAIN_" else "HOLD_" + val diff = record.getLabel - record.getLabel val sqErr = diff * diff // to compute SMAPE third version https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error val absLabelMinusScore = Math.abs(diff) - val labelPlusScore = record.label + record.score + val labelPlusScore = record.getLabel + record.getScore out.append((prefix + "SQERR", sqErr)) out.append((prefix + "ABS_LABEL_MINUS_SCORE", absLabelMinusScore)) out.append((prefix + "LABEL_PLUS_SCORE", labelPlusScore)) @@ -429,25 +428,25 @@ object Evaluation { threshold : Double) : Iterator[(String, Double)] = { val out = collection.mutable.ArrayBuffer[(String, Double)]() - val prefix = if (record.is_training) "TRAIN_" else "HOLD_" - if (record.score > threshold) { - if (record.label > 0) { + val prefix = if (record.isIs_training) "TRAIN_" else "HOLD_" + if (record.getScore > threshold) { + if (record.getLabel > 0) { out.append((prefix + "TP", 1.0)) } else { out.append((prefix + "FP", 1.0)) } } else { - if (record.label <= 0) { + if (record.getLabel <= 0) { out.append((prefix + "TN", 1.0)) } else { out.append((prefix + "FN", 1.0)) } } - val error = if (record.label > 0) { - (1.0 - record.score) * (1.0 - record.score) + val error = if (record.getLabel > 0) { + (1.0 - record.getScore) * (1.0 - record.getScore) } else { - record.score * record.score + record.getScore * record.getScore } out.append((prefix + "SQERR", error)) out.append((prefix + "COUNT", 1.0)) diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/FeatureSelection.scala b/training/src/main/scala/com/airbnb/aerosolve/training/FeatureSelection.scala index bad6538f..67cce3b1 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/FeatureSelection.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/FeatureSelection.scala @@ -1,30 +1,12 @@ package com.airbnb.aerosolve.training -import java.io.BufferedWriter -import java.io.OutputStreamWriter -import java.util - -import com.airbnb.aerosolve.core.{ModelRecord, ModelHeader, FeatureVector, Example} -import com.airbnb.aerosolve.core.models.LinearModel -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.Example +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry} import com.typesafe.config.Config -import org.slf4j.{LoggerFactory, Logger} -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.slf4j.{Logger, LoggerFactory} -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Buffer import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ -import scala.util.Random -import scala.math.abs -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path object FeatureSelection { private final val log: Logger = LoggerFactory.getLogger("FeatureSelection") @@ -35,47 +17,49 @@ object FeatureSelection { def pointwiseMutualInformation(examples : RDD[Example], config : Config, key : String, - rankKey : String, + labelFamily : Family, posThreshold : Double, minPosCount : Double, - newCrosses : Boolean) : RDD[((String, String), Double)] = { - val pointwise = LinearRankerUtils.makePointwise(examples, config, key, rankKey) + newCrosses : Boolean, + registry : FeatureRegistry) : RDD[(Feature, Double)] = { + val pointwise = LinearRankerUtils.makePointwiseFloat(examples, config, key, registry) + val allFeature = registry.feature(allKey._1, allKey._2) + val features = pointwise .mapPartitions(part => { - // The tuple2 is var, var | positive - val output = scala.collection.mutable.HashMap[(String, String), (Double, Double)]() - part.foreach(example =>{ - val featureVector = example.example.get(0) - val isPos = if (featureVector.floatFeatures.get(rankKey).asScala.head._2 > posThreshold) 1.0 - else 0.0 - val all : (Double, Double) = output.getOrElse(allKey, (0.0, 0.0)) - output.put(allKey, (all._1 + 1.0, all._2 + 1.0 * isPos)) + // The tuple2 is var, var | positive + val output = scala.collection.mutable.HashMap[Feature, (Double, Double)]() + part.foreach(example =>{ + val featureVector = example.only + val labelVal = featureVector.get(labelFamily).iterator.next.value + val isPos = if (labelVal > posThreshold) 1.0 else 0.0 + val all : (Double, Double) = output.getOrElse(allFeature, (0.0, 0.0)) + output.put(allFeature, (all._1 + 1.0, all._2 + 1.0 * isPos)) - val features : Array[(String, String)] = - LinearRankerUtils.getFeatures(featureVector) - if (newCrosses) { - for (i <- features) { - for (j <- features) { - if (i._1 < j._1) { - val key = ("%s%s".format(i._1, j._1), - "%s%s".format(i._2, j._2)) - val x = output.getOrElse(key, (0.0, 0.0)) - output.put(key, (x._1 + 1.0, x._2 + 1.0 * isPos)) + if (newCrosses) { + for (fv1 <- featureVector.iterator) { + for (fv2 <- featureVector.iterator) { + if (fv1.feature.compareTo(fv2.feature) <= 0) { + val feature = registry.feature( + "%s%s".format(fv1.feature.family.name, fv2.feature.family.name), + "%s%s".format(fv1.feature.name, fv2.feature.name)) + val x = output.getOrElse(feature, (0.0, 0.0)) + output.put(feature, (x._1 + 1.0, x._2 + 1.0 * isPos)) + } } } } - } - for (feature <- features) { - val x = output.getOrElse(feature, (0.0, 0.0)) - output.put(feature, (x._1 + 1.0, x._2 + 1.0 * isPos)) - } + for (featureValue <- featureVector.iterator()) { + val x = output.getOrElse(featureValue.feature, (0.0, 0.0)) + output.put(featureValue.feature, (x._1 + 1.0, x._2 + 1.0 * isPos)) + } + }) + output.iterator }) - output.iterator - }) - .reduceByKey((a, b) => (a._1 + b._1, a._2 + b._2)) - .filter(x => x._2._2 >= minPosCount) + .reduceByKey((a, b) => (a._1 + b._1, a._2 + b._2)) + .filter(x => x._2._2 >= minPosCount) - val allCount = features.filter(x => x._1.equals(allKey)).take(1).head + val allCount = features.filter(x => x._1.equals(allFeature)).take(1).head features.map(x => { val prob = x._2._1 / allCount._2._1 diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/ForestTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/ForestTrainer.scala index 11ef209f..1a69741e 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/ForestTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/ForestTrainer.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.features.{Family, FeatureRegistry} import com.airbnb.aerosolve.core.models.BoostedStumpsModel import com.airbnb.aerosolve.core.models.DecisionTreeModel import com.airbnb.aerosolve.core.models.ForestModel @@ -26,9 +27,10 @@ object ForestTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : ForestModel = { + key : String, + registry: FeatureRegistry) : ForestModel = { val candidateSize : Int = config.getInt(key + ".num_candidates") - val rankKey : String = config.getString(key + ".rank_key") + val labelFamily : Family = registry.family(config.getString(key + ".rank_key")) val rankThreshold : Double = config.getDouble(key + ".rank_threshold") val maxDepth : Int = config.getInt(key + ".max_depth") val minLeafCount : Int = config.getInt(key + ".min_leaf_items") @@ -39,9 +41,9 @@ object ForestTrainer { val numTrees : Int = config.getInt(key + ".num_trees") val examples = LinearRankerUtils - .makePointwiseFloat(input, config, key) - .map(x => Util.flattenFeature(x.example(0))) - .filter(x => x.contains(rankKey)) + .makePointwiseFloat(input, config, key, registry) + .map(example => example.only) + .filter(vector => vector.contains(labelFamily)) .coalesce(numTrees, true) val trees = examples.mapPartitions(part => { @@ -56,38 +58,37 @@ object ForestTrainer { 0, 0, maxDepth, - rankKey, + labelFamily, rankThreshold, numTries, minLeafCount, SplitCriteria.splitCriteriaFromName(splitCriteriaName)) - val tree = new DecisionTreeModel() - tree.setStumps(stumps) + val tree = new DecisionTreeModel(registry) + tree.stumps(stumps) Array(tree).iterator }) .collect - .toArray log.info("%d trees trained".format(trees.size)) - val forest = new ForestModel() + val forest = new ForestModel(registry) val scale = 1.0f / numTrees.toFloat - forest.setTrees(new java.util.ArrayList[DecisionTreeModel]()) + forest.trees(new java.util.ArrayList[DecisionTreeModel]()) for (tree <- trees) { - for (stump <- tree.getStumps) { - if (stump.featureWeight != 0.0f) { - stump.featureWeight *= scale + for (stump <- tree.stumps) { + if (stump.getFeatureWeight != 0.0f) { + stump.setFeatureWeight(stump.getFeatureWeight * scale) } - if (stump.labelDistribution != null) { - val dist = stump.labelDistribution.asScala + if (stump.getLabelDistribution != null) { + val dist = stump.getLabelDistribution.asScala for (rec <- dist) { dist.put(rec._1, rec._2 * scale) } } } - forest.getTrees().append(tree) + forest.trees.append(tree) } forest } @@ -95,8 +96,9 @@ object ForestTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/FullRankLinearTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/FullRankLinearTrainer.scala index 1df26d78..4c468ca0 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/FullRankLinearTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/FullRankLinearTrainer.scala @@ -1,5 +1,6 @@ package com.airbnb.aerosolve.training +import com.airbnb.aerosolve.core.features.{Feature, FeatureRegistry, Family} import com.airbnb.aerosolve.core.{Example, LabelDictionaryEntry} import com.airbnb.aerosolve.core.models.FullRankLinearModel import com.airbnb.aerosolve.core.util.FloatVector @@ -25,23 +26,25 @@ object FullRankLinearTrainer { case class FullRankLinearTrainerOptions(loss : String, iterations : Int, - rankKey : String, + labelFamily : Family, lambda : Double, subsample : Double, minCount : Int, cache : String, solver : String, - labelMinCount: Option[Int]) + labelMinCount: Option[Int], + registry: FeatureRegistry) def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : FullRankLinearModel = { - val options = parseTrainingOptions(config.getConfig(key)) + key : String, + registry: FeatureRegistry) : FullRankLinearModel = { + val options = parseTrainingOptions(config.getConfig(key), registry) val raw : RDD[Example] = LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, options.registry) val pointwise = options.cache match { case "memory" => raw.cache() @@ -64,8 +67,9 @@ object FullRankLinearTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } @@ -73,19 +77,19 @@ object FullRankLinearTrainer { options : FullRankLinearTrainerOptions, model : FullRankLinearModel, pointwise : RDD[Example]) = { - var prevGradients : Map[(String, String), GradientContainer] = Map() - val step = scala.collection.mutable.HashMap[(String, String), FloatVector]() + var prevGradients : Map[Feature, GradientContainer] = Map() + val step = scala.collection.mutable.HashMap[Feature, FloatVector]() for (iter <- 0 until options.iterations) { log.info(s"Iteration $iter") val sample = pointwise.sample(false, options.subsample) - val gradients: Map[(String, String), GradientContainer] = options.loss match { + val gradients: Map[Feature, GradientContainer] = options.loss match { case "softmax" => softmaxGradient(sc, options, model, sample) case "hinge" => hingeGradient(sc, options ,model, sample, "l1") case "squared_hinge" => hingeGradient(sc, options, model, sample, "l2") case _: String => softmaxGradient(sc, options, model, sample) } - val weightVector = model.getWeightVector() - val dim = model.getLabelDictionary.size() + val weightVector = model.weightVector + val dim = model.labelDictionary.size options.solver match { case "sparse_boost" => GradientUtils .sparseBoost(gradients, weightVector, dim, options.lambda) @@ -98,40 +102,38 @@ object FullRankLinearTrainer { } def filterZeros(model : FullRankLinearModel) = { - val weightVector = model.getWeightVector() - for (family <- weightVector) { - val toDelete = scala.collection.mutable.ArrayBuffer[String]() - for (feature <- family._2) { - if (feature._2.dot(feature._2) < 1e-6) { - toDelete.add(feature._1) - } - } - for (deleteFeature <- toDelete) { - family._2.remove(deleteFeature) + val toDelete = scala.collection.mutable.ArrayBuffer[Feature]() + for (entry <- model.weightVector) { + if (entry._2.dot(entry._2) < 1e-6) { + toDelete.add(entry._1) } } + // TODO (Brad): Encapsulation + for (deleteFeature <- toDelete) { + model.weightVector.remove(deleteFeature) + } } def softmaxGradient(sc : SparkContext, options : FullRankLinearTrainerOptions, model : FullRankLinearModel, - pointwise : RDD[Example]) : Map[(String, String), GradientContainer] = { + pointwise : RDD[Example]) : Map[Feature, GradientContainer] = { val modelBC = sc.broadcast(model) pointwise .mapPartitions(partition => { val model = modelBC.value - val labelToIdx = model.getLabelToIndex() - val dim = model.getLabelDictionary.size() - val gradient = scala.collection.mutable.HashMap[(String, String), GradientContainer]() - val weightVector = model.getWeightVector() + val labelToIdx = model.labelToIndex + val dim = model.labelDictionary.size + val gradient = scala.collection.mutable.HashMap[Feature, GradientContainer]() + val weightVector = model.weightVector() - partition.foreach(examples => { - val flatFeatures = Util.flattenFeature(examples.example.get(0)) - val labels = flatFeatures.get(options.rankKey) + partition.foreach(example => { + val vector = example.only + val labels = vector.get(options.labelFamily) if (labels != null) { - val posLabels = labels.keySet().asScala - val scores = model.scoreFlatFeature(flatFeatures) + val posLabels = labels.iterator.map(fv => fv.feature.name) + val scores = model.scoreFlatFeature(vector) // Convert to multinomial using softmax. scores.softmax() // The importance is prob - 1 for positive labels, prob otherwise. @@ -142,21 +144,18 @@ object FullRankLinearTrainer { } } // Gradient is importance * feature value - for (family <- flatFeatures) { - for (feature <- family._2) { - val key = (family._1, feature._1) - // We only care about features in the model. - if (weightVector.containsKey(key._1) && weightVector.get(key._1).containsKey(key._2)) { - val featureVal = feature._2 - val gradContainer = gradient.getOrElse(key, - GradientContainer(new FloatVector(dim), 0.0)) - gradContainer.grad.multiplyAdd(featureVal.toFloat, scores) - val norm = math.max(featureVal * featureVal, 1.0) - gradient.put(key, - GradientContainer(gradContainer.grad, - gradContainer.featureSquaredSum + norm - )) - } + for (fv <- vector.iterator) { + val key = fv.feature + // We only care about features in the model. + if (weightVector.containsKey(key)) { + val gradContainer = gradient.getOrElse(key, + GradientContainer(new FloatVector(dim), 0.0)) + gradContainer.grad.multiplyAdd(fv.value, scores) + val norm = math.max(fv.value * fv.value, 1.0) + gradient.put(key, + GradientContainer(gradContainer.grad, + gradContainer.featureSquaredSum + norm + )) } } } @@ -164,7 +163,7 @@ object FullRankLinearTrainer { gradient.iterator }) .reduceByKey((a, b) => GradientUtils.sumGradients(a,b)) - .collectAsMap + .collectAsMap() .toMap } @@ -172,25 +171,25 @@ object FullRankLinearTrainer { options : FullRankLinearTrainerOptions, model : FullRankLinearModel, pointwise : RDD[Example], - lossType : String) : Map[(String, String), GradientContainer] = { + lossType : String) : Map[Feature, GradientContainer] = { val modelBC = sc.broadcast(model) pointwise .mapPartitions(partition => { val model = modelBC.value - val labelToIdx = model.getLabelToIndex() - val dim = model.getLabelDictionary.size() - val gradient = scala.collection.mutable.HashMap[(String, String), GradientContainer]() - val weightVector = model.getWeightVector() + val labelToIdx = model.labelToIndex + val dim = model.labelDictionary.size + val gradient = scala.collection.mutable.HashMap[Feature, GradientContainer]() + val weightVector = model.weightVector val rnd = new Random() partition.foreach(examples => { - val flatFeatures = Util.flattenFeature(examples.example.get(0)) - val labels = flatFeatures.get(options.rankKey) + val vector = examples.only + val labels = vector.get(options.labelFamily) if (labels != null && labels.size() > 0) { - val posLabels = labels.toArray + val posLabels = labels.iterator.map(fv => (fv.feature.name, fv.value)).toArray // Pick a random positive label - val posLabelRnd = rnd.nextInt(posLabels.size) + val posLabelRnd = rnd.nextInt(posLabels.length) val (posLabel, posMargin) = posLabels(posLabelRnd) val posIdx = labelToIdx.get(posLabel) // Pick a random other label. This can be a negative or a positive with a smaller margin. @@ -198,11 +197,14 @@ object FullRankLinearTrainer { while (negIdx == posIdx) { negIdx = rnd.nextInt(dim) } - val negLabel = model.getLabelDictionary.get(negIdx).label - val negMargin : Double = if (labels.containsKey(negLabel)) labels.get(negLabel) else 0.0 + val negLabel = model.labelDictionary.get(negIdx).getLabel + val negLabelFeature = options.labelFamily.feature(negLabel) + val negMargin : Double = if (labels.containsKey(negLabelFeature)) { + labels.getDouble(negLabelFeature) + } else 0.0 if (posMargin > negMargin) { - val scores = model.scoreFlatFeature(flatFeatures) + val scores = model.scoreFlatFeature(vector) val posScore = scores.values(posIdx) val negScore = scores.values(negIdx) // loss = max(0, margin + w(-) * x - w(+) * x) @@ -219,21 +221,19 @@ object FullRankLinearTrainer { grad.values(negIdx) = loss.toFloat } - for (family <- flatFeatures) { - for (feature <- family._2) { - val key = (family._1, feature._1) - // We only care about features in the model. - if (weightVector.containsKey(key._1) && weightVector.get(key._1).containsKey(key._2)) { - val featureVal = feature._2 - val gradContainer = gradient.getOrElse(key, - GradientContainer(new FloatVector(dim), 0.0)) - gradContainer.grad.multiplyAdd(featureVal.toFloat, grad) - val norm = math.max(featureVal * featureVal, 1.0) - gradient.put(key, - GradientContainer(gradContainer.grad, - gradContainer.featureSquaredSum + norm - )) - } + for (fv <- vector.iterator) { + val key = fv.feature + // We only care about features in the model. + if (weightVector.containsKey(key)) { + val featureVal = fv.value + val gradContainer = gradient.getOrElse(key, + GradientContainer(new FloatVector(dim), 0.0)) + gradContainer.grad.multiplyAdd(featureVal, grad) + val norm = math.max(featureVal * featureVal, 1.0) + gradient.put(key, + GradientContainer(gradContainer.grad, + gradContainer.featureSquaredSum + norm + )) } } } @@ -247,52 +247,49 @@ object FullRankLinearTrainer { .toMap } - def parseTrainingOptions(config : Config) : FullRankLinearTrainerOptions = { - + def parseTrainingOptions(config : Config, registry: FeatureRegistry) : FullRankLinearTrainerOptions = { FullRankLinearTrainerOptions( loss = config.getString("loss"), iterations = config.getInt("iterations"), - rankKey = config.getString("rank_key"), + labelFamily = registry.family(config.getString("rank_key")), lambda = config.getDouble("lambda"), subsample = config.getDouble("subsample"), minCount = config.getInt("min_count"), cache = Try(config.getString("cache")).getOrElse(""), solver = Try(config.getString("solver")).getOrElse("rprop"), - labelMinCount = Try(Some(config.getInt("label_min_count"))).getOrElse(None) + labelMinCount = Try(Some(config.getInt("label_min_count"))).getOrElse(None), + registry = registry ) } def setupModel(options : FullRankLinearTrainerOptions, pointwise : RDD[Example]) : FullRankLinearModel = { val stats = TrainingUtils.getFeatureStatistics(options.minCount, pointwise) val labelCounts = if (options.labelMinCount.isDefined) { - TrainingUtils.getLabelCounts(options.labelMinCount.get, pointwise, options.rankKey) + TrainingUtils.getLabelCounts(options.labelMinCount.get, pointwise, options.labelFamily) } else { - TrainingUtils.getLabelCounts(options.minCount, pointwise, options.rankKey) + TrainingUtils.getLabelCounts(options.minCount, pointwise, options.labelFamily) } - val model = new FullRankLinearModel() - val weights = model.getWeightVector() - val dict = model.getLabelDictionary() + val model = new FullRankLinearModel(options.registry) + val weights = model.weightVector + val dict = model.labelDictionary for (kv <- stats) { - val (family, feature) = kv._1 - if (family != options.rankKey) { - if (!weights.containsKey(family)) { - weights.put(family, new java.util.HashMap[java.lang.String, FloatVector]()) - } - val familyMap = weights.get(family) - if (!familyMap.containsKey(feature)) { + val (feature, _) = kv + if (feature.family != options.labelFamily) { + if (!weights.containsKey(feature)) { // Dummy entry until we know the number of labels. - familyMap.put(feature, null) + // TODO (Brad): Awkward. Intentionally setting the key to null in a map is scary. + weights.put(feature, null) } } } for (kv <- labelCounts) { - val (family, feature) = kv._1 + val (feature, count) = kv val entry = new LabelDictionaryEntry() - entry.setLabel(feature) - entry.setCount(kv._2) + entry.setLabel(feature.name()) + entry.setCount(count) dict.add(entry) } @@ -301,14 +298,11 @@ object FullRankLinearTrainer { // Now fill all the feature vectors with length dim. var count : Int = 0 - for (family <- weights) { - val keys = family._2.keySet() - for (key <- keys) { - count = count + 1 - family._2.put(key, new FloatVector(dim)) - } + for (feature <- weights.keySet()) { + count = count + 1 + weights.put(feature, new FloatVector(dim)) } - log.info(s"Total number of features is $count") + log.info(s"Total number of inputFeatures is $count") model.buildLabelToIndex() model diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/GradientUtils.scala b/training/src/main/scala/com/airbnb/aerosolve/training/GradientUtils.scala index 4c3d06b8..894627d8 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/GradientUtils.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/GradientUtils.scala @@ -1,5 +1,6 @@ package com.airbnb.aerosolve.training +import com.airbnb.aerosolve.core.features.Feature import com.airbnb.aerosolve.core.util.FloatVector import org.slf4j.{LoggerFactory, Logger} @@ -23,30 +24,27 @@ object GradientUtils { } // Gradient update rule from "boosting with structural sparsity Duchi et al 2009" - def sparseBoost(gradients : Map[(String, String), GradientContainer], - weightVector : java.util.Map[String,java.util.Map[String,com.airbnb.aerosolve.core.util.FloatVector]], + def sparseBoost(gradients : Map[Feature, GradientContainer], + weightVector : java.util.Map[Feature, FloatVector], dim :Int, lambda : Double) = { var gradientNorm = 0.0 var featureCount = 0 gradients.foreach(kv => { val (key, gradient) = kv - val featureMap = weightVector.get(key._1) - if (featureMap != null) { - val weight = featureMap.get(key._2) - if (weight != null) { - // Just a proxy measure for convergence. - gradientNorm = gradientNorm + gradient.grad.dot(gradient.grad) - val scale = 2.0 / math.max(1e-6, gradient.featureSquaredSum) - weight.multiplyAdd(-scale.toFloat, gradient.grad) - val hingeScale = 1.0 - lambda * scale / math.sqrt(weight.dot(weight)) - if (hingeScale <= 0.0f) { - // Weights may get re-activated ... filter at the end. - weight.setZero(dim) - } else { - featureCount = featureCount + 1 - weight.scale(hingeScale.toFloat) - } + val weight = weightVector.get(key) + if (weight != null) { + // Just a proxy measure for convergence. + gradientNorm = gradientNorm + gradient.grad.dot(gradient.grad) + val scale = 2.0 / math.max(1e-6, gradient.featureSquaredSum) + weight.multiplyAdd(-scale.toFloat, gradient.grad) + val hingeScale = 1.0 - lambda * scale / math.sqrt(weight.dot(weight)) + if (hingeScale <= 0.0f) { + // Weights may get re-activated ... filter at the end. + weight.setZero(dim) + } else { + featureCount = featureCount + 1 + weight.scale(hingeScale.toFloat) } } }) @@ -57,10 +55,10 @@ object GradientUtils { // Improved RPROP- algorithm // https://en.wikipedia.org/wiki/Rprop // http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.3428 - def rprop(gradients : Map[(String, String), GradientContainer], - prevGradients : Map[(String, String), GradientContainer], - step : scala.collection.mutable.HashMap[(String, String), FloatVector], - weightVector : java.util.Map[String,java.util.Map[String,com.airbnb.aerosolve.core.util.FloatVector]], + def rprop(gradients : Map[Feature, GradientContainer], + prevGradients : Map[Feature, GradientContainer], + step : scala.collection.mutable.HashMap[Feature, FloatVector], + weightVector : java.util.Map[Feature, FloatVector], dim :Int, lambda : Double, deltaMax : Float = 1.0f) = { @@ -75,35 +73,32 @@ object GradientUtils { gradients.foreach(kv => { val (key, gradient) = kv - val featureMap = weightVector.get(key._1) - if (featureMap != null) { - val weight = featureMap.get(key._2) - if (weight != null) { - val prev = prevGradients.get(key) - val prevGrad = if (prev.isEmpty) new FloatVector(dim) else prev.get.grad - val currOpt = step.get(key) - // Create a new step vector if we don't have one. - if (currOpt.isEmpty) { - val tmp = new FloatVector(dim) - tmp.setConstant(deltaInitial) - step.put(key, tmp) - } - val currStep = step.get(key).get - for (i <- 0 until dim) { - // L2 regularization term - gradient.grad.values(i) = gradient.grad.values(i) + lambda.toFloat * weight.values(i) - - val prod = prevGrad.values(i) * gradient.grad.values(i) - if (prod > 0) { - currStep.values(i) = math.min(currStep.values(i) * etaPlus, deltaMax) - } else if (prod < 0) { - currStep.values(i) = math.max(currStep.values(i) * etaMinus, deltaMin) - gradient.grad.values(i) = 0 - } - val sign = if (gradient.grad.values(i) > 0) 1.0f else -1.0f - weight.values(i) = weight.values(i) - sign * currStep.values(i) + val weight = weightVector.get(key) + if (weight != null) { + val prev = prevGradients.get(key) + val prevGrad = if (prev.isEmpty) new FloatVector(dim) else prev.get.grad + val currOpt = step.get(key) + // Create a new step vector if we don't have one. + if (currOpt.isEmpty) { + val tmp = new FloatVector(dim) + tmp.setConstant(deltaInitial) + step.put(key, tmp) + } + val currStep = step.get(key).get + for (i <- 0 until dim) { + // L2 regularization term + gradient.grad.values(i) = gradient.grad.values(i) + lambda.toFloat * weight.values(i) + val prod = prevGrad.values(i) * gradient.grad.values(i) + if (prod > 0) { + currStep.values(i) = math.min(currStep.values(i) * etaPlus, deltaMax) + } else if (prod < 0) { + currStep.values(i) = math.max(currStep.values(i) * etaMinus, deltaMin) + gradient.grad.values(i) = 0 } + val sign = if (gradient.grad.values(i) > 0) 1.0f else -1.0f + weight.values(i) = weight.values(i) - sign * currStep.values(i) + } } }) @@ -115,22 +110,19 @@ object GradientUtils { log.info("Median Step L2 norms = " + median) } - def gradientDescent(gradients : Map[(String, String), GradientContainer], - weightVector : java.util.Map[String,java.util.Map[String,com.airbnb.aerosolve.core.util.FloatVector]], + def gradientDescent(gradients : Map[Feature, GradientContainer], + weightVector : java.util.Map[Feature, FloatVector], dim :Int, learningRate: Double, lambda : Double) = { gradients.foreach(kv => { val (key, gradient) = kv - val featureMap = weightVector.get(key._1) - if (featureMap != null) { - val weight: FloatVector = featureMap.get(key._2) - if (weight != null) { - for (i <- 0 until dim) { - // L2 regularization term determined by lambda - gradient.grad.values(i) = gradient.grad.values(i) + lambda.toFloat * weight.values(i) - weight.values(i) -= learningRate.toFloat * gradient.grad.values(i) - } + val weight = weightVector.get(key) + if (weight != null) { + for (i <- 0 until dim) { + // L2 regularization term determined by lambda + gradient.grad.values(i) = gradient.grad.values(i) + lambda.toFloat * weight.values(i) + weight.values(i) -= learningRate.toFloat * gradient.grad.values(i) } } }) diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/HistogramCalibrator.scala b/training/src/main/scala/com/airbnb/aerosolve/training/HistogramCalibrator.scala index daa3cd33..1c8bc359 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/HistogramCalibrator.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/HistogramCalibrator.scala @@ -1,7 +1,7 @@ package com.airbnb.aerosolve.training -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD + import scala.collection.mutable.ArrayBuffer // Calibrates scores into the [0 .. 1] range by taking a histogram of scores diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/KDTree.scala b/training/src/main/scala/com/airbnb/aerosolve/training/KDTree.scala index c4f8b61c..be799fd0 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/KDTree.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/KDTree.scala @@ -35,26 +35,26 @@ class KDTree(val nodes: Array[KDTreeNode]) extends Serializable { val node = nodes(currIdx) builder.append("%d,%f,%f,%f,%f,%d".format( - currIdx, node.minX, node.minY, node.maxX, node.maxY, node.count)) + currIdx, node.getMinX, node.getMinY, node.getMaxX, node.getMaxY, node.getCount)) if (parent < 0) { builder.append(",") } else { builder.append(",%d".format(parent)) } - if (node.nodeType == KDTreeNodeType.LEAF) { + if (node.getNodeType == KDTreeNodeType.LEAF) { builder.append(",TRUE,,,,") } else { builder.append(",FALSE,%d,%d,%s,%f".format( - node.leftChild, - node.rightChild, - (if (node.nodeType == KDTreeNodeType.X_SPLIT) "TRUE" else "FALSE"), - node.splitValue + node.getLeftChild, + node.getRightChild, + (if (node.getNodeType == KDTreeNodeType.X_SPLIT) "TRUE" else "FALSE"), + node.getSplitValue )) } csv.append(builder.toString) - if (node.nodeType != KDTreeNodeType.LEAF) { - getCSVRecursive(node.leftChild, currIdx, csv) - getCSVRecursive(node.rightChild, currIdx, csv) + if (node.getNodeType != KDTreeNodeType.LEAF) { + getCSVRecursive(node.getLeftChild, currIdx, csv) + getCSVRecursive(node.getRightChild, currIdx, csv) } } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/KernelTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/KernelTrainer.scala index 770ae649..b828bfb4 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/KernelTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/KernelTrainer.scala @@ -1,25 +1,15 @@ package com.airbnb.aerosolve.training -import java.util - +import com.airbnb.aerosolve.core.{Example, FeatureVector, FunctionForm} +import com.airbnb.aerosolve.core.features.{MultiFamilyVector, Family, FeatureRegistry} import com.airbnb.aerosolve.core.models.KernelModel -import com.airbnb.aerosolve.core.Example -import com.airbnb.aerosolve.core.FeatureVector -import com.airbnb.aerosolve.core.FunctionForm -import com.airbnb.aerosolve.core.ModelRecord -import com.airbnb.aerosolve.core.util.FloatVector -import com.airbnb.aerosolve.core.util.StringDictionary -import com.airbnb.aerosolve.core.util.SupportVector -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.util.{FloatVector, SupportVector} import com.typesafe.config.Config import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} -import scala.util.Random import scala.util.Try -import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ // Simple SGD based kernel trainer. Mostly so we can test the kernel model for online use. // TODO(hector_yee) : if this gets more heavily used add in regularization and better training. @@ -29,20 +19,21 @@ object KernelTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : KernelModel = { + key : String, + registry: FeatureRegistry) : KernelModel = { val modelConfig = config.getConfig(key) val candidateSize : Int = modelConfig.getInt("num_candidates") val kernel : String = modelConfig.getString("kernel") val maxSV : Int = modelConfig.getInt("max_vectors") val scale : Float = modelConfig.getDouble("scale").toFloat - val learningRate : Float = Try(modelConfig.getDouble("learning_rate").toFloat).getOrElse(0.1f) - val rankKey : String = modelConfig.getString("rank_key") - val rankThreshold : Double = Try(modelConfig.getDouble("rank_threshold")).getOrElse(0.0f) + val learningRate : Double = Try(modelConfig.getDouble("learning_rate")).getOrElse(0.1d) + val labelFamily : Family = registry.family(modelConfig.getString("rank_key")) + val rankThreshold : Double = Try(modelConfig.getDouble("rank_threshold")).getOrElse(0.0d) val examples = LinearRankerUtils - .makePointwiseFloat(input, config, key) - val model = initModel(modelConfig, examples) + .makePointwiseFloat(input, config, key, registry) + val model = initModel(modelConfig, examples, registry) val loss : String = modelConfig.getString("loss") @@ -50,12 +41,12 @@ object KernelTrainer { // Super simple SGD trainer. Mostly to get the unit test to pass for (candidate <- candidates) { - val gradient = computeGradient(model, candidate.example(0), loss, rankKey, rankThreshold) + val gradient = computeGradient(model, candidate.only, loss, labelFamily, rankThreshold) if (gradient != 0.0) { - val flatFeatures = Util.flattenFeature(candidate.example(0)); - val vec = model.getDictionary().makeVectorFromSparseFloats(flatFeatures); + val vector = candidate.only + val vec = model.dictionary.makeVectorFromSparseFloats(vector) addNewSupportVector(model, kernel, scale, vec, maxSV) - model.onlineUpdate(gradient, learningRate, flatFeatures) + model.onlineUpdate(gradient, learningRate, vector) } } @@ -63,7 +54,7 @@ object KernelTrainer { } def addNewSupportVector(model : KernelModel, kernel : String, scale : Float, vec : FloatVector, maxSV : Int) = { - val supportVectors = model.getSupportVectors() + val supportVectors = model.supportVectors if (supportVectors.size() < maxSV) { val form = kernel match { case "rbf" => FunctionForm.RADIAL_BASIS_FUNCTION @@ -73,16 +64,21 @@ object KernelTrainer { case 1 => FunctionForm.ARC_COSINE } } - val sv = new SupportVector(vec, form, scale, 0.0f); - supportVectors.add(sv); + val sv = new SupportVector(vec, form, scale, 0.0f) + supportVectors.add(sv) } } def computeGradient(model : KernelModel, - fv : FeatureVector, - loss : String, rankKey : String, rankThreshold : Double) : Float = { + fv : MultiFamilyVector, + loss : String, labelFamily : Family, rankThreshold : Double) : Double = { val prediction = model.scoreItem(fv) - val label = if (loss == "hinge") TrainingUtils.getLabel(fv, rankKey, rankThreshold) else TrainingUtils.getLabel(fv, rankKey) + val label = if (loss == "hinge") { + TrainingUtils.getLabel(fv, labelFamily, rankThreshold) + } else { + TrainingUtils.getLabel(fv, labelFamily) + } + loss match { case "hinge" => { val lossVal = scala.math.max(0.0, 1.0 - label * prediction) @@ -93,27 +89,27 @@ object KernelTrainer { case "regression" => { val diff = prediction - label if (diff > 1.0) { - return 1.0f + return 1.0d } if (diff < -1.0) { - return -1.0f + return -1.0d } } } - return 0.0f; + 0.0d } - def initModel(modelConfig : Config, examples : RDD[Example]) : KernelModel = { + def initModel(modelConfig : Config, examples : RDD[Example], registry: FeatureRegistry) : KernelModel = { val minCount : Int = modelConfig.getInt("min_count") - val rankKey : String = modelConfig.getString("rank_key") + val labelFamily : Family = registry.family(modelConfig.getString("rank_key")) log.info("Building dictionary") val stats = TrainingUtils.getFeatureStatistics(minCount, examples) log.info(s"Dictionary size is ${stats.size}") - val dictionary = TrainingUtils.createStringDictionaryFromFeatureStatistics(stats, Set(rankKey)) + val dictionary = TrainingUtils.createStringDictionaryFromFeatureStatistics(stats, Set(labelFamily)) - val model = new KernelModel() - model.setDictionary(dictionary) + val model = new KernelModel(registry) + model.dictionary(dictionary) model } @@ -121,8 +117,9 @@ object KernelTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + "model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerTrainer.scala index 43900035..f4eaa475 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerTrainer.scala @@ -4,10 +4,12 @@ import java.io.BufferedWriter import java.io.OutputStreamWriter import java.util.concurrent.ConcurrentHashMap +import com.airbnb.aerosolve.core.features._ import com.airbnb.aerosolve.core.{ModelRecord, ModelHeader, FeatureVector, Example} import com.airbnb.aerosolve.core.models.LinearModel import com.airbnb.aerosolve.core.util.Util import com.typesafe.config.Config +import org.apache.spark.broadcast.Broadcast import org.slf4j.{LoggerFactory, Logger} import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ @@ -21,6 +23,7 @@ import scala.math.abs import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import scala.collection.mutable /* * A trainer that generates a linear ranker. @@ -69,290 +72,322 @@ object LinearRankerTrainer { key : String, loss : String, numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], + weights : collection.mutable.Map[Feature, (Double, Double)], iteration : Int) : - RDD[((String, String), (Double, Double))] = { - loss match { - case "ranking" => rankingTrain(sc, input, config, key, numBags, weights, iteration) - case "regression" => regressionTrain(sc, input, config, key, numBags, weights, iteration) - case "regressionL2" => regressionL2Train(sc, input, config, key, numBags, weights, iteration) - case "hinge" => classificationTrain(sc, input, config, key, numBags, weights, iteration) - case "logistic" => logisticTrain(sc, input, config, key, numBags, weights, iteration) - case _ => { - log.error("Unknown loss type %s".format(loss)) - System.exit(-1) - rankingTrain(sc, input, config, key, numBags, weights, iteration) + RDD[(Feature, (Double, Double))] = { + val registry : FeatureRegistry = new FeatureRegistry + val labelFamily: Family = registry.family(config.getString(key + ".rank_key")) + val lossFeature : Feature = registry.feature(lossKey._1, lossKey._2) + val lambda = config.getDouble(key + ".lambda") + val lambda2 = config.getDouble(key + ".lambda2") + val learningRate = config.getDouble(key + ".learning_rate") + + val dropout = if (config.hasPath(key + ".dropout")) { + Option(config.getDouble(key + ".dropout")) + } else None + val weightsBC = sc.broadcast(weights) + val examples = LinearRankerUtils + .makePointwiseFloat(input, config, key, registry) + .coalesce(numBags, true) + + // The following code might be a bit confusing for people who aren't experienced with Scala. + // To re-use more code, I'm using currying and function passing. + // If you're more used to OOP, imagine each type of loss is a subclass with a bunch of + // constructor params. Each class implements .updateWeights for a vector. + // Ranking is different and (conceptually) overrides the default .updateWeights for an example + // so it can behave a little differently. + // Since OOP can be implemented with closures and currying is simplified closures, these are + // just two ways to do the same thing. + // I think this functional approach is actually a lot easier to write and read even though + // it's likely a little unfamiliar. + // TODO (Brad) : This can all be simplified a lot more. + val weightUpdateFunction = if (loss == "ranking") { + ranking(lambda, lambda2, learningRate, lossFeature) _ + } else { + val function = loss match { + case "regression" => regression(lambda, lambda2, config.getDouble(key + ".epsilon"), + learningRate, lossFeature) _ + case "regressionL2" => regressionL2(lambda, lambda2, learningRate, lossFeature) _ + case "hinge" => classification(lambda, lambda2, learningRate, + config.getDouble(key + ".rank_threshold"), lossFeature) _ + case "logistic" => logistic(lambda, lambda2, learningRate, + config.getDouble(key + ".rank_threshold"), lossFeature) _ + case _ => { + val message = "Unknown loss type %s".format(loss) + log.error(message) + throw new IllegalArgumentException(message) + } } + basicWeightUpdateFunction(function) _ } + + updateWeights(examples, weightsBC, iteration, labelFamily, dropout, weightUpdateFunction) } - def regressionTrain(sc : SparkContext, - input : RDD[Example], - config : Config, - key : String, - numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], - iteration : Int) : - RDD[((String, String), (Double, Double))] = { - val rankKey: String = config.getString(key + ".rank_key") - val lambda : Double = config.getDouble(key + ".lambda") - val lambda2 : Double = config.getDouble(key + ".lambda2") - val epsilon : Double = config.getDouble(key + ".epsilon") - val learningRate: Double = config.getDouble(key + ".learning_rate") - val weightsBC = sc.broadcast(weights) - LinearRankerUtils - .makePointwise(input, config, key, rankKey) - .coalesce(numBags, true) + def updateWeights(input : RDD[Example], + weightsBC : Broadcast[collection.mutable.Map[Feature, (Double, Double)]], + iteration : Int, + labelFamily : Family, + dropout : Option[Double], + weightUpdateFunction : (Iterable[MultiFamilyVector], + collection.mutable.Map[Feature, (Double, Double)], + Family, Option[Double], Int) => Unit): RDD[(Feature, (Double, Double))] = { + input .mapPartitions(partition => { - // The keys the feature (family, value) - // The values are the weight, sum of squared gradients. - val weightMap = weightsBC.value - val rnd = new Random() - partition.foreach(examples => { - examples - .example - .filter(x => x.stringFeatures != null && - x.floatFeatures != null && - x.floatFeatures.containsKey(rankKey)) - .foreach(sample => { - val target = sample.floatFeatures.get(rankKey).iterator.next()._2 - val features = LinearRankerUtils.getFeatures(sample) - val prediction = LinearRankerUtils.score(features, weightMap) - // The loss function the epsilon insensitive loss L = max(0,|w'x - y| - epsilon) - // So if prediction = w'x and prediction > y then - // the dloss / dw is + val weightMap = weightsBC.value + partition.foreach(examples => { + val vectors = examples + .filter(vector => vector.contains(labelFamily)) + weightUpdateFunction(vectors, weightMap, labelFamily, dropout, iteration) + }) + weightMap.iterator + }) + } - val diff = prediction - target - val loss = Math.abs(diff) - epsilon - val lossEntry = weightMap.getOrElse(lossKey, (0.0, 0.0)) + def basicWeightUpdateFunction(function: + (Iterable[Feature], + collection.mutable.Map[Feature, (Double, Double)], + Double, Double, Int) => Unit) + (vectors: Iterable[MultiFamilyVector], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + labelFamily: Family, + dropout: Option[Double], + iteration: Int): Unit = { + vectors.foreach(vector => { + val label = LinearRankerUtils.getLabel(vector, labelFamily) + val scoreVector: Iterable[FeatureValue] = dropout + .map(d => vector.withDropout(d).asScala) + .getOrElse(vector.iterator.asScala.toIterable) + val features = LinearRankerUtils.getFeatures(scoreVector) + val prediction = LinearRankerUtils.score(features, weightMap) + val finalPrediction = dropout + .map(dropoutVal => prediction / (1.0 - dropoutVal)) + .getOrElse(prediction) + function(features, weightMap, finalPrediction, label, iteration) + }) + } - if (loss <= 0) { - // No loss suffered - weightMap.put(lossKey, (lossEntry._1, lossEntry._2 + 1.0)) - } else { - val grad = if (diff > 0) 1.0 else -1.0 - features.foreach(v => { - val wt = weightMap.getOrElse(v, (0.0, 0.0)) + def ranking(lambda: Double, + lambda2: Double, + learningRate: Double, + lossFeature : Feature) + (vectors: Iterable[MultiFamilyVector], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + labelFamily: Family, + dropout: Option[Double], + iteration: Int): Unit = { + val rnd = new Random(java.util.Calendar.getInstance().getTimeInMillis) + var size = weightMap.size + LinearRankerUtils.rankingCompression(vectors, labelFamily) + .foreach(ce => { + val pos = ce.pos.filter(x => rnd.nextDouble() > dropout.get) + val neg = ce.neg.filter(x => rnd.nextDouble() > dropout.get) + val posScore = LinearRankerUtils.score(pos, weightMap) / (1.0 - dropout.get) + val negScore = LinearRankerUtils.score(neg, weightMap) / (1.0 - dropout.get) + val loss = 1.0 - posScore + negScore + val lossEntry = weightMap.getOrElse(lossFeature, (0.0, 0.0)) + if (loss > 0.0) { + def update(feature : Feature, grad : Double) = { + val wt = weightMap.getOrElse(feature, (0.0, 0.0)) + // Allow puts only in the first iteration and size is less than MAX_SIZE + // or if it already exists. + if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { val newGradSum = wt._2 + 1.0 val newWeight = fobosUpdate(currWeight = wt._1, gradient = grad, - eta = learningRate, + eta = learningRate, l1Reg = lambda, l2Reg = lambda2, sum = newGradSum) - weightMap.put(v, (newWeight, newGradSum)) - }) - weightMap.put(lossKey, (lossEntry._1 + loss, lossEntry._2 + 1.0)) + if (newWeight == 0.0) { + weightMap.remove(feature) + } else { + weightMap.put(feature, (newWeight, newGradSum)) + } + if (wt._1 == 0.0) { + size = size + 1 + } + } } - }) + pos.foreach(v => { + update(v, -1.0) + }) + neg.foreach(v => { + update(v, 1.0) + }) + weightMap.put(lossFeature, (lossEntry._1 + loss, lossEntry._2 + 1.0)) + } else { + weightMap.put(lossFeature, (lossEntry._1, lossEntry._2 + 1.0)) + } }) - weightMap - .iterator - }) } - // Squared difference loss - def regressionL2Train(sc : SparkContext, - input : RDD[Example], - config : Config, - key : String, - numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], - iteration : Int) : - RDD[((String, String), (Double, Double))] = { - val rankKey: String = config.getString(key + ".rank_key") - val lambda : Double = config.getDouble(key + ".lambda") - val lambda2 : Double = config.getDouble(key + ".lambda2") - val learningRate: Double = config.getDouble(key + ".learning_rate") - val weightsBC = sc.broadcast(weights) - LinearRankerUtils - .makePointwise(input, config, key, rankKey) - .coalesce(numBags, true) - .mapPartitions(partition => { - // The keys the feature (family, value) - // The values are the weight, sum of squared gradients. - val weightMap = weightsBC.value - val rnd = new Random() - partition.foreach(examples => { - examples - .example - .filter(x => x.stringFeatures != null && - x.floatFeatures != null && - x.floatFeatures.containsKey(rankKey)) - .foreach(sample => { - val target = sample.floatFeatures.get(rankKey).iterator.next()._2 - val features = LinearRankerUtils.getFeatures(sample) - val prediction = LinearRankerUtils.score(features, weightMap) - // The loss function is the squared error loss L = 0.5 * ||w'x - y||^2 - // So dloss / dw = (w'x - y) * x - - val diff = prediction - target - val sqdiff = diff * diff - val loss = 0.5 * sqdiff - val lossEntry = weightMap.getOrElse(lossKey, (0.0, 0.0)) + def regression(lambda: Double, + lambda2: Double, + epsilon: Double, + learningRate: Double, + lossFeature : Feature) + (features : Iterable[Feature], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + prediction: Double, + target : Double, + iteration: Int) : Unit = { + // The loss function the epsilon insensitive loss L = max(0,|w'x - y| - epsilon) + // So if prediction = w'x and prediction > y then + // the dloss / dw is + + val diff = prediction - target + val loss = Math.abs(diff) - epsilon + val lossEntry = weightMap.getOrElse(lossFeature, (0.0, 0.0)) - val grad = diff - features.foreach(v => { - val wt = weightMap.getOrElse(v, (0.0, 0.0)) - val newGradSum = wt._2 + sqdiff - val newWeight = fobosUpdate(currWeight = wt._1, - gradient = grad, - eta = learningRate, - l1Reg = lambda, - l2Reg = lambda2, - sum = newGradSum) - weightMap.put(v, (newWeight, newGradSum)) - }) - weightMap.put(lossKey, (lossEntry._1 + loss, lossEntry._2 + 1.0)) - }) + if (loss <= 0) { + // No loss suffered + weightMap.put(lossFeature, (lossEntry._1, lossEntry._2 + 1.0)) + } else { + val grad = if (diff > 0) 1.0 else -1.0 + features.foreach(feature => { + val wt = weightMap.getOrElse(feature, (0.0, 0.0)) + val newGradSum = wt._2 + 1.0 + val newWeight = fobosUpdate(currWeight = wt._1, + gradient = grad, + eta = learningRate, + l1Reg = lambda, + l2Reg = lambda2, + sum = newGradSum) + weightMap.put(feature, (newWeight, newGradSum)) }) - weightMap - .iterator + weightMap.put(lossFeature, (lossEntry._1 + loss, lossEntry._2 + 1.0)) + } + } + + def regressionL2(lambda: Double, + lambda2: Double, + learningRate: Double, + lossFeature : Feature) + (features : Iterable[Feature], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + prediction: Double, + target : Double, + iteration : Int) : Unit = { + // The loss function is the squared error loss L = 0.5 * ||w'x - y||^2 + // So dloss / dw = (w'x - y) * x + + val diff = prediction - target + val sqdiff = diff * diff + val loss = 0.5 * sqdiff + val lossEntry = weightMap.getOrElse(lossFeature, (0.0, 0.0)) + + val grad = diff + features.foreach(feature => { + val wt = weightMap.getOrElse(feature, (0.0, 0.0)) + val newGradSum = wt._2 + sqdiff + val newWeight = fobosUpdate(currWeight = wt._1, + gradient = grad, + eta = learningRate, + l1Reg = lambda, + l2Reg = lambda2, + sum = newGradSum) + weightMap.put(feature, (newWeight, newGradSum)) }) + weightMap.put(lossFeature, (lossEntry._1 + loss, lossEntry._2 + 1.0)) } - def classificationTrain(sc : SparkContext, - input : RDD[Example], - config : Config, - key : String, - numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], - iteration : Int) : - RDD[((String, String), (Double, Double))] = { - val rankKey: String = config.getString(key + ".rank_key") - val weightsBC = sc.broadcast(weights) - LinearRankerUtils - .makePointwise(input, config, key, rankKey) - .coalesce(numBags, true) - .mapPartitions(partition => { - // The keys the feature (family, value) - // The values are the weight, sum of squared gradients. - val weightMap = weightsBC.value - val lambda : Double = config.getDouble(key + ".lambda") - val lambda2 : Double = config.getDouble(key + ".lambda2") - var size = weightMap.size - val rnd = new Random() - val learningRate: Double = config.getDouble(key + ".learning_rate") - val threshold: Double = config.getDouble(key + ".rank_threshold") - val dropout : Double = config.getDouble(key + ".dropout") - partition.foreach(examples => { - examples - .example - .filter(x => x.stringFeatures != null && - x.floatFeatures != null && - x.floatFeatures.containsKey(rankKey)) - .foreach(sample => { - val rank = sample.floatFeatures.get(rankKey).iterator.next()._2 - val features = LinearRankerUtils.getFeatures(sample).filter(x => rnd.nextDouble() > dropout) - val prediction = LinearRankerUtils.score(features, weightMap) / (1.0 - dropout) - val label = if (rank <= threshold) { - -1.0 - } else { - 1.0 + def classification(lambda : Double, + lambda2 : Double, + learningRate : Double, + threshold : Double, + lossFeature : Feature) + (features : Iterable[Feature], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + prediction: Double, + rank : Double, + iteration : Int) : Unit = { + val label = if (rank <= threshold) { + -1.0 + } else { + 1.0 + } + val loss = 1.0 - label * prediction + val lossEntry = weightMap.getOrElse(lossFeature, (0.0, 0.0)) + var size = weightMap.size + if (loss > 0.0) { + features.foreach(feature => { + val wt = weightMap.getOrElse(feature, (0.0, 0.0)) + // Allow puts only in the first iteration and size is less than MAX_SIZE + // or if it already exists. + if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { + if (wt._1 == 0.0) { + // We added a weight increase the size. + size = size + 1 } - val loss = 1.0 - label * prediction - val lossEntry = weightMap.getOrElse(lossKey, (0.0, 0.0)) - if (loss > 0.0) { - features.foreach(v => { - val wt = weightMap.getOrElse(v, (0.0, 0.0)) - // Allow puts only in the first iteration and size is less than MAX_SIZE - // or if it already exists. - if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { - if (wt._1 == 0.0) { - // We added a weight increase the size. - size = size + 1 - } - val newGradSum = wt._2 + 1.0 - val newWeight = fobosUpdate(currWeight = wt._1, - gradient = -label, - eta = learningRate, - l1Reg = lambda, - l2Reg = lambda2, - sum = newGradSum) - if (newWeight == 0.0) { - weightMap.remove(v) - } else { - weightMap.put(v, (newWeight, newGradSum)) - } - } - }) - weightMap.put(lossKey, (lossEntry._1 + loss, lossEntry._2 + 1.0)) + val newGradSum = wt._2 + 1.0 + val newWeight = fobosUpdate(currWeight = wt._1, + gradient = -label, + eta = learningRate, + l1Reg = lambda, + l2Reg = lambda2, + sum = newGradSum) + if (newWeight == 0.0) { + weightMap.remove(feature) } else { - weightMap.put(lossKey, (lossEntry._1, lossEntry._2 + 1.0)) + weightMap.put(feature, (newWeight, newGradSum)) } - }) + } }) - weightMap - .iterator - }) + weightMap.put(lossFeature, (lossEntry._1 + loss, lossEntry._2 + 1.0)) + } else { + weightMap.put(lossFeature, (lossEntry._1, lossEntry._2 + 1.0)) + } } - def logisticTrain(sc : SparkContext, - input : RDD[Example], - config : Config, - key : String, - numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], - iteration : Int) : - RDD[((String, String), (Double, Double))] = { - val weightsBC = sc.broadcast(weights) - LinearRankerUtils - .makePointwiseCompressed(input, config, key) - .coalesce(numBags, true) - .mapPartitions(partition => { - // The keys the feature (family, value) - // The values are the weight, sum of squared gradients. - val weightMap = weightsBC.value - var size = weightMap.size - val rnd = new Random() - val learningRate: Double = config.getDouble(key + ".learning_rate") - val threshold: Double = config.getDouble(key + ".rank_threshold") - val lambda : Double = config.getDouble(key + ".lambda") - val lambda2 : Double = config.getDouble(key + ".lambda2") - val dropout : Double = config.getDouble(key + ".dropout") - partition.foreach(sample => { - val prediction = LinearRankerUtils.score(sample.pos, weightMap) / (1.0 - dropout) - val label = if (sample.label <= threshold) { - -1.0 + def logistic(lambda : Double, + lambda2 : Double, + learningRate : Double, + threshold : Double, + lossFeature : Feature) + (features : Iterable[Feature], + weightMap: collection.mutable.Map[Feature, (Double, Double)], + prediction: Double, + rank : Double, + iteration : Int) : Unit = { + val label = if (rank <= threshold) { + -1.0 + } else { + 1.0 + } + // To prevent blowup. + val corr = scala.math.min(10.0, label * prediction) + val expCorr = scala.math.exp(corr) + val loss = scala.math.log(1.0 + 1.0 / expCorr) + val lossEntry = weightMap.getOrElse(lossFeature, (0.0, 0.0)) + var size = weightMap.size + features.foreach(feature => { + val wt = weightMap.getOrElse(feature, (0.0, 0.0)) + // Allow puts only in the first iteration and size is less than MAX_SIZE + // or if it already exists. + if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { + if (wt._1 == 0.0) { + // We added a weight increase the size. + size = size + 1 + } + val newGradSum = wt._2 + 1.0 + val grad = -label / (1.0 + expCorr) + val newWeight = fobosUpdate(currWeight = wt._1, + gradient = grad, + eta = learningRate, + l1Reg = lambda, + l2Reg = lambda2, + sum = newGradSum) + if (newWeight == 0.0) { + weightMap.remove(feature) } else { - 1.0 + weightMap.put(feature, (newWeight, newGradSum)) } - // To prevent blowup. - val corr = scala.math.min(10.0, label * prediction) - val expCorr = scala.math.exp(corr) - val loss = scala.math.log(1.0 + 1.0 / expCorr) - val lossEntry = weightMap.getOrElse(lossKey, (0.0, 0.0)) - sample.pos.foreach(v => { - val wt = weightMap.getOrElse(v, (0.0, 0.0)) - // Allow puts only in the first iteration and size is less than MAX_SIZE - // or if it already exists. - if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { - if (wt._1 == 0.0) { - // We added a weight increase the size. - size = size + 1 - } - val newGradSum = wt._2 + 1.0 - val grad = -label / (1.0 + expCorr) - val newWeight = fobosUpdate(currWeight = wt._1, - gradient = grad, - eta = learningRate, - l1Reg = lambda, - l2Reg = lambda2, - sum = newGradSum) - if (newWeight == 0.0) { - weightMap.remove(v) - } else { - weightMap.put(v, (newWeight, newGradSum)) - } - } - }) - weightMap.put(lossKey, (lossEntry._1 + loss, lossEntry._2 + 1.0)) - }) - weightMap - .iterator + } }) } + // http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf def fobosUpdate(currWeight : Double, gradient : Double, @@ -370,93 +405,24 @@ object LinearRankerTrainer { sign * step } - def rankingTrain(sc : SparkContext, - input : RDD[Example], - config : Config, - key : String, - numBags : Int, - weights : collection.mutable.Map[(String, String), (Double, Double)], - iteration : Int) : - RDD[((String, String), (Double, Double))] = { - val examples = LinearRankerUtils.rankingTrain(input, config, key) - - val weightsBC = sc.broadcast(weights) - - examples - .coalesce(numBags, true) - .mapPartitions(partition => { - // The keys the feature (family, value) - // The values are the weight, sum of squared gradients. - val weightMap = weightsBC.value - var size = weightMap.size - val rnd = new Random(java.util.Calendar.getInstance().getTimeInMillis) - val learningRate: Double = config.getDouble(key + ".learning_rate") - val lambda : Double = config.getDouble(key + ".lambda") - val lambda2 : Double = config.getDouble(key + ".lambda2") - val dropout : Double = config.getDouble(key + ".dropout") - partition.foreach(ce => { - val pos = ce.pos.filter(x => rnd.nextDouble() > dropout) - val neg = ce.neg.filter(x => rnd.nextDouble() > dropout) - val posScore = LinearRankerUtils.score(pos, weightMap) / (1.0 - dropout) - val negScore = LinearRankerUtils.score(neg, weightMap) / (1.0 - dropout) - val loss = 1.0 - posScore + negScore - val lossEntry = weightMap.getOrElse(lossKey, (0.0, 0.0)) - if (loss > 0.0) { - def update(v : (String, String), grad : Double) = { - val wt = weightMap.getOrElse(v, (0.0, 0.0)) - // Allow puts only in the first iteration and size is less than MAX_SIZE - // or if it already exists. - if ((iteration == 1 && size < MAX_WEIGHTS) || wt._1 != 0.0) { - val newGradSum = wt._2 + 1.0 - val newWeight = fobosUpdate(currWeight = wt._1, - gradient = grad, - eta = learningRate, - l1Reg = lambda, - l2Reg = lambda2, - sum = newGradSum) - if (newWeight == 0.0) { - weightMap.remove(v) - } else { - weightMap.put(v, (newWeight, newGradSum)) - } - if (wt._1 == 0.0) { - size = size + 1 - } - } - } - pos.foreach(v => { - update(v, -1.0) - }) - neg.foreach(v => { - update(v, 1.0) - }) - weightMap.put(lossKey, (lossEntry._1 + loss, lossEntry._2 + 1.0)) - } else { - weightMap.put(lossKey, (lossEntry._1, lossEntry._2 + 1.0)) - } - }) - // Strip off the sum of squared gradients for the result - weightMap - .iterator - }) - } - def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : Array[((String, String), Double)] = { + key : String, + registry : FeatureRegistry) : Array[(Feature, Double)] = { val loss: String = try { config.getString(key + ".loss") } catch { case _: Throwable => "ranking" } log.info("Training using " + loss) - sgdTrain(sc, input, config, key, loss) + sgdTrain(sc, input, config, key, loss, registry) } def setPrior(config : Config, key : String, - weights : collection.mutable.Map[(String, String), (Double, Double)]) = { + weights : mutable.Map[Feature, (Double, Double)], + registry : FeatureRegistry) = { try { val priors = config.getStringList(key + ".prior") for (prior <- priors) { @@ -466,7 +432,7 @@ object LinearRankerTrainer { val name = tokens(1) val weight = tokens(2).toDouble log.info("Setting prior %s:%s = %f".format(family, name, weight)) - weights.put((family, name), (weight, 1.0)) + weights.put(registry.feature(family, name), (weight, 1.0)) } } } catch { @@ -478,15 +444,16 @@ object LinearRankerTrainer { input : RDD[Example], config : Config, key : String, - loss : String) : Array[((String, String), Double)] = { + loss : String, + registry : FeatureRegistry) : Array[(Feature, Double)] = { val numBags : Int = config.getInt(key + ".num_bags") val iterations : Int = config.getInt(key + ".iterations") val subsample : Double = Try(config.getDouble(key + ".subsample")).getOrElse(1.0) // The keys the feature (family, value) // The values are the weight. - var weights = new ConcurrentHashMap[(String, String), (Double, Double)]().asScala - setPrior(config, key, weights) + var weights = mutable.HashMap[Feature, (Double, Double)]() + setPrior(config, key, weights, registry) // Since we are bagging models, average them by numBags val scale : Double = 1.0 / numBags.toDouble @@ -499,7 +466,8 @@ object LinearRankerTrainer { .reduceByKey((a,b) => (a._1 + b._1, a._2 + b._2)) .persist() - val lossV = resultsRDD.filter(x => x._1 == lossKey).take(1) + val lossFeature = registry.feature(lossKey._1, lossKey._2) + val lossV = resultsRDD.filter(x => x._1 == lossFeature).take(1) var lossSum = 0.0 var count = 0.0 if (!lossV.isEmpty) { @@ -509,13 +477,13 @@ object LinearRankerTrainer { 0.0 } val results = resultsRDD - .filter(x => x._1 != lossKey) + .filter(x => x._1 != lossFeature) .map(x => (scala.math.abs(x._2._1), (x._1, x._2))) .top(LinearRankerTrainer.MAX_WEIGHTS) .map(x => x._2) // Nuke the old weights - weights = new ConcurrentHashMap[(String, String), (Double, Double)]().asScala + weights = mutable.HashMap[Feature, (Double, Double)]() var sz = 0 results .foreach(value => { @@ -533,7 +501,7 @@ object LinearRankerTrainer { .toArray } - def save(writer : BufferedWriter, weights : Array[((String, String), Double)]) = { + def save(writer : BufferedWriter, weights : Array[(Feature, Double)]) = { val header = new ModelHeader() header.setModelType("linear") header.setNumRecords(weights.size) @@ -544,12 +512,13 @@ object LinearRankerTrainer { log.info("Top 50 weights") for(i <- 0 until weights.size) { val weight = weights(i) - val (family, name) = weight._1 + val family = weight._1.family.name + val name = weight._1.name val wt = weight._2 if (i < 50) { log.info("%s : %s = %f".format(family, name, wt)) } - val record = new ModelRecord(); + val record = new ModelRecord() record.setFeatureFamily(family) record.setFeatureName(name) record.setFeatureWeight(wt) @@ -562,8 +531,9 @@ object LinearRankerTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val weights = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val weights = train(sc, input, config, key, registry) val output : String = config.getString(key + ".model_output") val fileSystem = FileSystem.get(new java.net.URI(output), new Configuration()) diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerUtils.scala b/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerUtils.scala index ffc1f146..98c1c3f7 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerUtils.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/LinearRankerUtils.scala @@ -2,6 +2,8 @@ package com.airbnb.aerosolve.training import com.airbnb.aerosolve.core.Example import com.airbnb.aerosolve.core.FeatureVector +import com.airbnb.aerosolve.core.features._ +import com.airbnb.aerosolve.core.models.AbstractModel import com.airbnb.aerosolve.core.transforms.Transformer import com.typesafe.config.Config import org.slf4j.{LoggerFactory, Logger} @@ -17,88 +19,65 @@ import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.util.Random -case class CompressedExample(pos : Array[(String, String)], - neg : Array[(String, String)], - label : Double); +case class CompressedExample(pos : Array[Feature], + neg : Array[Feature], + label : Double) object LinearRankerUtils { private final val log: Logger = LoggerFactory.getLogger("LinearRankerUtils") - def getFeatures(sample : FeatureVector) : Array[(String, String)] = { - val features = HashSet[(String, String)]() - sample.stringFeatures.foreach(family => { - family._2.foreach(value => { - features.add((family._1, value)) - }) - }) - features.toArray + def getLabel(vector : MultiFamilyVector, labelFamily: Family): Int = { + vector.get(labelFamily).iterator.next.value.toInt + } + + def getFeatures(vector : Iterable[FeatureValue]) : Iterable[Feature] = { + vector.map(fv => fv.feature) } // Does feature expansion on an example and buckets them by rank. - def expandAndBucketizeExamples( - examples : Example, - transformer : Transformer, - rankKey : String) : - Array[Array[Array[(String, String)]]] = { - transformer.combineContextAndItems(examples) - val samples : Seq[FeatureVector] = examples.example - val buckets = HashMap[Int, Buffer[Array[(String, String)]]]() - samples - .filter(x => x.stringFeatures != null && - x.floatFeatures != null && - x.floatFeatures.get(rankKey) != null) - .foreach(sample => { - val rankBucket : Int = sample.floatFeatures.get(rankKey).toSeq.head._2.toInt - val features = getFeatures(sample) - val entryOpt = buckets.get(rankBucket) - if (entryOpt.isEmpty) { - buckets.put(rankBucket, ArrayBuffer(features)) - } else { - entryOpt.get.append(features) - } - }) - // Sort buckets in ascending order. - buckets - .toBuffer - .sortWith((x,y) => x._1 < y._1) - .map(x => x._2.toArray) + // Assumes the example is transformed and contains a label. + def expandAndBucketizeExample(example : Iterable[MultiFamilyVector], + labelFamily : Family) : + Array[Array[Iterable[Feature]]] = { + example + .map(sample => { + val labelBucket : Int = getLabel(sample, labelFamily) + val features = getFeatures(sample) + (labelBucket, features) + }) + .groupBy(_._1) + .toSeq + // Sort buckets in ascending order. + .sortBy(_._1) + .map{ case (_, iter) => iter.map(_._2).toArray } .toArray } - // Makes ranking training data - def rankingTrain(input : RDD[Example], config : Config, key : String) : - RDD[CompressedExample] = { - input - .mapPartitions(partition => { - val output = ArrayBuffer[CompressedExample]() - val rnd = new Random() - val rankKey: String = config.getString(key + ".rank_key") - val transformer = new Transformer(config, key) - partition.foreach(examples => { - val buckets = LinearRankerUtils.expandAndBucketizeExamples(examples, transformer, rankKey) - for (i <- 0 to buckets.size - 2) { - for (j <- i + 1 to buckets.size - 1) { - val neg = buckets(i)(rnd.nextInt(buckets(i).size)).toSet - val pos = buckets(j)(rnd.nextInt(buckets(j).size)).toSet - val intersect = pos.intersect(neg) - // For ranking we have pairs of examples with label always 1.0. - val out = CompressedExample(pos.diff(intersect).toArray, - neg.diff(intersect).toArray, - label = 1.0) - output.append(out) - } - } - }) - output.iterator - }) + def rankingCompression(example : Iterable[MultiFamilyVector], labelFamily : Family) : Seq[CompressedExample] = { + val output = ArrayBuffer[CompressedExample]() + val buckets = expandAndBucketizeExample(example, labelFamily) + val rnd = new Random() + for (i <- 0 to buckets.length - 2) { + for (j <- i + 1 to buckets.length - 1) { + val neg = buckets(i)(rnd.nextInt(buckets(i).length)).toSet + val pos = buckets(j)(rnd.nextInt(buckets(j).length)).toSet + val intersect = pos.intersect(neg) + // For ranking we have pairs of examples with label always 1.0. + val out = CompressedExample(pos.diff(intersect).toArray, + neg.diff(intersect).toArray, + label = 1.0) + output.append(out) + } + } + output.toSeq } - def score(feature : Array[(String, String)], - weightMap : collection.mutable.Map[(String, String), (Double, Double)]) : Double = { + def score(vector : Iterable[Feature], + weightMap : collection.mutable.Map[Feature, (Double, Double)]) : Double = { var sum : Double = 0 - feature.foreach(v => { - val opt = weightMap.get(v) - if (opt != None) { + vector.iterator.foreach(feature => { + val opt = weightMap.get(feature) + if (opt.isDefined) { sum += opt.get._1 } }) @@ -107,78 +86,34 @@ object LinearRankerUtils { // Makes an example pointwise while preserving the float features. def makePointwiseFloat( - examples : RDD[Example], - config : Config, - key : String) : RDD[Example] = { - val transformer = new Transformer(config, key) - examples.map(example => { - val buffer = collection.mutable.ArrayBuffer[Example]() - example.example.asScala.foreach(x => { - val newExample = new Example() - newExample.setContext(example.context) - newExample.addToExample(x) - transformer.combineContextAndItems(newExample) - buffer.append(newExample) + examples : RDD[Example], + config : Config, + key : String, + registry: FeatureRegistry) : RDD[Example] = { + val transformer = new Transformer(config, key, registry) + examples.flatMap(example => { + example.map(vector => { + // TODO (Brad): Why make a new one? Also, why make separate examples for each vector? + val newExample = new SimpleExample(registry) + newExample.context().merge(example.context()) + newExample.createVector().merge(vector) + newExample.transform(transformer) }) - buffer }) - .flatMap(x => x) } // Transforms an RDD of examples def transformExamples( - examples : RDD[Example], - config : Config, - key : String) : RDD[Example] = { - val transformer = new Transformer(config, key) - examples.map(example => { - transformer.combineContextAndItems(example) - example.unsetContext() - example - }) - } - - // Since examples are bags of user impressions, for pointwise algoriths - // we need to shuffle each feature vector separately. - def makePointwise(examples : RDD[Example], - config : Config, - key : String, - rankKey : String) : RDD[Example] = { - val transformer = new Transformer(config, key) - examples.map(example => { - val buffer = collection.mutable.ArrayBuffer[Example]() - example.example.asScala.foreach{x => { - val newExample = new Example() - newExample.setContext(example.context) - newExample.addToExample(x) - transformer.combineContextAndItems(newExample) - // For space reasons remove all float features except rankKey. - val floatFeatures = newExample.example.get(0).getFloatFeatures - if (floatFeatures != null) { - val rank : java.util.Map[java.lang.String, java.lang.Double] = - floatFeatures.get(rankKey) - val newFloat : java.util.Map[java.lang.String, - java.util.Map[java.lang.String, java.lang.Double]] = - new java.util.HashMap() - newFloat.put(rankKey, rank) - newExample.example.get(0).setFloatFeatures(newFloat) - } - buffer.append(newExample) - }} - buffer - }) - .flatMap(x => x) + examples : RDD[Example], + config : Config, + key : String, + registry: FeatureRegistry) : RDD[Example] = { + val transformer = new Transformer(config, key, registry) + examples.map(_.transform(transformer)) } - def makePointwiseCompressed(examples : RDD[Example], - config : Config, - key : String) : RDD[CompressedExample] = { - val rankKey: String = config.getString(key + ".rank_key") - val pointwise = makePointwise(examples, config, key, rankKey) - pointwise.map(example => { - val ex = example.example.get(0) - val label = ex.floatFeatures.get(rankKey).entrySet().iterator().next().getValue - CompressedExample(getFeatures(ex), Array[(String, String)](), label) - }) - } + // TODO (Brad): I removed makePointwise. It removed all float features besides the label family. + // Now that we don't distinguish float and String, I'm not sure how to do this and I'm worried + // it will break something. The comment said it was for space reasons and if so, + // maybe we can skip it with the more efficient representation. } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/LowRankLinearTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/LowRankLinearTrainer.scala index eb5dd002..7ac8339a 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/LowRankLinearTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/LowRankLinearTrainer.scala @@ -1,18 +1,18 @@ package com.airbnb.aerosolve.training -import com.airbnb.aerosolve.core.{Example, LabelDictionaryEntry} +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry} import com.airbnb.aerosolve.core.models.LowRankLinearModel import com.airbnb.aerosolve.core.util.FloatVector -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.{Example, LabelDictionaryEntry} import com.airbnb.aerosolve.training.GradientUtils._ import com.typesafe.config.Config -import org.slf4j.{LoggerFactory, Logger} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConversions._ -import scala.util.{Try, Random} +import scala.collection.mutable +import scala.util.{Random, Try} /* * A trainer that generates a low rank linear model. @@ -24,7 +24,7 @@ object LowRankLinearTrainer { private final val LABEL_EMBEDDING_KEY = "$label_embedding" case class LowRankLinearTrainerOptions(loss : String, iterations : Int, - rankKey : String, + labelFamily : Family, lambda : Double, subsample : Double, minCount : Int, @@ -32,17 +32,19 @@ object LowRankLinearTrainer { solver : String, embeddingDimension : Int, rankLossType: String, - maxNorm: Double) + maxNorm: Double, + registry: FeatureRegistry) def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : LowRankLinearModel = { - val options = parseTrainingOptions(config.getConfig(key)) + key : String, + registry: FeatureRegistry) : LowRankLinearModel = { + val options = parseTrainingOptions(config.getConfig(key), registry) val raw : RDD[Example] = LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, options.registry) val pointwise = options.cache match { case "memory" => raw.cache() @@ -63,21 +65,23 @@ object LowRankLinearTrainer { options : LowRankLinearTrainerOptions, model : LowRankLinearModel, pointwise : RDD[Example]) = { - var prevGradients : Map[(String, String), GradientContainer] = Map() - val step = scala.collection.mutable.HashMap[(String, String), FloatVector]() + var prevGradients : Map[Feature, GradientContainer] = Map() + val step = scala.collection.mutable.HashMap[Feature, FloatVector]() for (iter <- 0 until options.iterations) { log.info(s"Iteration $iter") val sample = pointwise .sample(false, options.subsample) - val gradients: Map[(String, String), GradientContainer] = options.loss match { + val gradients: Map[Feature, GradientContainer] = options.loss match { case "hinge" => hingeGradient(sc, options ,model, sample) case _: String => hingeGradient(sc, options, model, sample) } - val featureWeightVector = model.getFeatureWeightVector - val labelWeightVector = model.getLabelWeightVector - val labelWeightVectorWrapper = new java.util.HashMap[String,java.util.Map[String,com.airbnb.aerosolve.core.util.FloatVector]]() - labelWeightVectorWrapper.put(LABEL_EMBEDDING_KEY, labelWeightVector) + val featureWeightVector = model.featureWeightVector + val labelWeightVector = model.labelWeightVector() + val labelWeightVectorWrapper = labelWeightVector.map{ case (featureName, vector) => + (options.registry.feature(LABEL_EMBEDDING_KEY, featureName), vector) + }.toMap + options.solver match { // TODO (Peng): implement alternating optimization with bagging case "rprop" => { @@ -93,26 +97,27 @@ object LowRankLinearTrainer { def hingeGradient(sc : SparkContext, options : LowRankLinearTrainerOptions, model : LowRankLinearModel, - pointwise : RDD[Example]) : Map[(String, String), GradientContainer] = { + pointwise : RDD[Example]) : Map[Feature, GradientContainer] = { val modelBC = sc.broadcast(model) pointwise .mapPartitions(partition => { val model = modelBC.value - val labelToIdx = model.getLabelToIndex - val dim = model.getLabelDictionary.size() - val gradient = scala.collection.mutable.HashMap[(String, String), GradientContainer]() - val featureWeightVector = model.getFeatureWeightVector - val labelWeightVector = model.getLabelWeightVector + val labelToIdx = model.labelToIndex + val dim = model.labelDictionary.size + val gradient = mutable.HashMap[Feature, GradientContainer]() + val featureWeightVector = model.featureWeightVector + val labelWeightVector = model.labelWeightVector val rnd = new Random() - partition.foreach(examples => { - val flatFeatures = Util.flattenFeature(examples.example.get(0)) - val scores = model.scoreFlatFeature(flatFeatures) - val labels = flatFeatures.get(options.rankKey) + partition.foreach(example => { + val vector = example.only + val scores = model.scoreFlatFeature(vector) + val labels = vector.get(options.labelFamily) + val labelMap = labels.iterator.map(fv => (fv.feature.name, fv.value)).toMap if (labels != null && labels.size() > 0) { - val posLabels = labels.toArray + val posLabels = labelMap.toArray // Pick a random positive label val posLabelRnd = rnd.nextInt(posLabels.length) val (posLabel, posMargin) = posLabels(posLabelRnd) @@ -125,7 +130,7 @@ object LowRankLinearTrainer { var N = 0 do { // Pick a random other label - val (idx, label, margin, iter) = pickRandomOtherLabel(model, labels, posIdx, posMargin, rnd, dim) + val (idx, label, margin, iter) = pickRandomOtherLabel(model, labelMap, posIdx, posMargin, rnd, dim) if (iter < dim) { // we successfully get a random other label negScore = scores.values(negIdx) @@ -150,9 +155,9 @@ object LowRankLinearTrainer { val loss = ((posMargin - negMargin) + (negScore - posScore)) * rankLoss if (loss > 0.0) { // compute gradient w.r.t W (labelWeightVector) - val fvProjection = model.projectFeatureToEmbedding(flatFeatures) + val fvProjection = model.projectFeatureToEmbedding(vector) // update w- - val negLabelKey = (LABEL_EMBEDDING_KEY, negLabel) + val negLabelKey = options.registry.feature(LABEL_EMBEDDING_KEY, negLabel) val gradContainerNeg = gradient.getOrElse(negLabelKey, GradientContainer(new FloatVector(options.embeddingDimension), 0.0)) gradContainerNeg.grad.multiplyAdd(rankLoss, fvProjection) @@ -161,7 +166,7 @@ object LowRankLinearTrainer { // but with rprop solver, this is not used, so we don't compute it here to improve speed gradient.put(negLabelKey, GradientContainer(gradContainerNeg.grad, 0.0)) // update w+ - val posLabelKey = (LABEL_EMBEDDING_KEY, posLabel) + val posLabelKey = options.registry.feature(LABEL_EMBEDDING_KEY, posLabel) val gradContainerPos = gradient.getOrElse(posLabelKey, GradientContainer(new FloatVector(options.embeddingDimension), 0.0)) gradContainerPos.grad.multiplyAdd(-rankLoss, fvProjection) @@ -170,18 +175,15 @@ object LowRankLinearTrainer { // compute gradient w.r.t V (featureWeightVector) val posLabelWeightVector = labelWeightVector.get(posLabel) val negLabelWeightVector = labelWeightVector.get(negLabel) - for (family <- flatFeatures) { - for (feature <- family._2) { - val key = (family._1, feature._1) - // We only care about features in the model. - if (featureWeightVector.containsKey(key._1) && featureWeightVector.get(key._1).containsKey(key._2)) { - val featureVal = feature._2 - val gradContainer = gradient.getOrElse(key, - GradientContainer(new FloatVector(options.embeddingDimension), 0.0)) - gradContainer.grad.multiplyAdd(featureVal.toFloat * rankLoss, negLabelWeightVector) - gradContainer.grad.multiplyAdd(-featureVal.toFloat * rankLoss, posLabelWeightVector) - gradient.put(key, GradientContainer(gradContainer.grad, 0.0)) - } + for (fv <- vector.iterator) { + // We only care about features in the model. + if (featureWeightVector.containsKey(fv.feature)) { + val featureVal = fv.value + val gradContainer = gradient.getOrElse(fv.feature, + GradientContainer(new FloatVector(options.embeddingDimension), 0.0)) + gradContainer.grad.multiplyAdd(featureVal.toFloat * rankLoss, negLabelWeightVector) + gradContainer.grad.multiplyAdd(-featureVal.toFloat * rankLoss, posLabelWeightVector) + gradient.put(fv.feature, GradientContainer(gradContainer.grad, 0.0)) } } } @@ -198,32 +200,28 @@ object LowRankLinearTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } private def normalizeWeightVectors(model: LowRankLinearModel, maxNorm: Double) = { - val featureWeightVector = model.getFeatureWeightVector - val labelWeightVector = model.getLabelWeightVector - for (family <- featureWeightVector.entrySet()) { - for (feature <- family.getValue.entrySet()) { - val weight = feature.getValue - weight.capNorm(maxNorm.toFloat) - } + for (weight <- model.featureWeightVector.values()) { + weight.capNorm(maxNorm.toFloat) } - for (labelWeight <- labelWeightVector.entrySet()) { - val weight = labelWeight.getValue + for (weight <- model.labelWeightVector.values) { weight.capNorm(maxNorm.toFloat) } } - private def parseTrainingOptions(config : Config) : LowRankLinearTrainerOptions = { + private def parseTrainingOptions(config : Config, registry: FeatureRegistry) + : LowRankLinearTrainerOptions = { LowRankLinearTrainerOptions( loss = config.getString("loss"), iterations = config.getInt("iterations"), - rankKey = config.getString("rank_key"), + labelFamily = registry.family(config.getString("rank_key")), lambda = config.getDouble("lambda"), subsample = config.getDouble("subsample"), minCount = config.getInt("min_count"), @@ -231,27 +229,28 @@ object LowRankLinearTrainer { solver = Try(config.getString("solver")).getOrElse("rprop"), embeddingDimension = config.getInt("embedding_dimension"), rankLossType = Try(config.getString("rank_loss")).getOrElse("non_uniform"), - maxNorm = Try(config.getDouble("max_norm")).getOrElse(1.0) + maxNorm = Try(config.getDouble("max_norm")).getOrElse(1.0), + registry = registry ) } private def pickRandomOtherLabel(model: LowRankLinearModel, - posLabels: java.util.Map[java.lang.String, java.lang.Double], + posLabels: Map[String, Double], posIdx: Int, posMargin: Double, rnd: Random, dim: Int) : (Int, String, Double, Int) = { // Pick a random other label. This can be a negative or a positive with a smaller margin. var negIdx = rnd.nextInt(dim) - var negLabel = model.getLabelDictionary.get(negIdx).label - var negMargin : Double = if (posLabels.containsKey(negLabel)) posLabels.get(negLabel) else 0.0 + var negLabel = model.labelDictionary.get(negIdx).getLabel + var negMargin : Double = posLabels.getOrElse(negLabel, 0.0d) var iter = 0 while ((negIdx == posIdx || negMargin > posMargin) && iter < dim) { // we only want to pick a label that has smaller margin, we try at most dim times // we try this for at most dim times negIdx = rnd.nextInt(dim) - negLabel = model.getLabelDictionary.get(negIdx).label - negMargin = if (posLabels.containsKey(negLabel)) posLabels.get(negLabel) else 0.0 + negLabel = model.labelDictionary.get(negIdx).getLabel + negMargin = posLabels.getOrElse(negLabel, 0.0d) iter += 1 } (negIdx, negLabel, negMargin, iter) @@ -288,28 +287,23 @@ object LowRankLinearTrainer { def setupModel(options : LowRankLinearTrainerOptions, pointwise : RDD[Example]) : LowRankLinearModel = { val stats = TrainingUtils.getFeatureStatistics(options.minCount, pointwise) - val model = new LowRankLinearModel() - val featureWeights = model.getFeatureWeightVector - val labelWeights = model.getLabelWeightVector - val dict = model.getLabelDictionary + val model = new LowRankLinearModel(options.registry) + val featureWeights = model.featureWeightVector + val labelWeights = model.labelWeightVector + val dict = model.labelDictionary val embeddingSize = options.embeddingDimension - model.setEmbeddingDimension(embeddingSize) + model.embeddingDimension(embeddingSize) var count : Int = 0 - for (kv <- stats) { - val (family, feature) = kv._1 - if (family == options.rankKey) { + for ((feature, featureStats) <- stats) { + if (feature.family == options.labelFamily) { val entry = new LabelDictionaryEntry() - entry.setLabel(feature) - entry.setCount(kv._2.count.toInt) + entry.setLabel(feature.name) + entry.setCount(featureStats.count.toInt) dict.add(entry) } else { - if (!featureWeights.containsKey(family)) { - featureWeights.put(family, new java.util.HashMap[java.lang.String, FloatVector]()) - } - val familyMap = featureWeights.get(family) - if (!familyMap.containsKey(feature)) { + if (!featureWeights.containsKey(feature)) { count = count + 1 - familyMap.put(feature, FloatVector.getUniformVector(embeddingSize)) + featureWeights.put(feature, FloatVector.getUniformVector(embeddingSize)) } } } @@ -323,7 +317,7 @@ object LowRankLinearTrainer { model.buildLabelToIndex() normalizeWeightVectors(model, options.maxNorm) log.info(s"Total number of labels is ${dict.size()}") - log.info(s"Total number of features is $count") + log.info(s"Total number of inputFeatures is $count") model } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/MaxoutTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/MaxoutTrainer.scala index 7f8661af..6bed4340 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/MaxoutTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/MaxoutTrainer.scala @@ -3,14 +3,13 @@ package com.airbnb.aerosolve.training import java.io.{BufferedWriter, OutputStreamWriter} import java.util.concurrent.ConcurrentHashMap -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry, MultiFamilyVector} import com.airbnb.aerosolve.core.models.MaxoutModel import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.Config import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} @@ -24,11 +23,12 @@ object MaxoutTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : MaxoutModel = { + key : String, + registry: FeatureRegistry) : MaxoutModel = { val loss : String = config.getString(key + ".loss") val numHidden : Int = config.getInt(key + ".num_hidden") val iterations : Int = config.getInt(key + ".iterations") - val rankKey : String = config.getString(key + ".rank_key") + val labelFamily : Family = registry.family(config.getString(key + ".rank_key")) val learningRate : Double = config.getDouble(key + ".learning_rate") val lambda : Double = config.getDouble(key + ".lambda") val lambda2 : Double = config.getDouble(key + ".lambda2") @@ -40,12 +40,12 @@ object MaxoutTrainer { val pointwise : RDD[Example] = LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, registry) .cache() - var model = new MaxoutModel() + var model = new MaxoutModel(registry) model.initForTraining(numHidden) - initModel(minCount, rankKey, pointwise, model) + initModel(minCount, labelFamily, pointwise, model) log.info("Computing max values for all features") log.info("Training using " + loss) @@ -55,7 +55,7 @@ object MaxoutTrainer { key, pointwise, numHidden, - rankKey, + labelFamily, loss, learningRate, lambda, @@ -73,42 +73,38 @@ object MaxoutTrainer { // Intializes the model def initModel(minCount : Int, - rankKey : String, + labelFamily : Family, input : RDD[Example], model : MaxoutModel) = { - val maxScale = getMaxScale(minCount, rankKey, input) + val maxScale = getMaxScale(minCount, labelFamily, input) log.info("Num features = %d".format(maxScale.length)) - for (entry <- maxScale) { - model.addVector(entry._1._1, entry._1._2, entry._2.toFloat) + for ((feature, value) <- maxScale) { + model.addVector(feature, value.toFloat) } } // Returns 1 / largest absolute value of the feature def getMaxScale(minCount : Int, - rankKey : String, - input : RDD[Example]) : Array[((String, String), Double)] = { + labelFamily: Family, + input : RDD[Example]) : Array[(Feature, Double)] = { input .mapPartitions(partition => { - val weights = new ConcurrentHashMap[(String, String), (Double, Int)]().asScala - partition.foreach(example => { - val flatFeature = Util.flattenFeature(example.example.get(0)).asScala - flatFeature.foreach(familyMap => { - if (!rankKey.equals(familyMap._1)) { - familyMap._2.foreach(feature => { - val key = (familyMap._1, feature._1) - val curr = weights.getOrElse(key, (0.0, 0)) - weights.put(key, (scala.math.max(curr._1, feature._2), curr._2 + 1)) - }) - } + val weights = new ConcurrentHashMap[Feature, (Double, Int)]().asScala + partition.foreach(example => { + val vector = example.only + vector.iterator.foreach(fv => { + if (!labelFamily.equals(fv.feature.family)) { + val curr = weights.getOrElse(fv.feature, (0.0, 0)) + weights.put(fv.feature, (scala.math.max(curr._1, fv.value), curr._2 + 1)) + } + }) }) - }) weights.iterator }) .reduceByKey((a, b) => (scala.math.max(a._1, b._1), a._2 + b._2)) .filter(x => x._2._1 > 1e-10 && x._2._2 >= minCount) - .map(x => (x._1, 1.0 / x._2._1)) - .collect - .toArray + .mapValues(x => 1.0 / x._1) + .collect() } def sgdTrain(sc : SparkContext, @@ -116,7 +112,7 @@ object MaxoutTrainer { key : String, input : RDD[Example], numHidden : Int, - rankKey : String, + labelFamily : Family, loss : String, learningRate : Double, lambda : Double, @@ -143,41 +139,41 @@ object MaxoutTrainer { .sample(false, subsample) .coalesce(1, true) .mapPartitions(partition => { - val workingModel = modelBC.value - @volatile var lossSum : Double = 0.0 - @volatile var lossCount : Int = 0 - partition.foreach(example => { - val fv = example.example.get(0) - val rank = fv.floatFeatures.get(rankKey).asScala.head._2 - val label = if (rank <= threshold) { - -1.0 - } else { - 1.0 - } - loss match { - case "logistic" => lossSum = lossSum + updateLogistic(workingModel, fv, label, learningRate, lambda, lambda2, dropout, dropoutHidden, momentum) - case "hinge" => lossSum = lossSum + updateHinge(workingModel, fv, label, learningRate, lambda, lambda2, dropout, dropoutHidden, momentum) - case _ => { - log.error("Unknown loss function %s".format(loss)) - System.exit(-1) + val workingModel = modelBC.value + @volatile var lossSum : Double = 0.0 + @volatile var lossCount : Int = 0 + partition.foreach(example => { + val fv = example.only + val rank = LinearRankerUtils.getLabel(fv, labelFamily) + val label = if (rank <= threshold) { + -1.0 + } else { + 1.0 + } + loss match { + case "logistic" => lossSum = lossSum + updateLogistic(workingModel, fv, label, learningRate, lambda, lambda2, dropout, dropoutHidden, momentum) + case "hinge" => lossSum = lossSum + updateHinge(workingModel, fv, label, learningRate, lambda, lambda2, dropout, dropoutHidden, momentum) + case _ => { + log.error("Unknown loss function %s".format(loss)) + System.exit(-1) + } } - } - lossCount = lossCount + 1 - if (lossCount % lossMod == 0) { - log.info("Loss = %f, samples = %d".format(lossSum / lossMod.toDouble, lossCount)) - lossSum = 0.0 - } + lossCount = lossCount + 1 + if (lossCount % lossMod == 0) { + log.info("Loss = %f, samples = %d".format(lossSum / lossMod.toDouble, lossCount)) + lossSum = 0.0 + } + }) + Array[MaxoutModel](workingModel).iterator }) - Array[MaxoutModel](workingModel).iterator - }) - .collect + .collect() .head saveModel(modelRet, config, key) - return modelRet + modelRet } def updateLogistic(model : MaxoutModel, - fv : FeatureVector, + fv : MultiFamilyVector, label : Double, learningRate : Double, lambda : Double, @@ -185,8 +181,8 @@ object MaxoutTrainer { dropout : Double, dropoutHidden : Double, momentum : Double) : Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, dropout) - val response = model.getResponse(flatFeatures) + val vector = fv.withDropout(dropout) + val response = model.getResponse(vector) val values = response.getValues for (i <- 0 until values.length) { if (scala.util.Random.nextDouble() < dropoutHidden) { @@ -206,8 +202,8 @@ object MaxoutTrainer { lambda2.toFloat, momentum.toFloat, result, - flatFeatures) - return loss + vector) + loss } def updateHinge(model : MaxoutModel, @@ -219,8 +215,8 @@ object MaxoutTrainer { dropout : Double, dropoutHidden : Double, momentum : Double) : Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, dropout) - val response = model.getResponse(flatFeatures) + val vector = fv.withDropout(dropout) + val response = model.getResponse(vector) val values = response.getValues for (i <- 0 until values.length) { if (scala.util.Random.nextDouble() < dropoutHidden) { @@ -238,9 +234,9 @@ object MaxoutTrainer { lambda2.toFloat, momentum.toFloat, result, - flatFeatures) + vector) } - return loss + loss } def saveModel(model : MaxoutModel, @@ -263,8 +259,9 @@ object MaxoutTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) saveModel(model, config, key) } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/MlpModelTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/MlpModelTrainer.scala index 55edb5d2..e791a35f 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/MlpModelTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/MlpModelTrainer.scala @@ -1,17 +1,16 @@ package com.airbnb.aerosolve.training -import com.airbnb.aerosolve.core.{FeatureVector, Example, FunctionForm} +import com.airbnb.aerosolve.core.features.{Family, Feature, FeatureRegistry, MultiFamilyVector} import com.airbnb.aerosolve.core.models.MlpModel - import com.airbnb.aerosolve.core.util.FloatVector -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.{Example, FunctionForm} import com.typesafe.config.Config -import org.slf4j.{LoggerFactory, Logger} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.util.Try /** @@ -29,7 +28,7 @@ object MlpModelTrainer { iteration: Int, // number of iterations to run subsample : Double, // determine mini-batch size threshold : Double, // threshold for binary classification - rankKey: String, + labelFamily : Family, learningRateInit : Double, // initial learning rate learningRateDecay : Double, // learning rate decay rate momentumInit : Double, // initial momentum value @@ -40,7 +39,8 @@ object MlpModelTrainer { weightDecay : Double, // l2 regularization parameter weightInitStd : Double, // weight initialization std cache : String, - minCount : Int + minCount : Int, + registry: FeatureRegistry ) case class NetWorkParams(activationFunctions: java.util.ArrayList[FunctionForm], @@ -49,13 +49,14 @@ object MlpModelTrainer { def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : MlpModel = { - val trainerOptions = parseTrainingOptions(config.getConfig(key)) + key : String, + registry: FeatureRegistry) : MlpModel = { + val trainerOptions = parseTrainingOptions(config.getConfig(key), registry) val networkOptions = parseNetworkOptions(config.getConfig(key)) val raw : RDD[Example] = LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, trainerOptions.registry) val pointwise = trainerOptions.cache match { case "memory" => raw.cache() @@ -116,20 +117,19 @@ object MlpModelTrainer { def computeGradient(sc : SparkContext, options : TrainerOptions, model : MlpModel, - miniBatch : RDD[Example]) : Map[(String, String), FloatVector] = { + miniBatch : RDD[Example]) : Map[Feature, FloatVector] = { // compute the sum of gradient of examples in the mini-batch val modelBC = sc.broadcast(model) miniBatch .mapPartitions(partition => { val model = modelBC.value - val gradient = scala.collection.mutable.HashMap[(String, String), FloatVector]() + val gradient = mutable.HashMap[Feature, FloatVector]() partition.foreach(example => { - val fv = example.example.get(0) - val flatFeatures: java.util.Map[String, java.util.Map[java.lang.String, java.lang.Double]] = Util.flattenFeature(fv) + val fv = example.only val score = if (options.dropout > 0) { - model.forwardPropagationWithDropout(flatFeatures, options.dropout) + model.forwardPropagationWithDropout(fv, options.dropout) } else { - model.forwardPropagation(flatFeatures) + model.forwardPropagation(fv) } val grad = options.loss match { case "hinge" => computeHingeGradient(score, fv, options) @@ -138,13 +138,13 @@ object MlpModelTrainer { } // back-propagation for updating gradient // note: activations have been computed in forwardPropagation - val outputLayerId = model.getNumHiddenLayers - val func = model.getActivationFunction.get(outputLayerId) + val outputLayerId = model.numHiddenLayers + val func = model.activationFunction.get(outputLayerId) // delta: gradient of loss function w.r.t. node input // activation: the output of a node val outputNodeDelta = computeActivationGradient(score, func) * grad - backPropagation(model, outputNodeDelta.toFloat, gradient, flatFeatures, options.weightDecay.toFloat) + backPropagation(model, outputNodeDelta.toFloat, gradient, fv, options.weightDecay.toFloat) }) gradient.iterator }) @@ -157,46 +157,44 @@ object MlpModelTrainer { x._1.scale(1.0f / x._2.toFloat) x._1 }) - .collectAsMap + .collectAsMap() .toMap } def backPropagation(model: MlpModel, outputNodeDelta: Float, - gradient: scala.collection.mutable.HashMap[(String, String), FloatVector], - flatFeatures: java.util.Map[String, java.util.Map[java.lang.String, java.lang.Double]], - weightDecay: Float = 0.0f) = { + gradient: mutable.Map[Feature, FloatVector], + vector : MultiFamilyVector, + weightDecay: Double = 0.0d) = { // outputNodeDelta: gradient of the loss function w.r.t the input of the output node - val numHiddenLayers = model.getNumHiddenLayers - val layerNodeNumber = model.getLayerNodeNumber - val activationFunctions = model.getActivationFunction // set delta for the output layer var upperLayerDelta = new FloatVector(1) upperLayerDelta.set(0, outputNodeDelta) // compute gradient for bias at the output node val outputBiasGrad = new FloatVector(1) outputBiasGrad.set(0, outputNodeDelta) - val outputBiasKey = (LAYER_PREFIX + numHiddenLayers.toString, BIAS_PREFIX) + val outputBiasKey = vector.registry.feature( + LAYER_PREFIX + model.numHiddenLayers.toString, BIAS_PREFIX) outputBiasGrad.add(gradient.getOrElse(outputBiasKey, new FloatVector(1))) gradient.put(outputBiasKey, outputBiasGrad) // update for hidden layers - for (i <- (0 until numHiddenLayers).reverse) { + for (i <- (0 until model.numHiddenLayers).reverse) { // i decreases from numHiddenLayers-1 to 0 - val numNode = layerNodeNumber.get(i) - val numNodeUpperLayer = layerNodeNumber.get(i + 1) - val func = activationFunctions.get(i) + val numNode = model.layerNodeNumber.get(i) + val numNodeUpperLayer = model.layerNodeNumber.get(i + 1) + val func = model.activationFunction.get(i) val thisLayerDelta = new FloatVector(numNode) // compute gradient of weights from the i-th layer to the (i+1)-th layer - val activations = model.getLayerActivations.get(i) - val hiddenLayerWeights = model.getHiddenLayerWeights.get(i) - val biasKey = (LAYER_PREFIX + i.toString, BIAS_PREFIX) + val activations = model.layerActivations.get(i) + val hiddenLayerWeights = model.hiddenLayerWeights.get(i) + val biasKey = vector.registry.feature(LAYER_PREFIX + i.toString, BIAS_PREFIX) val biasGrad = gradient.getOrElse(biasKey, new FloatVector(numNode)) for (j <- 0 until numNode) { - val key = (LAYER_PREFIX + i.toString, NODE_PREFIX + j.toString) + val key = vector.registry.feature(LAYER_PREFIX + i.toString, NODE_PREFIX + j.toString) val gradFv = gradient.getOrElse(key, new FloatVector(numNodeUpperLayer)) gradFv.multiplyAdd(activations.get(j), upperLayerDelta) if (weightDecay > 0.0f) { - val weight = model.getHiddenLayerWeights.get(i).get(j) + val weight = model.hiddenLayerWeights.get(i).get(j) gradFv.multiplyAdd(weightDecay, weight) } gradient.put(key, gradFv) @@ -207,53 +205,49 @@ object MlpModelTrainer { thisLayerDelta.set(j, delta.toFloat) } biasGrad.add(thisLayerDelta) - if (weightDecay > 0.0f) { - biasGrad.multiplyAdd(weightDecay, model.getBias.get(i)) + if (weightDecay > 0.0d) { + biasGrad.multiplyAdd(weightDecay, model.bias.get(i)) } gradient.put(biasKey, biasGrad) upperLayerDelta = thisLayerDelta } - val inputLayerWeights = model.getInputLayerWeights // update for the input layer - val numNodeUpperLayer = layerNodeNumber.get(0) - for (family <- flatFeatures) { - for (feature <- family._2) { - val key = (family._1, feature._1) - // We only care about features in the model. - if (inputLayerWeights.containsKey(key._1) && inputLayerWeights.get(key._1).containsKey(key._2)) { - val gradFv = gradient.getOrElse(key, new FloatVector(numNodeUpperLayer)) - gradFv.multiplyAdd(feature._2.toFloat, upperLayerDelta) - if (weightDecay > 0.0f) { - val weight = inputLayerWeights.get(key._1).get(key._2) - gradFv.multiplyAdd(weightDecay, weight) - } - gradient.put(key, gradFv) + val numNodeUpperLayer = model.layerNodeNumber.get(0) + for (fv <- vector.iterator) { + val key = fv.feature + // We only care about features in the model. + if (model.inputLayerWeights.containsKey(key)) { + val gradFv = gradient.getOrElse(key, new FloatVector(numNodeUpperLayer)) + gradFv.multiplyAdd(fv.value, upperLayerDelta) + if (weightDecay > 0.0d) { + val weight = model.inputLayerWeights.get(key) + gradFv.multiplyAdd(weightDecay, weight) } + gradient.put(key, gradFv) } } } def updateModel(model: MlpModel, - gradientContainer: Map[(String, String), FloatVector], - updateContainer: scala.collection.mutable.HashMap[(String, String), FloatVector], - momentum: Float, - learningRate: Float, + gradientContainer: Map[Feature, FloatVector], + updateContainer: mutable.Map[Feature, FloatVector], + momentum: Double, + learningRate: Double, dropout: Double) = { // computing current updates based on previous updates and new gradient // then update model weights (also update the prevUpdateContainer) - val numHiddenLayers = model.getNumHiddenLayers for ((key, prevUpdate) <- updateContainer) { - val weightToUpdate : FloatVector = if (key._1.startsWith(LAYER_PREFIX)) { - val layerId: Int = key._1.substring(LAYER_PREFIX.length).toInt - assert(layerId >= 0 && layerId <= numHiddenLayers) - if (key._2.equals(BIAS_PREFIX)) { + val weightToUpdate : FloatVector = if (key.family.name.startsWith(LAYER_PREFIX)) { + val layerId: Int = key.family.name.substring(LAYER_PREFIX.length).toInt + assert(layerId >= 0 && layerId <= model.numHiddenLayers) + if (key.name.equals(BIAS_PREFIX)) { // node bias updates - model.getBias.get(layerId) - } else if (key._2.startsWith(NODE_PREFIX)) { - val nodeId = key._2.substring(NODE_PREFIX.length).toInt + model.bias.get(layerId) + } else if (key.name.startsWith(NODE_PREFIX)) { + val nodeId = key.name.substring(NODE_PREFIX.length).toInt // hidden layer weight updates - model.getHiddenLayerWeights.get(layerId).get(nodeId) + model.hiddenLayerWeights.get(layerId).get(nodeId) } else { // error assert(false) @@ -261,9 +255,9 @@ object MlpModelTrainer { } } else { // input layer weight updates - val inputLayerWeight = model.getInputLayerWeights.get(key._1) + val inputLayerWeight = model.inputLayerWeights.get(key) if (inputLayerWeight != null) { - inputLayerWeight.get(key._2) + inputLayerWeight } else { new FloatVector() } @@ -282,19 +276,21 @@ object MlpModelTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } - private def parseTrainingOptions(config : Config) : TrainerOptions = { + private def parseTrainingOptions(config : Config, + registry: FeatureRegistry) : TrainerOptions = { TrainerOptions( loss = config.getString("loss"), margin = config.getDouble("margin"), iteration = config.getInt("iterations"), subsample = config.getDouble("subsample"), threshold = Try(config.getDouble("rank_threshold")).getOrElse(0.0), - rankKey = config.getString("rank_key"), + labelFamily = registry.family(config.getString("rank_key")), learningRateInit = config.getDouble("learning_rate_init"), learningRateDecay = Try(config.getDouble("learning_rate_decay")).getOrElse(1.0), momentumInit = Try(config.getDouble("momentum_init")).getOrElse(0.0), @@ -305,7 +301,8 @@ object MlpModelTrainer { weightDecay = Try(config.getDouble("weight_decay")).getOrElse(0.0), weightInitStd = config.getDouble("weight_init_std"), cache = Try(config.getString("cache")).getOrElse(""), - minCount = Try(config.getInt("min_count")).getOrElse(0) + minCount = Try(config.getInt("min_count")).getOrElse(0), + registry = registry ) } @@ -332,63 +329,49 @@ object MlpModelTrainer { pointwise : RDD[Example]) : MlpModel = { val model = new MlpModel( networkOptions.activationFunctions, - networkOptions.nodeNumber) - val hiddenLayerWeights = model.getHiddenLayerWeights - val inputLayerWeights = model.getInputLayerWeights - val layerNodeNumber = model.getLayerNodeNumber - val numHiddenLayers = model.getNumHiddenLayers + networkOptions.nodeNumber, + trainerOptions.registry) val std = trainerOptions.weightInitStd.toFloat val stats = TrainingUtils.getFeatureStatistics(trainerOptions.minCount, pointwise) // set up input layer weights var count : Int = 0 - for (kv <- stats) { - val (family, feature) = kv._1 - if (family != trainerOptions.rankKey) { - if (!inputLayerWeights.containsKey(family)) { - inputLayerWeights.put(family, new java.util.HashMap[java.lang.String, FloatVector]()) - } - val familyMap = inputLayerWeights.get(family) - if (!familyMap.containsKey(feature)) { - count = count + 1 - familyMap.put(feature, FloatVector.getGaussianVector(layerNodeNumber.get(0), std)) - } + for ((feature, featureStats) <- stats) { + if (feature.family != trainerOptions.labelFamily && !model.inputLayerWeights.containsKey(feature)) { + count = count + 1 + model.inputLayerWeights.put(feature, FloatVector.getGaussianVector(model.layerNodeNumber.get(0), std)) } } // set up hidden layer weights - for (i <- 0 until numHiddenLayers) { + for (i <- 0 until model.numHiddenLayers) { val arr = new java.util.ArrayList[FloatVector]() - for (j <- 0 until layerNodeNumber.get(i)) { - val fv = FloatVector.getGaussianVector(layerNodeNumber.get(i + 1), std) + for (j <- 0 until model.layerNodeNumber.get(i)) { + val fv = FloatVector.getGaussianVector(model.layerNodeNumber.get(i + 1), std) arr.add(fv) } - hiddenLayerWeights.put(i, arr) + model.hiddenLayerWeights.put(i, arr) } // note: bias at each node initialized to zero in this trainer - log.info(s"Total number of features is $count") + log.info(s"Total number of inputFeatures is $count") model } - private def setupUpdateContainer(model: MlpModel) : scala.collection.mutable.HashMap[(String, String), FloatVector] = { - val container = scala.collection.mutable.HashMap[(String, String), FloatVector]() + private def setupUpdateContainer(model: MlpModel) : mutable.Map[Feature, FloatVector] = { + val container = mutable.HashMap[Feature, FloatVector]() // set up input layer weights gradient - val inputLayerWeights = model.getInputLayerWeights - val n0 = model.getLayerNodeNumber.get(0) - for (family <- inputLayerWeights) { - for (feature <- family._2) { - val key = (family._1, feature._1) - container.put(key, new FloatVector(n0)) - } + val n0 = model.layerNodeNumber.get(0) + for (feature <- model.inputLayerWeights.keySet) { + container.put(feature, new FloatVector(n0)) } // set up hidden layer weights gradient - val numHiddenLayers = model.getNumHiddenLayers + val numHiddenLayers = model.numHiddenLayers for (i <- 0 until numHiddenLayers) { - val thisLayerNodeNum = model.getLayerNodeNumber.get(i) - val nextLayerNodeNum = model.getLayerNodeNumber.get(i + 1) + val thisLayerNodeNum = model.layerNodeNumber.get(i) + val nextLayerNodeNum = model.layerNodeNumber.get(i + 1) for (j <- 0 until thisLayerNodeNum) { - val key = (LAYER_PREFIX + i.toString, NODE_PREFIX + j.toString) + val key = model.registry.feature(LAYER_PREFIX + i.toString, NODE_PREFIX + j.toString) container.put(key, new FloatVector(nextLayerNodeNum)) } } @@ -396,16 +379,16 @@ object MlpModelTrainer { // set up bias gradient for (i <- 0 to numHiddenLayers) { // all bias in the same layer are put to the same FloatVector - val key = (LAYER_PREFIX + i.toString, BIAS_PREFIX) - container.put(key, new FloatVector(model.getLayerNodeNumber.get(i))) + val key = model.registry.feature(LAYER_PREFIX + i.toString, BIAS_PREFIX) + container.put(key, new FloatVector(model.layerNodeNumber.get(i))) } container } private def computeUpdates(prevUpdate: FloatVector, - momentum: Float, - learningRate: Float, + momentum: Double, + learningRate: Double, gradient: FloatVector): FloatVector = { // based on hinton's dropout paper: http://arxiv.org/pdf/1207.0580.pdf val update: FloatVector = new FloatVector(prevUpdate.length) @@ -435,11 +418,11 @@ object MlpModelTrainer { } private def computeHingeGradient(prediction: Double, - fv: FeatureVector, + fv: MultiFamilyVector, option: TrainerOptions): Double = { // Returns d_loss / d_output_activation // gradient of loss function w.r.t the output node activation - val label = TrainingUtils.getLabel(fv, option.rankKey, option.threshold) + val label = TrainingUtils.getLabel(fv, option.labelFamily, option.threshold) // loss = max(0.0, option.margin - label * prediction) if (option.margin - label * prediction > 0) { -label @@ -449,13 +432,13 @@ object MlpModelTrainer { } private def computeRegressionGradient(prediction: Double, - fv: FeatureVector, + fv: MultiFamilyVector, option: TrainerOptions): Double = { // epsilon-insensitive loss for regression (as in SVM regression) // loss = max(0.0, |prediction - label| - epsilon) // where epsilon = option.margin assert(option.margin > 0) - val label = TrainingUtils.getLabel(fv, option.rankKey) + val label = TrainingUtils.getLabel(fv, option.labelFamily) if (prediction - label > option.margin) { 1.0 } else if (prediction - label < - option.margin) { diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/NDTree.scala b/training/src/main/scala/com/airbnb/aerosolve/training/NDTree.scala index 62a3d692..e75348b6 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/NDTree.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/NDTree.scala @@ -309,14 +309,14 @@ class NDTree(val nodes: Array[NDTreeNode]) extends Serializable { builder.append("%d,".format(current)) // TODO: instead of converting entire collection to Scala, use a Java to Scala foreach iterator - node.min.asScala.map(_.doubleValue()).foreach((minimum: Double) => { + node.getMin.asScala.map(_.doubleValue()).foreach((minimum: Double) => { builder.append("%f,".format(minimum)) }) - node.max.asScala.map(_.doubleValue()).foreach((maximum: Double) => { + node.getMax.asScala.map(_.doubleValue()).foreach((maximum: Double) => { builder.append("%f,".format(maximum)) }) - builder.append("%d".format(node.count)) + builder.append("%d".format(node.getCount)) if (parent < 0) { builder.append(",") @@ -324,22 +324,22 @@ class NDTree(val nodes: Array[NDTreeNode]) extends Serializable { builder.append(",%d".format(parent)) } - if (node.axisIndex == NDTreeModel.LEAF) { + if (node.getAxisIndex == NDTreeModel.LEAF) { builder.append(",TRUE,,,,") } else { builder.append(",FALSE,%d,%d,%d,%f".format( - node.leftChild, - node.rightChild, - node.axisIndex, - node.splitValue + node.getLeftChild, + node.getRightChild, + node.getAxisIndex, + node.getSplitValue )) } csv.append(builder.toString) - if (node.axisIndex != NDTreeModel.LEAF) { - getCSVRecursive(node.leftChild, current, csv) - getCSVRecursive(node.rightChild, current, csv) + if (node.getAxisIndex != NDTreeModel.LEAF) { + getCSVRecursive(node.getLeftChild, current, csv) + getCSVRecursive(node.getRightChild, current, csv) } } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/SplineTrainer.scala b/training/src/main/scala/com/airbnb/aerosolve/training/SplineTrainer.scala index 22c9b4d3..43656a24 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/SplineTrainer.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/SplineTrainer.scala @@ -1,5 +1,6 @@ package com.airbnb.aerosolve.training +import com.airbnb.aerosolve.core.features._ import com.airbnb.aerosolve.core.models.SplineModel import com.airbnb.aerosolve.core.models.SplineModel.WeightSpline import com.airbnb.aerosolve.core.util.Util @@ -13,7 +14,7 @@ import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap +import scala.collection.mutable import scala.util.{Random, Try} /* @@ -26,7 +27,7 @@ object SplineTrainer { case class SplineTrainerParams( numBins : Int, numBags : Int, - rankKey : String, + labelFamily : Family, loss : String, learningRate : Double, dropout : Double, @@ -40,13 +41,15 @@ object SplineTrainer { rankFraction : Double, // Fraction of time to use ranking loss when loss is rank_and_hinge rankMargin : Double, // The margin for ranking loss maxSamplesPerExample : Int, // Max number of samples to use per example - epsilon : Double // epsilon used in epsilon-insensitive loss for regression training + epsilon : Double, // epsilon used in epsilon-insensitive loss for regression training + registry : FeatureRegistry ) def train(sc : SparkContext, input : RDD[Example], config : Config, - key : String) : SplineModel = { + key : String, + registry: FeatureRegistry) : SplineModel = { val loss : String = config.getString(key + ".loss") val isRanking = loss match { case "rank_and_hinge" => true @@ -62,7 +65,7 @@ object SplineTrainer { val numBins : Int = config.getInt(key + ".num_bins") val numBags : Int = config.getInt(key + ".num_bags") val iterations : Int = config.getInt(key + ".iterations") - val rankKey : String = config.getString(key + ".rank_key") + val labelFamily : Family = registry.family(config.getString(key + ".rank_key")) val learningRate : Double = config.getDouble(key + ".learning_rate") val dropout : Double = config.getDouble(key + ".dropout") val minCount : Int = config.getInt(key + ".min_count") @@ -93,7 +96,7 @@ object SplineTrainer { val params = SplineTrainerParams( numBins = numBins, numBags = numBags, - rankKey = rankKey, + labelFamily = labelFamily, loss = loss, learningRate = learningRate, dropout = dropout, @@ -107,30 +110,31 @@ object SplineTrainer { rankFraction = rankFraction, rankMargin = rankMargin, maxSamplesPerExample = maxSamplesPerExample, - epsilon = epsilon) + epsilon = epsilon, + registry = registry) val transformed : RDD[Example] = if (isRanking) { - LinearRankerUtils.transformExamples(input, config, key) + LinearRankerUtils.transformExamples(input, config, key, registry) } else { LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, registry) } val initialModel = if(initModelPath == "") { None } else { - TrainingUtils.loadScoreModel(initModelPath) + TrainingUtils.loadScoreModel(initModelPath, registry) } var model = if(initialModel.isDefined) { val newModel = initialModel.get.asInstanceOf[SplineModel] newModel.setSplineNormCap(linfinityCap.toFloat) - initModel(minCount, subsample, rankKey, transformed, newModel, false) + initModel(minCount, subsample, labelFamily, transformed, newModel, false) newModel } else { - val newModel = new SplineModel() + val newModel = new SplineModel(registry) newModel.initForTraining(numBins) newModel.setSplineNormCap(linfinityCap.toFloat) - initModel(minCount, subsample, rankKey, transformed, newModel, true) + initModel(minCount, subsample, labelFamily, transformed, newModel, true) setPrior(config, key, newModel) newModel } @@ -162,17 +166,17 @@ object SplineTrainer { // Initializes the model def initModel(minCount : Int, subsample : Double, - rankKey : String, + labelFamily : Family, input : RDD[Example], model : SplineModel, overwrite : Boolean) = { log.info("Computing min/max values for all features") val stats = TrainingUtils .getFeatureStatistics(minCount, input.sample(false, subsample)) - .filter(x => x._1._1 != rankKey) + .filter(x => x._1.family != labelFamily) log.info("Num features = %d".format(stats.length)) for (entry <- stats) { - model.addSpline(entry._1._1, entry._1._2, entry._2.min.toFloat, entry._2.max.toFloat, overwrite) + model.addSpline(entry._1, entry._2.min.toFloat, entry._2.max.toFloat, overwrite) } } @@ -188,16 +192,14 @@ object SplineTrainer { val name = tokens(1) val start = tokens(2).toDouble val end = tokens(3).toDouble - val familyMap = model.getWeightSpline.asScala.get(family) - if (familyMap != None) { - val spline = familyMap.get.get(name) - if (spline != null) { - log.info("Setting prior %s:%s <- %f to %f".format(family, name, start, end)) - val len = spline.splineWeights.length - for (i <- 0 until len) { - val t = i.toDouble / (len.toDouble - 1.0) - spline.splineWeights(i) = ((1.0 - t) * start + t * end).toFloat - } + val feature = model.registry.feature(family, name) + val weightSpline = model.getWeightSpline.asScala.get(feature) + if (weightSpline.isDefined) { + log.info("Setting prior %s:%s <- %f to %f".format(family, name, start, end)) + val len = weightSpline.get.splineWeights.length + for (i <- 0 until len) { + val t = i.toDouble / (len.toDouble - 1.0) + weightSpline.get.splineWeights(i) = ((1.0 - t) * start + t * end).toFloat } } } else { @@ -282,13 +284,11 @@ object SplineTrainer { .coalesce(params.numBags, true) .mapPartitions(partition => sgdPartition(partition, modelBC, params)) - .groupByKey + .groupByKey() // Average the spline weights .map(x => { val head = x._2.head - val spline = new WeightSpline(head.spline.getMinVal, - head.spline.getMaxVal, - params.numBins) + val spline = new WeightSpline(head.spline.getMinVal, head.spline.getMaxVal, params.numBins) val scale = 1.0f / params.numBags.toFloat x._2.foreach(entry => { for (i <- 0 until params.numBins) { @@ -298,18 +298,15 @@ object SplineTrainer { smoothSpline(params.smoothingTolerance, spline) (x._1, spline) }) - .collect - .foreach(entry => { - val family = model.getWeightSpline.get(entry._1._1) - if (family != null && family.containsKey(entry._1._2)) { - family.put(entry._1._2, entry._2) - } - }) + .collect() + .foreach{ case (feature, spline) => + model.getWeightSpline.replace(feature, spline) + } deleteSmallSplines(model, params.linfinityThreshold) TrainingUtils.saveModel(model, config, key + ".model_output") - return model + model } def sgdMultiscaleTrain(sc : SparkContext, @@ -330,13 +327,11 @@ object SplineTrainer { .mapPartitionsWithIndex((index, partition) => sgdPartitionMultiscale(index, partition, multiscale, modelBC, params)) - .groupByKey + .groupByKey() // Average the spline weights .map(x => { val head = x._2.head - val spline = new WeightSpline(head.spline.getMinVal, - head.spline.getMaxVal, - params.numBins) + val spline = new WeightSpline(head.spline.getMinVal, head.spline.getMaxVal, params.numBins) val scale = 1.0f / params.numBags.toFloat x._2.foreach(entry => { entry.resample(params.numBins) @@ -347,48 +342,34 @@ object SplineTrainer { smoothSpline(params.smoothingTolerance, spline) (x._1, spline) }) - .collect - .foreach(entry => { - val family = model.getWeightSpline.get(entry._1._1) - if (family != null && family.containsKey(entry._1._2)) { - family.put(entry._1._2, entry._2) - } - }) + .collect() + .foreach { case (feature, spline) => + model.getWeightSpline.replace(feature, spline) + } deleteSmallSplines(model, params.linfinityThreshold) TrainingUtils.saveModel(model, config, key + ".model_output") - return model + model } def deleteSmallSplines(model : SplineModel, linfinityThreshold : Double) = { - val toDelete = scala.collection.mutable.ArrayBuffer[(String, String)]() + val toDelete = mutable.ArrayBuffer[Feature]() - model.getWeightSpline.asScala.foreach(family => { - family._2.asScala.foreach(entry => { - if (entry._2.LInfinityNorm < linfinityThreshold) { - toDelete.append((family._1, entry._1)) - } - }) - }) + model.getWeightSpline.asScala.foreach { case (feature, spline) => + if (spline.LInfinityNorm < linfinityThreshold) toDelete.append(feature) + } log.info("Deleting %d empty splines".format(toDelete.size)) - toDelete.foreach(entry => { - val family = model.getWeightSpline.get(entry._1) - if (family != null && family.containsKey(entry._2)) { - family.remove(entry._2) - } - }) + toDelete.foreach(feature => model.getWeightSpline.remove(feature)) } def sgdPartition(partition : Iterator[Example], modelBC : Broadcast[SplineModel], params : SplineTrainerParams) = { - val workingModel = modelBC.value - val output = sgdPartitionInternal(partition, workingModel, params) - output.iterator + sgdPartitionInternal(partition, modelBC.value, params).iterator } def sgdPartitionMultiscale( @@ -404,12 +385,9 @@ object SplineTrainer { log.info("Resampling to %d bins".format(newBins)) workingModel .getWeightSpline - .foreach(family => { - family._2.foreach(feature => { - feature._2.resample(newBins) - }) - }) - + .values + .foreach(_.resample(newBins)) + val output = sgdPartitionInternal(partition, workingModel, params) output.iterator } @@ -417,20 +395,20 @@ object SplineTrainer { def sgdPartitionInternal(partition : Iterator[Example], workingModel : SplineModel, params : SplineTrainerParams) : - HashMap[(String, String), SplineModel.WeightSpline] = { + mutable.Map[Feature, SplineModel.WeightSpline] = { @volatile var lossSum : Double = 0.0 @volatile var lossCount : Int = 0 partition.foreach(example => { if (params.isRanking) { // Since this is SGD we don't want to over sample from one example // but we also want to make good use of the example already in RAM - val count = scala.math.min(params.maxSamplesPerExample, example.example.size) + val count = scala.math.min(params.maxSamplesPerExample, example.size) for (i <- 0 until count) { lossSum += rankAndHingeLoss(example, workingModel, params) lossCount = lossCount + 1 } } else { - lossSum += pointwiseLoss(example.example.get(0), workingModel, params.loss, params) + lossSum += pointwiseLoss(example.only, workingModel, params.loss, params) lossCount = lossCount + 1 } if (lossCount % params.lossMod == 0) { @@ -438,37 +416,28 @@ object SplineTrainer { lossSum = 0.0 } }) - val output = HashMap[(String, String), SplineModel.WeightSpline]() - workingModel - .getWeightSpline - .foreach(family => { - family._2.foreach(feature => { - output.put((family._1, feature._1), feature._2) - }) - }) - output + workingModel.getWeightSpline } def rankAndHingeLoss(example : Example, workingModel : SplineModel, params : SplineTrainerParams) : Double = { - val count = example.example.size + val count = example.size val idx1 = Random.nextInt(count) - val fv1 = example.example.get(idx1) + val exampleSeq: Seq[MultiFamilyVector] = example.asScala.toSeq + val fv1 = exampleSeq(idx1) var doHinge : Boolean = false var loss : Double = 0.0 if (Random.nextDouble() < params.rankFraction) { - val label1 = TrainingUtils.getLabel(fv1, params.rankKey, params.threshold) + val label1 = TrainingUtils.getLabel(fv1, params.labelFamily, params.threshold) val idx2 = pickCounterExample(example, idx1, label1, count, params) if (idx2 >= 0) { - val fv2 = example.example.get(idx2) - val label2 = TrainingUtils.getLabel(fv2, params.rankKey, params.threshold) + val fv2 = exampleSeq(idx2) + val label2 = TrainingUtils.getLabel(fv2, params.labelFamily, params.threshold) // Can't do dropout for ranking loss since we are relying on difference of features. - val flatFeatures1 = Util.flattenFeature(fv1) - val prediction1 = workingModel.scoreFlatFeatures(flatFeatures1) - val flatFeatures2 = Util.flattenFeature(fv2) - val prediction2 = workingModel.scoreFlatFeatures(flatFeatures2) + val prediction1 = workingModel.scoreItem(fv1) + val prediction2 = workingModel.scoreItem(fv2) if (label1 > label2) { loss = scala.math.max(0.0, params.rankMargin - prediction1 + prediction2) } else { @@ -478,10 +447,10 @@ object SplineTrainer { if (loss > 0) { workingModel.update(-label1.toFloat, params.learningRate.toFloat, - flatFeatures1) + fv1) workingModel.update(-label2.toFloat, params.learningRate.toFloat, - flatFeatures2) + fv1) } } else { // No counter example. @@ -497,7 +466,7 @@ object SplineTrainer { "hinge", params) } - return loss + loss } // Picks the first random counter example to idx1 @@ -507,27 +476,28 @@ object SplineTrainer { count : Int, params : SplineTrainerParams) : Int = { val shuffle = Random.shuffle((0 until count).toBuffer) - + val exampleSeq = example.asScala.toSeq + for (idx2 <- shuffle) { if (idx2 != idx1) { val label2 = TrainingUtils.getLabel( - example.example.get(idx2), params.rankKey, params.threshold) + exampleSeq(idx2), params.labelFamily, params.threshold) if (label2 != label1) { return idx2 } } } - return -1; + -1 } - def pointwiseLoss(fv : FeatureVector, + def pointwiseLoss(fv : MultiFamilyVector, workingModel : SplineModel, loss : String, params : SplineTrainerParams) : Double = { val label: Double = if (loss == "regression") { - TrainingUtils.getLabel(fv, params.rankKey) + TrainingUtils.getLabel(fv, params.labelFamily) } else { - TrainingUtils.getLabel(fv, params.rankKey, params.threshold) + TrainingUtils.getLabel(fv, params.labelFamily, params.threshold) } loss match { @@ -541,11 +511,13 @@ object SplineTrainer { // We rescale by 1 / p so that at inference time we don't have to scale by p. // In our case p = 1.0 - dropout rate def updateLogistic(model : SplineModel, - fv : FeatureVector, + fv : MultiFamilyVector, label : Double, params : SplineTrainerParams) : Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - val prediction = model.scoreFlatFeatures(flatFeatures) / (1.0 - params.dropout) + // TODO (Brad): withFamilyDropout will randomly drop out some DenseVectors but we shouldn't + // have any. + val flatFeatures = fv.withFamilyDropout(params.dropout) + val prediction = model.scoreItem(flatFeatures) / (1.0 - params.dropout) // To prevent blowup. val corr = scala.math.min(10.0, label * prediction) val expCorr = scala.math.exp(corr) @@ -558,11 +530,11 @@ object SplineTrainer { } def updateHinge(model : SplineModel, - fv : FeatureVector, + fv : MultiFamilyVector, label : Double, params : SplineTrainerParams) : Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - val prediction = model.scoreFlatFeatures(flatFeatures) / (1.0 - params.dropout) + val flatFeatures = fv.withFamilyDropout(params.dropout) + val prediction = model.scoreItem(flatFeatures) / (1.0 - params.dropout) val loss = scala.math.max(0.0, params.margin - label * prediction) if (loss > 0.0) { val grad = -label @@ -574,11 +546,11 @@ object SplineTrainer { } def updateRegressor(model: SplineModel, - fv: FeatureVector, + fv: MultiFamilyVector, label: Double, params : SplineTrainerParams) : Double = { - val flatFeatures = Util.flattenFeatureWithDropout(fv, params.dropout) - val prediction = model.scoreFlatFeatures(flatFeatures) / (1.0 - params.dropout) + val flatFeatures = fv.withFamilyDropout(params.dropout) + val prediction = model.scoreItem(flatFeatures) / (1.0 - params.dropout) val loss = math.abs(prediction - label) // absolute difference if (prediction - label > params.epsilon) { model.update(1.0f, params.learningRate.toFloat, flatFeatures) @@ -591,8 +563,9 @@ object SplineTrainer { def trainAndSaveToFile(sc : SparkContext, input : RDD[Example], config : Config, - key : String) = { - val model = train(sc, input, config, key) + key : String, + registry: FeatureRegistry) = { + val model = train(sc, input, config, key, registry) TrainingUtils.saveModel(model, config, key + ".model_output") } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/TrainingUtils.scala b/training/src/main/scala/com/airbnb/aerosolve/training/TrainingUtils.scala index dfa1f084..d07e4055 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/TrainingUtils.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/TrainingUtils.scala @@ -4,12 +4,13 @@ import java.io.BufferedWriter import java.io.InputStreamReader import java.io.BufferedReader import java.io.OutputStreamWriter +import java.lang import java.net.URI import java.util.concurrent.ConcurrentHashMap import com.airbnb.aerosolve.core.Example import com.airbnb.aerosolve.core.FeatureVector -import com.airbnb.aerosolve.core.util.Util +import com.airbnb.aerosolve.core.features._ import com.airbnb.aerosolve.core.util.StringDictionary import com.airbnb.aerosolve.core.models.AbstractModel import com.airbnb.aerosolve.core.models.ModelFactory @@ -21,12 +22,11 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.slf4j.{LoggerFactory, Logger} +import scala.collection.{Map, mutable, JavaConversions, JavaConverters} import scala.collection.JavaConverters._ import scala.collection.JavaConversions._ import com.typesafe.config.Config -import scala.collection.mutable.ArrayBuffer - object TrainingUtils { val log: Logger = LoggerFactory.getLogger("TrainingUtils") @@ -82,11 +82,11 @@ object TrainingUtils { writer.close() file.close() } catch { - case _ : Throwable => log.error("Could not save model") + case e : Throwable => log.error("Could not save model", e) } } - def loadScoreModel(modelName: String): Option[AbstractModel] = { + def loadScoreModel(modelName: String, registry:FeatureRegistry): Option[AbstractModel] = { val fs = FileSystem.get(new URI(modelName), hadoopConfiguration) val modelPath = new Path(modelName) if (!fs.exists(modelPath)) { @@ -95,12 +95,12 @@ object TrainingUtils { } val modelStream = fs.open(modelPath) val reader = new BufferedReader(new InputStreamReader(modelStream)) - val modelOpt = ModelFactory.createFromReader(reader) + val modelOpt = ModelFactory.createFromReader(reader, registry) if (!modelOpt.isPresent) { return None } val model = modelOpt.get() - return Some(model) + Some(model) } def loadPlattScaleWeights(filePath: String): Option[Array[Double]] = { @@ -115,15 +115,15 @@ object TrainingUtils { for (weight <- calibrationModelStr.split(" ")) calibrationWeights += weight.toDouble - return Some(calibrationWeights.toArray) + Some(calibrationWeights.toArray) } def debugScore(example: Example, model: AbstractModel, transformer: Transformer) = { - transformer.combineContextAndItems(example) - for (ex <- example.example.asScala) { + example.transform(transformer) + for (vector <- example) { val builder = new java.lang.StringBuilder() - model.debugScoreItem(example.example.get(0), builder) - val result = builder.toString() + model.debugScoreItem(vector, builder) + val result = builder.toString println(result) } } @@ -170,42 +170,54 @@ object TrainingUtils { def trainAndSaveToFile(sc: SparkContext, input: RDD[Example], config: Config, - key: String) = { + key: String, + registry: FeatureRegistry) = { val trainer: String = config.getString(key + ".trainer") trainer match { - case "linear" => LinearRankerTrainer.trainAndSaveToFile(sc, input, config, key) - case "maxout" => MaxoutTrainer.trainAndSaveToFile(sc, input, config, key) - case "spline" => SplineTrainer.trainAndSaveToFile(sc, input, config, key) - case "boosted_stumps" => BoostedStumpsTrainer.trainAndSaveToFile(sc, input, config, key) - case "decision_tree" => DecisionTreeTrainer.trainAndSaveToFile(sc, input, config, key) - case "forest" => ForestTrainer.trainAndSaveToFile(sc, input, config, key) - case "boosted_forest" => BoostedForestTrainer.trainAndSaveToFile(sc, input, config, key) - case "additive" => AdditiveModelTrainer.trainAndSaveToFile(sc, input, config, key) - case "kernel" => KernelTrainer.trainAndSaveToFile(sc, input, config, key) - case "full_rank_linear" => FullRankLinearTrainer.trainAndSaveToFile(sc, input, config, key) - case "low_rank_linear" => LowRankLinearTrainer.trainAndSaveToFile(sc, input, config, key) - case "mlp" => MlpModelTrainer.trainAndSaveToFile(sc, input, config, key) + case "linear" => LinearRankerTrainer.trainAndSaveToFile(sc, input, config, key, registry) + case "maxout" => MaxoutTrainer.trainAndSaveToFile(sc, input, config, key, registry) + case "spline" => SplineTrainer.trainAndSaveToFile(sc, input, config, key, registry) + case "boosted_stumps" => BoostedStumpsTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "decision_tree" => DecisionTreeTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "forest" => ForestTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "boosted_forest" => BoostedForestTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "additive" => AdditiveModelTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "kernel" => KernelTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "full_rank_linear" => FullRankLinearTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "low_rank_linear" => LowRankLinearTrainer.trainAndSaveToFile(sc, input, config, key, + registry) + case "mlp" => MlpModelTrainer.trainAndSaveToFile(sc, input, config, key, + registry) } } - def getLabel(fv : FeatureVector, rankKey : String, threshold : Double) : Double = { + def getLabel(fv : MultiFamilyVector, labelFamily : Family, threshold : Double) : Double = { // get label for classification - val rank = fv.floatFeatures.get(rankKey).asScala.head._2 - val label = if (rank <= threshold) { + val rank = getLabel(fv, labelFamily) + if (rank <= threshold) { -1.0 } else { 1.0 } - return label } - def getLabelDistribution(fv : FeatureVector, rankKey : String) : Map[String, Double] = { - fv.floatFeatures.get(rankKey).asScala.map(x => (x._1.toString, x._2.toDouble)).toMap + def getLabelDistribution(fv : MultiFamilyVector, labelFamily : Family) : + Map[Feature, Double] = { + JavaConversions.mapAsScalaMap(fv.get(labelFamily)) + .mapValues(d => d.doubleValue()) } - def getLabel(fv : FeatureVector, rankKey : String) : Double = { + def getLabel(fv : MultiFamilyVector, labelFamily : Family) : Double = { // get label for regression - fv.floatFeatures.get(rankKey).asScala.head._2.toDouble + // TODO (Brad): If we were confident about the feature name, this would be simpler. + fv.get(labelFamily).iterator.next.value } // Returns the statistics of a feature @@ -213,34 +225,16 @@ object TrainingUtils { def getFeatureStatistics( minCount : Int, - input : RDD[Example]) : Array[((String, String), FeatureStatistics)] = { + input : RDD[Example]) : Array[(Feature, FeatureStatistics)] = { // ignore features present in less than minCount examples // output: Array[((featureFamily, featureName), (minValue, maxValue))] input - .mapPartitions(partition => { - // family, feature name => count, min, max, sum x, sum x ^ 2 - val weights = new ConcurrentHashMap[(String, String), FeatureStatistics]().asScala - partition.foreach(examples => { - for (i <- 0 until examples.example.size()) { - val flatFeature = Util.flattenFeature(examples.example.get(i)).asScala - flatFeature.foreach(familyMap => { - familyMap._2.foreach(feature => { - val key = (familyMap._1, feature._1) - val curr = weights.getOrElse(key, FeatureStatistics(0, Double.MaxValue, -Double.MaxValue, 0.0, 0.0)) - val v = feature._2 - weights.put(key, - FeatureStatistics(curr.count + 1, - scala.math.min(curr.min, v), - scala.math.max(curr.max, v), - curr.mean + v, // actually the sum - curr.variance + v * v) // actually the sum of squares - ) - }) - }) - } - }) - weights.iterator - }) + .flatMap(example => example + .flatMap((vector:java.lang.Iterable[FeatureValue]) => vector + .map(fv => { + val v = fv.value + (fv.feature(), FeatureStatistics(1, v, v, v, v*v)) + }))) .reduceByKey((a, b) => FeatureStatistics(a.count + b.count, scala.math.min(a.min, b.min), @@ -256,56 +250,46 @@ object TrainingUtils { mean = x._2.mean / x._2.count, variance = (x._2.variance - x._2.mean * x._2.mean / x._2.count) / (x._2.count - 1.0) ))) - .collect + .collect() } def getLabelCounts(minCount : Int, input : RDD[Example], - rankKey: String) : Array[((String, String), Int)] = { + labelFamily: Family) : Array[(Feature, Int)] = { input .mapPartitions(partition => { // family, feature name => count - val weights = new ConcurrentHashMap[(String, String), Int]().asScala + val weights = new ConcurrentHashMap[Feature, Int]().asScala partition.foreach(examples => { - for (i <- 0 until examples.example.size()) { - val example = examples.example.get(i) - val floatFeatures = example.getFloatFeatures - val stringFeatures = example.getStringFeatures - if (floatFeatures.containsKey(rankKey)) { - for (labelEntry <- floatFeatures.get(rankKey)) { - val key = (rankKey, labelEntry._1) - val cur = weights.getOrElse(key, 0) - weights.put(key, 1 + cur) - } - } else if (stringFeatures.containsKey(rankKey)) { - for (labelName <- stringFeatures.get(rankKey)) { - val key = (rankKey, labelName) + for (vector <- examples) { + val labelVector = vector.get(labelFamily) + if (labelVector != null) { + for (fv <- labelVector.iterator) { + val key = fv.feature val cur = weights.getOrElse(key, 0) weights.put(key, 1 + cur) } } } - } - ) + }) weights.iterator }) .reduceByKey((a, b) => a + b) .collect } - def createStringDictionaryFromFeatureStatistics(stats : Array[((String, String), FeatureStatistics)], - excludedFamilies : Set[String]) : StringDictionary = { + def createStringDictionaryFromFeatureStatistics(stats : Array[(Feature, FeatureStatistics)], + excludedFamilies : Set[Family]) : StringDictionary = { val dictionary = new StringDictionary() - for (stat <- stats) { - val (family, feature) = stat._1 - if (!excludedFamilies.contains(family)) { - if (stat._2.variance < 1e-6) { + for ((feature, featureStats) <- stats) { + if (!excludedFamilies.contains(feature.family)) { + if (featureStats.variance < 1e-6) { // Categorical feature, just pass through - dictionary.possiblyAdd(family, feature, 0.0f, 1.0f) + dictionary.possiblyAdd(feature, 0.0f, 1.0f) } else { - val mean = stat._2.mean - val scale = Math.sqrt(1.0 / stat._2.variance) - dictionary.possiblyAdd(family, feature, mean, scale) + val mean = featureStats.mean + val scale = Math.sqrt(1.0 / featureStats.variance) + dictionary.possiblyAdd(feature, mean, scale) } } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/EvalUtil.scala b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/EvalUtil.scala index 77b09d4f..0fd46d73 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/EvalUtil.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/EvalUtil.scala @@ -1,8 +1,10 @@ package com.airbnb.aerosolve.training.pipeline -import com.airbnb.aerosolve.core.{EvaluationRecord, Example} +import com.airbnb.aerosolve.core.features.Family +import com.airbnb.aerosolve.core.{FeatureVector, EvaluationRecord, Example} import com.airbnb.aerosolve.core.models.AbstractModel import com.airbnb.aerosolve.core.transforms.Transformer +import com.airbnb.aerosolve.training.LinearRankerUtils import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.slf4j.{Logger, LoggerFactory} @@ -20,7 +22,7 @@ object EvalUtil { transformer: Transformer, modelOpt: AbstractModel, examples: RDD[Example], - label: String, + labelFamily: Family, useProb: Boolean, isMulticlass: Boolean, isTraining: Example => Boolean): RDD[EvaluationRecord] = { @@ -28,7 +30,7 @@ object EvalUtil { val transformerBC = sc.broadcast(transformer) examples.map(example => exampleToEvaluationRecord( example, transformerBC.value, - modelBC.value, useProb, isMulticlass, label, isTraining) + modelBC.value, useProb, isMulticlass, labelFamily, isTraining) ) } @@ -38,15 +40,15 @@ object EvalUtil { model: AbstractModel, useProb: Boolean, isMulticlass: Boolean, - label: String, + labelFamily: Family, isTraining: Example => Boolean): EvaluationRecord = { val result = new EvaluationRecord result.setIs_training(isTraining(example)) - transformer.combineContextAndItems(example) + example.transform(transformer) if (isMulticlass) { - val score = model.scoreItemMulticlass(example.example.get(0)).asScala - val multiclassLabel = example.example.get(0).floatFeatures.get(label).asScala + val score = model.scoreItemMulticlass(example.only).asScala + val multiclassLabel: FeatureVector = example.only.get(labelFamily) val evalScores = new java.util.HashMap[java.lang.String, java.lang.Double]() val evalLabels = new java.util.HashMap[java.lang.String, java.lang.Double]() @@ -54,16 +56,16 @@ object EvalUtil { result.setLabels(evalLabels) for (s <- score) { - evalScores.put(s.label, s.score) + evalScores.put(s.getLabel, s.getScore) } - for (l <- multiclassLabel) { - evalLabels.put(l._1, l._2) + for (fv <- multiclassLabel.iterator.asScala) { + evalLabels.put(fv.feature.name, fv.value) } } else { - val score = model.scoreItem(example.example.get(0)) + val score = model.scoreItem(example.only) val prob = if (useProb) model.scoreProbability(score) else score - val rank = example.example.get(0).floatFeatures.get(label).values().iterator().next() + val rank = LinearRankerUtils.getLabel(example.only, labelFamily) result.setScore(prob) result.setLabel(rank) @@ -83,10 +85,10 @@ object EvalUtil { val result = new EvaluationRecord result.setIs_training(isTraining(example)) - transformerBC.value.combineContextAndItems(example) - val score = modelBC.value.scoreItem(example.example.get(0)) + example.transform(transformerBC.value) + val score = modelBC.value.scoreItem(example.only) val prob = modelBC.value.scoreProbability(score) - val rank = example.example.get(0).floatFeatures.get("$rank").get("") + val rank = example.only.get("$rank", "") result.setScore(prob) result.setLabel(rank) diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/ExampleUtil.scala b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/ExampleUtil.scala index e151f25e..8bba7b85 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/ExampleUtil.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/ExampleUtil.scala @@ -1,14 +1,13 @@ package com.airbnb.aerosolve.training.pipeline -import com.airbnb.aerosolve.core.features.Features +import com.airbnb.aerosolve.core.Example import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer object ExampleUtil { - def getFeatures(row: Row, schema: Array[StructField]) = { - val features = Features.builder() + def putFeatures(example: Example, row: Row, schema: Array[StructField]) = { val names = ArrayBuffer[String]() val values = ArrayBuffer[AnyRef]() @@ -45,8 +44,7 @@ object ExampleUtil { } } - features.names(names.toArray) - features.values(values.toArray) - features.build() + val fv = example.createVector() + fv.putAll(names.toArray, values.toArray) } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/GenericPipeline.scala b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/GenericPipeline.scala index b1f46312..15227da9 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/GenericPipeline.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/GenericPipeline.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training.pipeline import java.io.{BufferedWriter, OutputStreamWriter} +import com.airbnb.aerosolve.core.features.{FeatureRegistry, SimpleExample} import com.airbnb.aerosolve.core.models.{AbstractModel, ForestModel, FullRankLinearModel} import com.airbnb.aerosolve.core.transforms.Transformer import com.airbnb.aerosolve.core.util.Util @@ -39,7 +40,7 @@ object GenericPipeline { training .coalesce(numShards, true) - .map(Util.encode) + .map(Util.encodeExample) .saveAsTextFile(output, classOf[GzipCodec]) } @@ -60,11 +61,12 @@ object GenericPipeline { val count = cfg.getInt("count") val key = cfg.getString("model_config") val isMulticlass = Try(cfg.getBoolean("is_multiclass")).getOrElse(false) + val registry = new FeatureRegistry val input = makeTraining(sc, query, isMulticlass) LinearRankerUtils - .makePointwiseFloat(input, config, key) + .makePointwiseFloat(input, config, key, registry) .take(count) .foreach(logPrettyExample) } @@ -77,27 +79,29 @@ object GenericPipeline { val inputPattern = cfg.getString("input") val subsample = cfg.getDouble("subsample") val modelConfig = cfg.getString("model_config") + val registry = new FeatureRegistry - val input = getExamples(sc, inputPattern) + val input = getExamples(sc, inputPattern, registry) .filter(isTraining) val filteredInput = input.sample(false, subsample) - TrainingUtils.trainAndSaveToFile(sc, filteredInput, config, modelConfig) + TrainingUtils.trainAndSaveToFile(sc, filteredInput, config, modelConfig, registry) } def getModelAndTransform( config : Config, modelCfgName : String, - modelName : String ) = { - val modelOpt = TrainingUtils.loadScoreModel(modelName) + modelName : String, + registry: FeatureRegistry) = { + val modelOpt = TrainingUtils.loadScoreModel(modelName, registry) if (modelOpt.isEmpty) { log.error("Could not load model") System.exit(-1) } - val transformer = new Transformer(config, modelCfgName) + val transformer = new Transformer(config, modelCfgName, registry) (modelOpt.get, transformer) } @@ -126,7 +130,8 @@ object GenericPipeline { val isRegression = Try(cfg.getBoolean("is_regression")).getOrElse(false) val isMulticlass = Try(cfg.getBoolean("is_multiclass")).getOrElse(false) val metric = cfg.getString("metric_to_maximize") - val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName) + val registry = new FeatureRegistry + val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName, registry) val metrics = evalModelInternal( sc, @@ -139,7 +144,8 @@ object GenericPipeline { isRegression, isMulticlass, metric, - isTraining + isTraining, + registry ) metrics @@ -152,11 +158,12 @@ object GenericPipeline { val plattsConfig = config.getConfig("calibrate_model") val modelCfgName = plattsConfig.getString("model_config") val modelName = plattsConfig.getString("model_name") + val registry = new FeatureRegistry - val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName) + val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName, registry) val input = plattsConfig.getString("input") // training_data_with_ds // get calibration training data - val data = getExamples(sc, input) + val data = getExamples(sc, input, registry) .sample(false, plattsConfig.getDouble("subsample")) val scoresAndLabel = PipelineUtil.scoreExamples(sc, transformer, model, data, isTraining, LABEL) @@ -205,8 +212,8 @@ object GenericPipeline { if (success) { // If calibration is successful, update offset and slope of the model // otherwise, use the default offset = 0 and slope = 1 in the model - model.setOffset(offset) - model.setSlope(slope) + model.offset(offset) + model.slope(slope) } // Save the model with updated calibration parameters @@ -224,11 +231,11 @@ object GenericPipeline { } def modelRecordToString(x: ModelRecord) : String = { - if (x.weightVector != null && !x.weightVector.isEmpty) { + if (x.getWeightVector != null && !x.getWeightVector.isEmpty) { "%s\t%s\t%f\t%f\t%s".format( - x.featureFamily, x.featureName, x.minVal, x.maxVal, x.weightVector.toString) + x.getFeatureFamily, x.getFeatureName, x.getMinVal, x.getMaxVal, x.getWeightVector.toString) } else { - "%s\t%s\t%f".format(x.featureFamily, x.featureName, x.featureWeight) + "%s\t%s\t%f".format(x.getFeatureFamily, x.getFeatureName, x.getFeatureWeight) } } @@ -240,7 +247,7 @@ object GenericPipeline { val model = sc .textFile(modelName) .map(Util.decodeModel) - .filter(x => x.featureName != null) + .filter(x => x.getFeatureName != null) .map(modelRecordToString) PipelineUtil.saveAndCommitAsTextFile(model, modelDump) @@ -250,10 +257,11 @@ object GenericPipeline { val cfg = config.getConfig("dump_forest") val modelName = cfg.getString("model_name") val modelDump = cfg.getString("model_dump") - val model = TrainingUtils.loadScoreModel(modelName).get + val registry = new FeatureRegistry + val model = TrainingUtils.loadScoreModel(modelName, registry).get val forest = model.asInstanceOf[ForestModel] - val trees = forest.getTrees().asScala.toArray + val trees = forest.trees().asScala.toArray val builder = new StringBuilder() val count = trees.size @@ -271,30 +279,29 @@ object GenericPipeline { val modelName = cfg.getString("model_name") val modelDump = cfg.getString("model_dump") val featuresPerLabel = cfg.getInt("features_per_label") - val model = TrainingUtils.loadScoreModel(modelName).get.asInstanceOf[FullRankLinearModel] + val registry = new FeatureRegistry + val model = TrainingUtils.loadScoreModel(modelName, registry) + .get.asInstanceOf[FullRankLinearModel] val builder = new StringBuilder() - model.getLabelDictionary.asScala.foreach(entry => { + model.labelDictionary.asScala.foreach(entry => { val label = entry.getLabel val count = entry.getCount - val index = model.getLabelToIndex.get(label) + val index = model.labelToIndex.get(label) - val weights = model.getWeightVector.asScala.flatMap({ - case (family, features) => features.asScala.map({ - case (feature, fv) => - Tuple3(family, feature, fv.getValues.apply(index)) - }) - }).toSeq + val weights = model.weightVector.asScala.mapValues( + floatVector => floatVector.getValues.apply(index) + ).toSeq // Sort by weight, descending and take top featuresPerLabel - val sortedWeights = weights.sortBy(entry => -1.0 * entry._3).take(featuresPerLabel) + val sortedWeights = weights.sortBy(entry => -1.0 * entry._2).take(featuresPerLabel) - sortedWeights.foreach(weightTuple => { + sortedWeights.foreach{ case (feature, value) => { builder ++= "%s\t%s\t%s\t%f\n".format( - label, weightTuple._1, weightTuple._2, weightTuple._3 + label, feature.family.name, feature.name, value ) - }) + }} }) PipelineUtil.writeStringToFile(builder.toString, modelDump) @@ -308,8 +315,9 @@ object GenericPipeline { val modelCfgName = cfg.getString("model_config") val modelName = cfg.getString("model_name") val isMulticlass = Try(cfg.getBoolean("is_multiclass")).getOrElse(false) + val registry = new FeatureRegistry - val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName) + val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName, registry) val hc = new HiveContext(sc) val hiveTraining = hc.sql(query) @@ -329,7 +337,7 @@ object GenericPipeline { val examples = hiveTraining // ID, example - .map(x => (x.getString(lastIdx), hiveTrainingToExample(x, origSchema, isMulticlass))) + .map(x => (x.getString(lastIdx), hiveTrainingToExample(x, origSchema, registry, isMulticlass))) .coalesce(numShards, true) if (isMulticlass) { @@ -353,8 +361,9 @@ object GenericPipeline { val modelName = cfg.getString("model_name") val count = cfg.getInt("count") val isMulticlass = Try(cfg.getBoolean("is_multiclass")).getOrElse(false) + val registry = new FeatureRegistry - val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName) + val (model, transformer) = getModelAndTransform(config, modelCfgName, modelName, registry) val hc = new HiveContext(sc) val hiveTraining = hc.sql(query) @@ -371,13 +380,13 @@ object GenericPipeline { val ex = hiveTraining // ID, example - .map(x => (x.getString(lastIdx), hiveTrainingToExample(x, origSchema, isMulticlass))) + .map(x => (x.getString(lastIdx), hiveTrainingToExample(x, origSchema, registry, isMulticlass))) .take(count) ex.foreach(ex => { - transformer.combineContextAndItems(ex._2) + ex._2.transform(transformer) val builder = new java.lang.StringBuilder() - val score = model.debugScoreItem(ex._2.example.get(0), builder) + val score = model.debugScoreItem(ex._2.only, builder) builder.append("Debug score for %s\n".format(ex._1)) log.info(builder.toString) }) @@ -387,9 +396,9 @@ object GenericPipeline { example: Example, model: AbstractModel, transformer: Transformer) = { - transformer.combineContextAndItems(example) + example.transform(transformer) - val score = model.scoreItem(example.example.get(0)) + val score = model.scoreItem(example.only) val prob = model.scoreProbability(score) (score, prob) @@ -399,9 +408,9 @@ object GenericPipeline { example: Example, model: AbstractModel, transformer: Transformer) = { - transformer.combineContextAndItems(example) + example.transform(transformer) - val multiclassResults = model.scoreItemMulticlass(example.example.get(0)) + val multiclassResults = model.scoreItemMulticlass(example.only) model.scoreToProbability(multiclassResults) multiclassResults.asScala @@ -418,9 +427,10 @@ object GenericPipeline { isRegression: Boolean, isMulticlass: Boolean, metric: String, - isTraining: Example => Boolean) : Array[(String, Double)] = { + isTraining: Example => Boolean, + registry: FeatureRegistry) : Array[(String, Double)] = { val examples = sc.textFile(inputPattern) - .map(Util.decodeExample) + .map(example => Util.decodeExample(example, registry)) .sample(false, subSample) val records = EvalUtil @@ -429,7 +439,7 @@ object GenericPipeline { transformer, modelOpt, examples, - LABEL, + registry.family(LABEL), isProb, isMulticlass, isTraining) @@ -449,26 +459,15 @@ object GenericPipeline { } def logPrettyExample(ex : Example) = { - val fv = ex.example.get(0) + val fv = ex.only val builder = new StringBuilder() - builder ++= "\nString Features:" - - if (fv.stringFeatures != null) { - fv.stringFeatures.asScala.foreach(x => { - builder ++= "FAMILY : " + x._1 + '\n' - x._2.asScala.foreach(y => {builder ++= "--> " + y + '\n'}) - }) - } - - builder ++= "\nFloat Features:" + builder ++= "\nFeatures:" - if (fv.floatFeatures != null) { - fv.floatFeatures.asScala.foreach(x => { - builder ++= "FAMILY : " + x._1 + '\n' - x._2.asScala.foreach(y => {builder ++= "--> " + y.toString + '\n'}) - }) - } + fv.iterator.asScala.foreach(fv => { + builder ++= "FAMILY : " + fv.feature.family.name + '\n' + builder ++= "--> " + fv.feature.name + " : " + fv.value + '\n' + }) log.info(builder.toString) } @@ -480,9 +479,10 @@ object GenericPipeline { val hc = new HiveContext(sc) val hiveTraining = hc.sql(query) val schema: Array[StructField] = hiveTraining.schema.fields.toArray + val registry = new FeatureRegistry hiveTraining - .map(x => hiveTrainingToExample(x, schema, isMulticlass)) + .map(x => hiveTrainingToExample(x, schema, registry, isMulticlass)) } def isTraining(examples : Example) : Boolean = { @@ -495,10 +495,11 @@ object GenericPipeline { (examples.toString.hashCode & 0xFF) <= 16 } - def getExamples(sc : SparkContext, inputPattern : String) : RDD[Example] = { + def getExamples(sc : SparkContext, inputPattern : String, registry : FeatureRegistry) + : RDD[Example] = { val examples : RDD[Example] = sc .textFile(inputPattern) - .map(Util.decodeExample) + .map(example => Util.decodeExample(example, registry)) examples } @@ -644,8 +645,10 @@ object GenericPipeline { def hiveTrainingToExample( row: Row, schema: Array[StructField], + registry: FeatureRegistry, isMulticlass: Boolean = false): Example = { - val features = ExampleUtil.getFeatures(row, schema) - features.toExample(isMulticlass) + val example = new SimpleExample(registry) + ExampleUtil.putFeatures(example, row, schema) + example } } diff --git a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtil.scala b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtil.scala index 5bf0341f..46f3e535 100644 --- a/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtil.scala +++ b/training/src/main/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtil.scala @@ -115,14 +115,14 @@ object PipelineUtil { modelOpt: AbstractModel, examples: RDD[Example], isTraining: Example => Boolean, - labelKey: String): RDD[(Float, String)] = { + labelKey: String): RDD[(Double, String)] = { val modelBC = sc.broadcast(modelOpt) val transformerBC = sc.broadcast(transformer) val scoreAndLabel = examples .map(example => { - transformerBC.value.combineContextAndItems(example) - val score = modelBC.value.scoreItem(example.example.get(0)) - val rank = example.example.get(0).floatFeatures.get(labelKey).get("") + example.transform(transformerBC.value) + val score = modelBC.value.scoreItem(example.only) + val rank = example.only.get(labelKey, "") val label = (if (isTraining(example)) "TRAIN_" else "HOLD_") + (if (rank > 0) "P" else "N") (score, label) }) diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/AdditiveModelTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/AdditiveModelTrainerTest.scala index 2c5acc98..4c9d1fb7 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/AdditiveModelTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/AdditiveModelTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.{ModelFactory, AdditiveModel} import com.airbnb.aerosolve.core.Example import com.typesafe.config.ConfigFactory @@ -13,6 +14,8 @@ import scala.collection.mutable.ArrayBuffer class AdditiveModelTrainerTest { val log = LoggerFactory.getLogger("AdditiveModelTrainerTest") + val registry = new FeatureRegistry + def makeConfig(loss : String, dropout : Double, extraArgs : String) : String = { """ |identity_transform { @@ -191,13 +194,13 @@ class AdditiveModelTrainerTest { var sc = new SparkContext("local", "AdditiveModelTest") try { val (examples, label, numPos) = if (exampleFunc.equals("poly")) { - TrainingTestHelper.makeClassificationExamples + TrainingTestHelper.makeClassificationExamples(registry) } else { - TrainingTestHelper.makeLinearClassificationExamples + TrainingTestHelper.makeLinearClassificationExamples(registry) } val config = ConfigFactory.parseString(makeConfig(loss, dropout, extraArgs)) val input = sc.parallelize(examples) - val model = AdditiveModelTrainer.train(sc, input, config, "model_config") + val model = AdditiveModelTrainer.train(sc, input, config, "model_config", registry) testClassificationModel(model, examples, label, numPos) } finally { sc.stop @@ -211,15 +214,15 @@ class AdditiveModelTrainerTest { var sc = new SparkContext("local", "AdditiveModelTest") try { val (examples, label) = if (exampleFunc.equals("linear")) { - TrainingTestHelper.makeLinearRegressionExamples() + TrainingTestHelper.makeLinearRegressionExamples(registry) } else { - TrainingTestHelper.makeRegressionExamples() + TrainingTestHelper.makeRegressionExamples(registry) } val (testingExample, testingLabel) = if (exampleFunc.equals("linear")) { - TrainingTestHelper.makeLinearRegressionExamples(25) + TrainingTestHelper.makeLinearRegressionExamples(registry, 25) } else { - TrainingTestHelper.makeRegressionExamples(25) + TrainingTestHelper.makeRegressionExamples(registry, 25) } val threshold = if (exampleFunc.equals("linear")) { @@ -230,7 +233,7 @@ class AdditiveModelTrainerTest { val config = ConfigFactory.parseString(makeRegressionConfig(extraArgs)) val input = sc.parallelize(examples) - val model = AdditiveModelTrainer.train(sc, input, config, "model_config") + val model = AdditiveModelTrainer.train(sc, input, config, "model_config", registry) testRegressionModel(model, examples, label, testingExample, testingLabel, threshold) } finally { @@ -250,7 +253,7 @@ class AdditiveModelTrainerTest { var i : Int = 0 val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -269,12 +272,12 @@ class AdditiveModelTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) log.info(str) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent()) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } } @@ -291,7 +294,7 @@ class AdditiveModelTrainerTest { var i = 0 // compute training error for (ex <- trainingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = trainLabelArr(i) trainTotalError += math.abs(score - label) i += 1 @@ -306,7 +309,7 @@ class AdditiveModelTrainerTest { // compute training error i = 0 for (ex <- testingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = testLabelArr(i) testTotalError += math.abs(score - label) i += 1 diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/BoostedStumpsModelTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/BoostedStumpsModelTest.scala index 1ca44529..500150ce 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/BoostedStumpsModelTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/BoostedStumpsModelTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.ConfigFactory @@ -16,6 +17,7 @@ import scala.collection.mutable.ArrayBuffer class BoostedStumpsModelTest { val log = LoggerFactory.getLogger("BoostedStumpsModelTest") + val registry = new FeatureRegistry def makeConfig(loss : String) : String = { """ @@ -53,7 +55,7 @@ class BoostedStumpsModelTest { } def testBoostedStumpTrainer(loss : String) = { - val (examples, label, numPos) = TrainingTestHelper.makeSimpleClassificationExamples + val (examples, label, numPos) = TrainingTestHelper.makeSimpleClassificationExamples(registry) var sc = new SparkContext("local", "BoostedStumpsTEst") @@ -61,16 +63,16 @@ class BoostedStumpsModelTest { val config = ConfigFactory.parseString(makeConfig(loss)) val input = sc.parallelize(examples) - val model = BoostedStumpsTrainer.train(sc, input, config, "model_config") + val model = BoostedStumpsTrainer.train(sc, input, config, "model_config", registry) - val stumps = model.getStumps.asScala + val stumps = model.stumps.asScala stumps.foreach(stump => log.info(stump.toString)) var numCorrect : Int = 0; var i : Int = 0; val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -89,13 +91,13 @@ class BoostedStumpsModelTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent()) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/DecisionTreeTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/DecisionTreeTrainerTest.scala index 4ab0f084..da4f5d20 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/DecisionTreeTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/DecisionTreeTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{BufferedReader, BufferedWriter, StringReader, StringWriter} +import com.airbnb.aerosolve.core.features.{MultiFamilyVector, BasicMultiFamilyVector, FeatureRegistry} import com.airbnb.aerosolve.core.models.ModelFactory import com.airbnb.aerosolve.core.ModelRecord import com.airbnb.aerosolve.training.pipeline.PipelineTestingUtil @@ -16,6 +17,7 @@ import scala.collection.JavaConversions class DecisionTreeTrainerTest { val log = LoggerFactory.getLogger("DecisionTreeModelTest") + val registry = new FeatureRegistry def makeConfig(splitCriteria : String) : String = { """ @@ -72,15 +74,15 @@ class DecisionTreeTrainerTest { def testDecisionTreeClassificationTrainer( splitCriteria : String, expectedCorrect : Double) = { - val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples + val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples(registry) PipelineTestingUtil.withSparkContext(sc => { val config = ConfigFactory.parseString(makeConfig(splitCriteria)) val input = sc.parallelize(examples) - val model = DecisionTreeTrainer.train(sc, input, config, "model_config") + val model = DecisionTreeTrainer.train(sc, input, config, "model_config", registry) - val stumps = model.getStumps.asScala + val stumps = model.stumps.asScala stumps.foreach(stump => log.info(stump.toString)) var numCorrect : Int = 0 @@ -88,7 +90,7 @@ class DecisionTreeTrainerTest { val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -110,13 +112,13 @@ class DecisionTreeTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } }) @@ -125,15 +127,16 @@ class DecisionTreeTrainerTest { def testDecisionTreeMulticlassTrainer( splitType : String, expectedCorrect : Double) = { - val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(false) + val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(false, + registry) PipelineTestingUtil.withSparkContext(sc => { val config = ConfigFactory.parseString(makeConfig(splitType)) val input = sc.parallelize(examples) - val model = DecisionTreeTrainer.train(sc, input, config, "model_config") + val model = DecisionTreeTrainer.train(sc, input, config, "model_config", registry) - val stumps = model.getStumps.asScala + val stumps = model.stumps.asScala stumps.foreach(stump => log.info(stump.toString)) log.info(model.toDot) @@ -142,9 +145,9 @@ class DecisionTreeTrainerTest { for (i <- examples.indices) { val ex = examples(i) - val scores = model.scoreItemMulticlass(ex.example.get(0)) - val best = scores.asScala.sortWith((a, b) => a.score > b.score).head - if (best.label == labels(i)) { + val scores = model.scoreItemMulticlass(ex.only) + val best = scores.asScala.sortWith((a, b) => a.getScore > b.getScore).head + if (best.getLabel == labels(i)) { numCorrect = numCorrect + 1 } } @@ -164,31 +167,31 @@ class DecisionTreeTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItemMulticlass(ex.example.get(0)) - val score2 = model2.scoreItemMulticlass(ex.example.get(0)) + val score = model.scoreItemMulticlass(ex.only) + val score2 = model2.scoreItemMulticlass(ex.only) assertEquals(score.size, score2.size) for (i <- 0 until score.size) { - assertEquals(score.get(i).score, score2.get(i).score, 0.1f) + assertEquals(score.get(i).getScore, score2.get(i).getScore, 0.1f) } } }) } def testDecisionTreeRegressionTrainer(splitCriteria : String) = { - val (examples, label) = TrainingTestHelper.makeRegressionExamples() + val (examples, label) = TrainingTestHelper.makeRegressionExamples(registry) PipelineTestingUtil.withSparkContext(sc => { val config = ConfigFactory.parseString(makeConfig(splitCriteria)) val input = sc.parallelize(examples) - val model = DecisionTreeTrainer.train(sc, input, config, "model_config") + val model = DecisionTreeTrainer.train(sc, input, config, "model_config", registry) - val stumps = model.getStumps.asScala + val stumps = model.stumps.asScala stumps.foreach(stump => log.info(stump.toString)) val labelArr = label.toArray @@ -196,7 +199,7 @@ class DecisionTreeTrainerTest { var totalError : Double = 0 for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val exampleLabel = labelArr(i) totalError += math.abs(score - exampleLabel) @@ -210,12 +213,12 @@ class DecisionTreeTrainerTest { // Points in flat region result in score of min value (-8.0) val flatRegionExamples = List( - TrainingTestHelper.makeExample(0, -3.5, 0), - TrainingTestHelper.makeExample(0, 3.2, 0) + TrainingTestHelper.makeExample(0, -3.5, 0, registry), + TrainingTestHelper.makeExample(0, 3.2, 0, registry) ) flatRegionExamples.foreach { flatRegionExample => - val score = model.scoreItem(flatRegionExample.example.get(0)) + val score = model.scoreItem(flatRegionExample.only) assertEquals(score, -8.0, 2.0f) } @@ -224,14 +227,14 @@ class DecisionTreeTrainerTest { @Test def testEvaluateRegressionSplit() = { - val examples = Array( - createJavaExample(1.1, 5.0), - createJavaExample(1.2, 5.6), - createJavaExample(1.25, 11.9), - createJavaExample(1.5, 10.2), - createJavaExample(1.8, 12.5), - createJavaExample(2.5, 8.3), - createJavaExample(2.9, 18.4) + val vectors = Array( + createVector(1.1, 5.0), + createVector(1.2, 5.6), + createVector(1.25, 11.9), + createVector(1.5, 10.2), + createVector(1.8, 12.5), + createVector(2.5, 8.3), + createVector(2.9, 18.4) ) val testSplit = new ModelRecord() @@ -240,7 +243,7 @@ class DecisionTreeTrainerTest { testSplit.setThreshold(1.3) val result = DecisionTreeTrainer.evaluateRegressionSplit( - examples, "$rank", 1, SplitCriteria.Variance, Some(testSplit) + vectors, registry.family("$rank"), 1, SplitCriteria.Variance, Some(testSplit) ) // Verify that Welford's Method is consistent with standard, two-pass calculation @@ -256,14 +259,11 @@ class DecisionTreeTrainerTest { assertEquals(result.get, -1.0 * (leftSumSq + rightSumSq), 0.000001f) } - def createJavaExample(x : Double, rank : Double) = { - JavaConversions.mapAsJavaMap( - Map( - "loc" -> JavaConversions.mapAsJavaMap(Map("x" -> x)), - "$rank" -> JavaConversions.mapAsJavaMap(Map("" -> rank)) - ) - ).asInstanceOf[ - java.util.Map[java.lang.String, java.util.Map[java.lang.String, java.lang.Double]]] + def createVector(x : Double, rank : Double) : MultiFamilyVector = { + val vector = new BasicMultiFamilyVector(registry) + vector.put("loc", "x", x) + vector.put("$rank", "", rank) + vector } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/ForestTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/ForestTrainerTest.scala index 49ea9aae..2a8a0c00 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/ForestTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/ForestTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.Config @@ -16,6 +17,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer class ForestTrainerTest { + def makeConfig(splitCriteria : String) : String = { """ |identity_transform { @@ -85,6 +87,7 @@ class ForestTrainerTest { object ForestTrainerTestHelper { val log = LoggerFactory.getLogger("ForestTrainerTest") + val registry = new FeatureRegistry def testForestTrainer(config : Config, boost : Boolean, expectedCorrect : Double) = { testForestTrainerHelper(config, boost, expectedCorrect, false, false) @@ -111,12 +114,12 @@ object ForestTrainerTestHelper { if (multiclass) { val (tmpEx, tmpLabels) = if (nonlinear) - TrainingTestHelper.makeNonlinearMulticlassClassificationExamples() else - TrainingTestHelper.makeSimpleMulticlassClassificationExamples(false) + TrainingTestHelper.makeNonlinearMulticlassClassificationExamples(registry) else + TrainingTestHelper.makeSimpleMulticlassClassificationExamples(false, registry) examples = tmpEx labels = tmpLabels } else { - val (tmpEx, tmpLabel, tmpNumPos) = TrainingTestHelper.makeClassificationExamples + val (tmpEx, tmpLabel, tmpNumPos) = TrainingTestHelper.makeClassificationExamples(registry) examples = tmpEx label = tmpLabel numPos = tmpNumPos @@ -127,15 +130,15 @@ object ForestTrainerTestHelper { try { val input = sc.parallelize(examples) val model = if (boost) { - BoostedForestTrainer.train(sc, input, config, "model_config") + BoostedForestTrainer.train(sc, input, config, "model_config", registry) } else { - ForestTrainer.train(sc, input, config, "model_config") + ForestTrainer.train(sc, input, config, "model_config", registry) } - val trees = model.getTrees.asScala + val trees = model.trees.asScala for (tree <- trees) { log.info("Tree:") - val stumps = tree.getStumps.asScala + val stumps = tree.stumps.asScala stumps.foreach(stump => log.info(stump.toString)) } @@ -143,9 +146,9 @@ object ForestTrainerTestHelper { var numCorrect: Int = 0 for (i <- 0 until examples.length) { val ex = examples(i) - val scores = model.scoreItemMulticlass(ex.example.get(0)) - val best = scores.asScala.sortWith((a, b) => a.score > b.score).head - if (best.label == labels(i)) { + val scores = model.scoreItemMulticlass(ex.only) + val best = scores.asScala.sortWith((a, b) => a.getScore > b.getScore).head + if (best.getLabel == labels(i)) { numCorrect = numCorrect + 1 } } @@ -158,7 +161,7 @@ object ForestTrainerTestHelper { var i : Int = 0; val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -170,32 +173,32 @@ object ForestTrainerTestHelper { assertTrue(fracCorrect > expectedCorrect) } - val swriter = new StringWriter(); - val writer = new BufferedWriter(swriter); - model.save(writer); + val swriter = new StringWriter() + val writer = new BufferedWriter(swriter) + model.save(writer) writer.close() - val str = swriter.toString() + val str = swriter.toString val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) - assertTrue(model2Opt.isPresent()) + val model2Opt = ModelFactory.createFromReader(reader, registry) + assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() if (multiclass) { for (ex <- examples) { - val score = model.scoreItemMulticlass(ex.example.get(0)) - val score2 = model2.scoreItemMulticlass(ex.example.get(0)) + val score = model.scoreItemMulticlass(ex.only) + val score2 = model2.scoreItemMulticlass(ex.only) assertEquals(score.size, score2.size) for (i <- 0 until score.size) { - assertEquals(score.get(i).score, score2.get(i).score, 0.1f) + assertEquals(score.get(i).getScore, score2.get(i).getScore, 0.1d) } } } else { for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) - assertEquals(score, score2, 0.01f) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) + assertEquals(score, score2, 0.01d) } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/FullRankLinearModelTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/FullRankLinearModelTest.scala index 15b04be3..a546f790 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/FullRankLinearModelTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/FullRankLinearModelTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.typesafe.config.ConfigFactory import org.apache.spark.SparkContext @@ -13,6 +14,7 @@ import scala.collection.JavaConverters._ class FullRankLinearModelTest { val log = LoggerFactory.getLogger("FullRankLinearModelTest") + val registry = new FeatureRegistry def makeConfig(loss : String, lambda : Double, solver : String) : String = { """ @@ -82,7 +84,8 @@ class FullRankLinearModelTest { solver : String, multiLabel : Boolean, expectedCorrect : Double) = { - val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(multiLabel) + val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(multiLabel, + registry) var sc = new SparkContext("local", "FullRankLinearTest") @@ -90,9 +93,9 @@ class FullRankLinearModelTest { val config = ConfigFactory.parseString(makeConfig(loss, lambda, solver)) val input = sc.parallelize(examples) - val model = FullRankLinearTrainer.train(sc, input, config, "model_config") + val model = FullRankLinearTrainer.train(sc, input, config, "model_config", registry) - val weightVector = model.getWeightVector.asScala + val weightVector = model.weightVector.asScala for (wv <- weightVector) { log.info(wv.toString()) } @@ -100,9 +103,9 @@ class FullRankLinearModelTest { var numCorrect: Int = 0 for (i <- 0 until examples.length) { val ex = examples(i) - val scores = model.scoreItemMulticlass(ex.example.get(0)) - val best = scores.asScala.sortWith((a, b) => a.score > b.score).head - if (best.label == labels(i)) { + val scores = model.scoreItemMulticlass(ex.only) + val best = scores.asScala.sortWith((a, b) => a.getScore > b.getScore).head + if (best.getLabel == labels(i)) { numCorrect = numCorrect + 1 } } @@ -119,18 +122,18 @@ class FullRankLinearModelTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() val labelCount = if (multiLabel) 6 else 4 for (ex <- examples) { - val score = model.scoreItemMulticlass(ex.example.get(0)) - val score2 = model2.scoreItemMulticlass(ex.example.get(0)) + val score = model.scoreItemMulticlass(ex.only) + val score2 = model2.scoreItemMulticlass(ex.only) assertEquals(score.size, labelCount) assertEquals(score.size, score2.size) for (i <- 0 until labelCount) { - assertEquals(score.get(i).score, score2.get(i).score, 0.1f) + assertEquals(score.get(i).getScore, score2.get(i).getScore, 0.1d) } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/KDTreeTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/KDTreeTest.scala index a4c2ff1a..f18963be 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/KDTreeTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/KDTreeTest.scala @@ -24,27 +24,27 @@ class KDTreeTest { val nodes = tree.nodes log.info("Num nodes = %d".format(nodes.size)) // Since the x dimension is largest we expect the first node to be an xsplit - assertEquals(KDTreeNodeType.X_SPLIT, nodes(0).nodeType) - assertEquals(-1.81, nodes(0).splitValue, 0.1) - assertEquals(1, nodes(0).leftChild) - assertEquals(2, nodes(0).rightChild) // Ensure every point is bounded in the box of the kdtree + assertEquals(KDTreeNodeType.X_SPLIT, nodes(0).getNodeType) + assertEquals(-1.81, nodes(0).getSplitValue, 0.1) + assertEquals(1, nodes(0).getLeftChild) + assertEquals(2, nodes(0).getRightChild) // Ensure every point is bounded in the box of the kdtree for (pt <- pts) { val res = tree.query(pt) for (idx <- res) { - assert(pt._1 >= nodes(idx).minX) - assert(pt._1 <= nodes(idx).maxX) - assert(pt._2 >= nodes(idx).minY) - assert(pt._2 <= nodes(idx).maxY) + assert(pt._1 >= nodes(idx).getMinX) + assert(pt._1 <= nodes(idx).getMaxX) + assert(pt._2 >= nodes(idx).getMinY) + assert(pt._2 <= nodes(idx).getMaxY) } } // Ensure all nodes are sensible for (node <- nodes) { - assert(node.count > 0) - assert(node.minX <= node.maxX) - assert(node.minY <= node.maxY) - if (node.nodeType != KDTreeNodeType.LEAF) { - assert(node.leftChild >= 0 && node.leftChild < nodes.size) - assert(node.rightChild >= 0 && node.rightChild < nodes.size) + assert(node.getCount > 0) + assert(node.getMinX <= node.getMaxX) + assert(node.getMinY <= node.getMaxY) + if (node.getNodeType != KDTreeNodeType.LEAF) { + assert(node.getLeftChild >= 0 && node.getLeftChild < nodes.size) + assert(node.getRightChild >= 0 && node.getRightChild < nodes.size) } } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/KernelTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/KernelTrainerTest.scala index abc8f6b5..0782c960 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/KernelTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/KernelTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.airbnb.aerosolve.core.ModelRecord import com.typesafe.config.ConfigFactory @@ -15,6 +16,7 @@ import scala.collection.JavaConversions class KernelTrainerTest { val log = LoggerFactory.getLogger("KernelTrainerTest") + val registry = new FeatureRegistry def makeConfig(loss : String, kernel : String) : String = { """ @@ -73,7 +75,7 @@ class KernelTrainerTest { def testKernelClassificationTrainer(loss : String, kernel : String, expectedCorrect : Double) = { - val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples + val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples(registry) var sc = new SparkContext("local", "KernelTrainerTest") @@ -81,13 +83,13 @@ class KernelTrainerTest { val config = ConfigFactory.parseString(makeConfig(loss, kernel)) val input = sc.parallelize(examples) - val model = KernelTrainer.train(sc, input, config, "model_config") + val model = KernelTrainer.train(sc, input, config, "model_config", registry) var numCorrect : Int = 0 var i : Int = 0 val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -98,7 +100,7 @@ class KernelTrainerTest { .format(numCorrect, fracCorrect, numPos, examples.length - numPos)) assertTrue(fracCorrect > expectedCorrect) - for (sv <- model.getSupportVectors().asScala) { + for (sv <- model.supportVectors.asScala) { log.info(sv.toModelRecord.toString) } @@ -110,13 +112,13 @@ class KernelTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent()) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } @@ -131,7 +133,7 @@ class KernelTrainerTest { def testKernelRegressionTrainer(loss : String, kernel : String, maxError : Double) = { - val (examples, label) = TrainingTestHelper.makeRegressionExamples() + val (examples, label) = TrainingTestHelper.makeRegressionExamples(registry) var sc = new SparkContext("local", "testKernelRegressionTrainerTest") @@ -139,9 +141,9 @@ class KernelTrainerTest { val config = ConfigFactory.parseString(makeConfig(loss, kernel)) val input = sc.parallelize(examples) - val model = KernelTrainer.train(sc, input, config, "model_config") + val model = KernelTrainer.train(sc, input, config, "model_config", registry) - for (sv <- model.getSupportVectors().asScala) { + for (sv <- model.supportVectors.asScala) { log.info(sv.toModelRecord.toString) } @@ -150,7 +152,7 @@ class KernelTrainerTest { var totalError : Double = 0 for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val exampleLabel = labelArr(i) totalError += math.abs(score - exampleLabel) diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/LinearClassificationTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/LinearClassificationTrainerTest.scala index 598a1df5..81666fa4 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/LinearClassificationTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/LinearClassificationTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.features.{SimpleExample, FeatureRegistry} import com.airbnb.aerosolve.core.util.Util import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.{ConfigFactory, Config} @@ -15,26 +16,16 @@ import scala.collection.mutable.ArrayBuffer class LinearClassificationTrainerTest { val log = LoggerFactory.getLogger("LinearClassificicationTrainerTest") - + val registry = new FeatureRegistry // Creates an example with name and target. def makeExamples(examples : ArrayBuffer[Example], name : String, target : Double) = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setStringFeatures(new java.util.HashMap) - val itemSet = new java.util.HashSet[String]() - itemSet.add(name) - val stringFeatures = item.getStringFeatures - stringFeatures.put("name", itemSet) - val biasSet = new java.util.HashSet[String]() - biasSet.add("1") - stringFeatures.put("bias", biasSet) - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", target) - example.addToExample(item) + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.putString("name", name) + item.putString("bias", "1") + item.put("$rank", "", target) examples += example } @@ -80,20 +71,20 @@ class LinearClassificationTrainerTest { val config = ConfigFactory.parseString(makeConfig) val input = sc.parallelize(examples) - val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config") + val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config", registry) val weights = origWeights.toMap origWeights .foreach(wt => { - log.info("%s:%s=%f".format(wt._1._1, wt._1._2, wt._2)) + log.info("%s:%s=%f".format(wt._1.family.name, wt._1.name, wt._2)) }) for (j <- 0 until 10) { val name = j.toString if (j % 2 == 0) { - assertTrue(weights.getOrElse(("name", name), 0.0) >= 1.0) + assertTrue(weights.getOrElse(registry.feature("name", name), 0.0) >= 1.0) } else { - assertTrue(weights.getOrElse(("name", name), 0.0) <= -1.0) + assertTrue(weights.getOrElse(registry.feature("name", name), 0.0) <= -1.0) } } } finally { diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/LinearLogisticClassificationTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/LinearLogisticClassificationTrainerTest.scala index cb9da4e8..7fce8ecb 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/LinearLogisticClassificationTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/LinearLogisticClassificationTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.features.{SimpleExample, FeatureRegistry} import com.airbnb.aerosolve.core.util.Util import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.{ConfigFactory, Config} @@ -15,26 +16,16 @@ import scala.collection.mutable.ArrayBuffer class LinearLogisticClassificationTrainerTest { val log = LoggerFactory.getLogger("LinearLogisticClassificationTrainerTest") - + val registry = new FeatureRegistry // Creates an example with name and target. def makeExamples(examples : ArrayBuffer[Example], name : String, target : Double) = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setStringFeatures(new java.util.HashMap) - val itemSet = new java.util.HashSet[String]() - itemSet.add(name) - val stringFeatures = item.getStringFeatures - stringFeatures.put("name", itemSet) - val biasSet = new java.util.HashSet[String]() - biasSet.add("1") - stringFeatures.put("bias", biasSet) - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", target) - example.addToExample(item) + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.putString("name", name) + item.putString("bias", "1") + item.put("$rank", "", target) examples += example } @@ -80,20 +71,20 @@ class LinearLogisticClassificationTrainerTest { val config = ConfigFactory.parseString(makeConfig) val input = sc.parallelize(examples) - val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config") + val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config", registry) val weights = origWeights.toMap origWeights .foreach(wt => { - log.info("%s:%s=%f".format(wt._1._1, wt._1._2, wt._2)) + log.info("%s:%s=%f".format(wt._1.family.name, wt._1.name, wt._2)) }) for (j <- 0 until 10) { val name = j.toString if (j % 2 == 0) { - assertTrue(weights.getOrElse(("name", name), 0.0) >= 1.0) + assertTrue(weights.getOrElse(registry.feature("name", name), 0.0) >= 1.0) } else { - assertTrue(weights.getOrElse(("name", name), 0.0) <= -1.0) + assertTrue(weights.getOrElse(registry.feature("name", name), 0.0) <= -1.0) } } } finally { diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/LinearRankerTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/LinearRankerTrainerTest.scala index 8c464ce9..27317d86 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/LinearRankerTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/LinearRankerTrainerTest.scala @@ -1,5 +1,6 @@ package com.airbnb.aerosolve.training +import com.airbnb.aerosolve.core.features.{SimpleExample, FeatureRegistry} import com.airbnb.aerosolve.core.util.Util import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.{ConfigFactory, Config} @@ -14,47 +15,34 @@ import scala.collection.JavaConversions._ class LinearRankerTrainerTest { val log = LoggerFactory.getLogger("LinearRankerTrainerTest") + val registry = new FeatureRegistry + // Creates an example with the context being a name // and items being integers from 0 to 10 and the user "name" // likes integers mod like_mod == 0 def makeExamples(examples : ArrayBuffer[Example], name : String, likes_mod : Int) = { - - val context: FeatureVector = new FeatureVector - context.setStringFeatures(new java.util.HashMap) - val nameSet = new java.util.HashSet[String]() - nameSet.add(name) - val stringFeatures = context.getStringFeatures - stringFeatures.put("name", nameSet) - for (i <- 0 to 10) { for (j <- 1 to 10) { - val example = new Example - example.setContext(context) + val example = new SimpleExample(registry) + example.context.putString("name", name) addItem(example, likes_mod, i) addItem(example, likes_mod, j) examples += example } } } + def addItem(example : Example, likes_mod : Int, i : Int) = { val rank : Double = if (i % likes_mod == 0) { 1.0 } else { 0.0 } - val item: FeatureVector = new FeatureVector - item.setStringFeatures(new java.util.HashMap) - val itemSet = new java.util.HashSet[String]() - itemSet.add("%d".format(i)) - val stringFeatures = item.getStringFeatures - stringFeatures.put("number", itemSet) - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", rank) - example.addToExample(item) + val item: FeatureVector = example.createVector() + item.putString("number", s"$i") + item.put("$rank", "", rank) } def makeConfig: String = { @@ -96,25 +84,25 @@ class LinearRankerTrainerTest { val config = ConfigFactory.parseString(makeConfig) val input = sc.parallelize(examples) - val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config") + val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config", registry) val weights = origWeights.toMap origWeights .foreach(wt => { - log.info("%s:%s=%f".format(wt._1._1, wt._1._2, wt._2)) + log.info("%s:%s=%f".format(wt._1.family.name, wt._1.name, wt._2)) }) // Ensure alice likes even numbers - assertTrue(weights.getOrElse(("name_X_number", "alice^2"), 0.0) > - weights.getOrElse(("name_X_number", "alice^5"), 0.0)) + assertTrue(weights.getOrElse(registry.feature("name_X_number", "alice^2"), 0.0) > + weights.getOrElse(registry.feature("name_X_number", "alice^5"), 0.0)) // Ensure bob likes multiples of 3 - assertTrue(weights.getOrElse(("name_X_number", "bob^6"), 0.0) > - weights.getOrElse(("name_X_number", "bob^1"), 0.0)) + assertTrue(weights.getOrElse(registry.feature("name_X_number", "bob^6"), 0.0) > + weights.getOrElse(registry.feature("name_X_number", "bob^1"), 0.0)) // Ensure charlie likes multiples of 4 - assertTrue(weights.getOrElse(("name_X_number", "charlie^8"), 0.0) > - weights.getOrElse(("name_X_number", "charlie^7"), 0.0)) + assertTrue(weights.getOrElse(registry.feature("name_X_number", "charlie^8"), 0.0) > + weights.getOrElse(registry.feature("name_X_number", "charlie^7"), 0.0)) } finally { sc.stop sc = null @@ -134,18 +122,22 @@ class LinearRankerTrainerTest { val input = sc.parallelize(examples) val xform = LinearRankerUtils - .transformExamples(input, config, "model_config") + .transformExamples(input, config, "model_config", registry) .collect .toArray .head log.info(xform.toString) - val fv = xform.example.asScala.toArray - assertTrue(fv(0).stringFeatures.get("name").asScala.head.equals("alice")) - assertTrue(fv(0).stringFeatures.get("number").asScala.head.equals("0")) - assertTrue(fv(0).stringFeatures.get("name_X_number").asScala.head.equals("alice^0")) - assertTrue(fv(1).stringFeatures.get("name").asScala.head.equals("alice")) - assertTrue(fv(1).stringFeatures.get("number").asScala.head.equals("1")) - assertTrue(fv(1).stringFeatures.get("name_X_number").asScala.head.equals("alice^1")) + val fv = xform.asScala.toArray + val expected = List((0, "name", "alice"), + (0, "number", "0"), + (0, "name_X_number", "alice^0"), + (1, "name", "alice"), + (1, "number", "1"), + (1, "name_X_number", "alice^1")) + expected.foreach{ case (index, familyName, expectedValue) => + assertTrue(fv(index).get(registry.family(familyName)).iterator.next + .feature.name.equals(expectedValue)) + } } finally { sc.stop sc = null diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/LinearRegressionTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/LinearRegressionTrainerTest.scala index c9853c67..e91b3ab9 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/LinearRegressionTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/LinearRegressionTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.features.{SimpleExample, FeatureRegistry} import com.airbnb.aerosolve.core.util.Util import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.typesafe.config.{ConfigFactory, Config} @@ -14,26 +15,16 @@ import scala.collection.mutable.ArrayBuffer class LinearRegressionTrainerTest { val log = LoggerFactory.getLogger("LinearRegressionTrainerTest") - + val registry = new FeatureRegistry // Creates an example with name and target. def makeExamples(examples : ArrayBuffer[Example], name : String, target : Double) = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setStringFeatures(new java.util.HashMap) - val itemSet = new java.util.HashSet[String]() - itemSet.add(name) - val stringFeatures = item.getStringFeatures - stringFeatures.put("name", itemSet) - val biasSet = new java.util.HashSet[String]() - biasSet.add("1") - stringFeatures.put("bias", biasSet) - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", target) - example.addToExample(item) + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.putString("name", name) + item.putString("bias", "1") + item.put("$rank", "", target) examples += example } @@ -83,30 +74,30 @@ class LinearRegressionTrainerTest { val config = ConfigFactory.parseString(makeConfig(loss)) val input = sc.parallelize(examples) - val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config") + val origWeights = LinearRankerTrainer.train(sc, input, config, "model_config", registry) val weights = origWeights.toMap origWeights .foreach(wt => { - log.info("%s:%s=%f".format(wt._1._1, wt._1._2, wt._2)) + log.info("%s:%s=%f".format(wt._1.family.name, wt._1.name, wt._2)) }) // Ensure alice likes 2 assertEquals(2.0, - weights.getOrElse(("name", "alice"), 0.0) + - weights.getOrElse(("bias", "1"), 0.0), + weights.getOrElse(registry.feature("name", "alice"), 0.0) + + weights.getOrElse(registry.feature("bias", "1"), 0.0), 0.5) // Ensure bob likes 3 assertEquals(3.0, - weights.getOrElse(("name", "bob"), 0.0) + - weights.getOrElse(("bias", "1"), 0.0), + weights.getOrElse(registry.feature("name", "bob"), 0.0) + + weights.getOrElse(registry.feature("bias", "1"), 0.0), 0.5) // Ensure charlie 7 assertEquals(7.0, - weights.getOrElse(("name", "charlie"), 0.0) + - weights.getOrElse(("bias", "1"), 0.0), + weights.getOrElse(registry.feature("name", "charlie"), 0.0) + + weights.getOrElse(registry.feature("bias", "1"), 0.0), 0.5) } finally { diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/LowRankLinearTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/LowRankLinearTrainerTest.scala index f260a9d9..23f5afda 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/LowRankLinearTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/LowRankLinearTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.typesafe.config.ConfigFactory import org.apache.spark.SparkContext @@ -13,6 +14,7 @@ import scala.collection.JavaConverters._ class LowRankLinearTrainerTest { val log = LoggerFactory.getLogger("LowRankLinearTrainerTest") + val registry = new FeatureRegistry def makeConfig(lambda : Double, embeddingDim : Int, rankLossType : String) : String = { """ @@ -53,7 +55,8 @@ class LowRankLinearTrainerTest { rankLossType: String, multiLabel : Boolean, expectedCorrect : Double) = { - val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(multiLabel) + val (examples, labels) = TrainingTestHelper.makeSimpleMulticlassClassificationExamples(multiLabel, + registry) var sc = new SparkContext("local", "LowRankLinearTest") @@ -61,14 +64,14 @@ class LowRankLinearTrainerTest { val config = ConfigFactory.parseString(makeConfig(lambda, embeddingDim, rankLossType)) val input = sc.parallelize(examples) - val model = LowRankLinearTrainer.train(sc, input, config, "model_config") + val model = LowRankLinearTrainer.train(sc, input, config, "model_config", registry) - val featureWeightVector = model.getFeatureWeightVector.asScala + val featureWeightVector = model.featureWeightVector.asScala for (wv <- featureWeightVector) { log.info(wv.toString()) } - val labelWeightVector = model.getLabelWeightVector.asScala + val labelWeightVector = model.labelWeightVector.asScala for (wv <- labelWeightVector) { log.info(wv.toString()) } @@ -76,9 +79,9 @@ class LowRankLinearTrainerTest { var numCorrect: Int = 0 for (i <- examples.indices) { val ex = examples(i) - val scores = model.scoreItemMulticlass(ex.example.get(0)) - val best = scores.asScala.sortWith((a, b) => a.score > b.score).head - if (best.label == labels(i)) { + val scores = model.scoreItemMulticlass(ex.only) + val best = scores.asScala.sortWith((a, b) => a.getScore > b.getScore).head + if (best.getLabel == labels(i)) { numCorrect = numCorrect + 1 } } @@ -95,18 +98,18 @@ class LowRankLinearTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() val labelCount = if (multiLabel) 6 else 4 for (ex <- examples) { - val score = model.scoreItemMulticlass(ex.example.get(0)) - val score2 = model2.scoreItemMulticlass(ex.example.get(0)) + val score = model.scoreItemMulticlass(ex.only) + val score2 = model2.scoreItemMulticlass(ex.only) assertEquals(score.size, labelCount) assertEquals(score.size, score2.size) for (i <- 0 until labelCount) { - assertEquals(score.get(i).score, score2.get(i).score, 0.1f) + assertEquals(score.get(i).getScore, score2.get(i).getScore, 0.1f) } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/MaxoutTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/MaxoutTrainerTest.scala index 879a9cc6..810b75f1 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/MaxoutTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/MaxoutTrainerTest.scala @@ -3,6 +3,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} import java.util +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.ModelFactory import com.airbnb.aerosolve.core.util.Util import com.airbnb.aerosolve.core.{Example, FeatureVector} @@ -18,6 +19,7 @@ import scala.collection.mutable.ArrayBuffer class MaxoutTrainerTest { val log = LoggerFactory.getLogger("MaxoutTrainerTest") + val registry = new FeatureRegistry def makeConfig(loss : String) : String = { """ @@ -58,7 +60,7 @@ class MaxoutTrainerTest { } def testMaxoutTrainer(loss : String) = { - val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples + val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples(registry) var sc = new SparkContext("local", "MaxoutTrainerText") @@ -66,23 +68,21 @@ class MaxoutTrainerTest { val config = ConfigFactory.parseString(makeConfig(loss)) val input = sc.parallelize(examples) - val model = MaxoutTrainer.train(sc, input, config, "model_config") - - val weights = model.getWeightVector.asScala - for (familyMap <- weights) { - for (featureMap <- familyMap._2.asScala) { - log.info(("family=%s,feature=%s," - + "scale=%f, weights=%s") - .format(familyMap._1, featureMap._1, featureMap._2.scale, - featureMap._2.weights.toString())) - } + val model = MaxoutTrainer.train(sc, input, config, "model_config", registry) + + val weights = model.weightVector.asScala + for ((feature, vec) <- weights) { + log.info(("family=%s,feature=%s," + + "scale=%f, weights=%s") + .format(feature.family.name, feature.name, vec.scale, + vec.weights.toString)) } - var numCorrect : Int = 0; - var i : Int = 0; + var numCorrect : Int = 0 + var i : Int = 0 val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -93,23 +93,23 @@ class MaxoutTrainerTest { .format(numCorrect, fracCorrect, numPos, examples.length - numPos)) assertTrue(fracCorrect > 0.6) - val swriter = new StringWriter(); - val writer = new BufferedWriter(swriter); - model.save(writer); + val swriter = new StringWriter() + val writer = new BufferedWriter(swriter) + model.save(writer) writer.close() - val str = swriter.toString() + val str = swriter.toString val sreader = new StringReader(str) val reader = new BufferedReader(sreader) log.info(str) - val model2Opt = ModelFactory.createFromReader(reader) - assertTrue(model2Opt.isPresent()) + val model2Opt = ModelFactory.createFromReader(reader, registry) + assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/MlpModelTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/MlpModelTrainerTest.scala index b1f8f445..23eb36d3 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/MlpModelTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/MlpModelTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.core.models.{ModelFactory, MlpModel} import com.airbnb.aerosolve.core.Example import com.typesafe.config.ConfigFactory @@ -14,6 +15,8 @@ import scala.collection.JavaConverters._ class MlpModelTrainerTest { val log = LoggerFactory.getLogger("MlpModelTrainerTest") + val registry = new FeatureRegistry + def makeConfig(dropout : Double, momentumT : Int, loss : String, @@ -114,13 +117,13 @@ class MlpModelTrainerTest { var sc = new SparkContext("local", "MlpModelTrainerTest") try { val (examples, label, numPos) = if (exampleFunc.equals("poly")) { - TrainingTestHelper.makeClassificationExamples + TrainingTestHelper.makeClassificationExamples(registry) } else { - TrainingTestHelper.makeLinearClassificationExamples + TrainingTestHelper.makeLinearClassificationExamples(registry) } val config = ConfigFactory.parseString(makeConfig(dropout, momentumT, loss, extraArgs, weightDecay)) val input = sc.parallelize(examples) - val model = MlpModelTrainer.train(sc, input, config, "model_config") + val model = MlpModelTrainer.train(sc, input, config, "model_config", registry) testClassificationModel(model, examples, label, numPos) } finally { sc.stop @@ -137,7 +140,7 @@ class MlpModelTrainerTest { var i : Int = 0 val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -156,12 +159,12 @@ class MlpModelTrainerTest { val sreader = new StringReader(str) val reader = new BufferedReader(sreader) log.info(str) - val model2Opt = ModelFactory.createFromReader(reader) + val model2Opt = ModelFactory.createFromReader(reader, registry) assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } } @@ -172,33 +175,33 @@ class MlpModelTrainerTest { weightDecay : Double, epsilon: Double = 0.1, learningRateInit: Double = 0.1): Unit = { - val (trainingExample, trainingLabel) = TrainingTestHelper.makeRegressionExamples() + val (trainingExample, trainingLabel) = TrainingTestHelper.makeRegressionExamples(registry) var sc = new SparkContext("local", "MlpRegressionTest") try { val config = ConfigFactory.parseString(makeConfig( dropout, momentumT, "regression", extraArgs, weightDecay = weightDecay, margin = epsilon, learningRateInit = learningRateInit)) val input = sc.parallelize(trainingExample) - val model = MlpModelTrainer.train(sc, input, config, "model_config") + val model = MlpModelTrainer.train(sc, input, config, "model_config", registry) val trainLabelArr = trainingLabel.toArray var trainTotalError : Double = 0 var i = 0 // compute training error for (ex <- trainingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = trainLabelArr(i) trainTotalError += math.abs(score - label) i += 1 } val trainError = trainTotalError / trainingExample.size.toDouble // compute testing error - val (testingExample, testingLabel) = TrainingTestHelper.makeRegressionExamples(25) + val (testingExample, testingLabel) = TrainingTestHelper.makeRegressionExamples(registry, 25) val testLabelArr = testingLabel.toArray var testTotalError : Double = 0 // compute training error i = 0 for (ex <- testingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = testLabelArr(i) testTotalError += math.abs(score - label) i += 1 diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/NDTreeTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/NDTreeTest.scala index 213c4fc1..b14ec8ac 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/NDTreeTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/NDTreeTest.scala @@ -36,33 +36,33 @@ class NDTreeTest { log.info("Number of nodes = %d".format(nodes.length)) // Since the y dimension is largest, we expect the first node to be a split along the y axis - assertEquals(1, nodes(0).axisIndex) - assertEquals(21.0, nodes(0).splitValue, 0) - assertEquals(1, nodes(0).leftChild) - assertEquals(2, nodes(0).rightChild) + assertEquals(1, nodes(0).getAxisIndex) + assertEquals(21.0, nodes(0).getSplitValue, 0) + assertEquals(1, nodes(0).getLeftChild) + assertEquals(2, nodes(0).getRightChild) // Ensure every point is bounded in the box of the kdtree for (point <- points) { val res = tree.query(point) for (index <- res) { for (i <- 0 until dimensions) { - assert(point(i) >= nodes(index).min.get(i)) - assert(point(i) <= nodes(index).max.get(i)) + assert(point(i) >= nodes(index).getMin.get(i)) + assert(point(i) <= nodes(index).getMax.get(i)) } } } // Ensure all nodes are sensible for (node <- nodes) { - assert(node.count > 0) + assert(node.getCount > 0) for (i <- 0 until dimensions) { - assert(node.min.get(i) <= node.max.get(i)) + assert(node.getMin.get(i) <= node.getMax.get(i)) } - if (node.axisIndex != NDTreeModel.LEAF) { - assert(node.leftChild >= 0 && node.leftChild < nodes.length) - assert(node.rightChild >= 0 && node.rightChild < nodes.length) + if (node.getAxisIndex != NDTreeModel.LEAF) { + assert(node.getLeftChild >= 0 && node.getLeftChild < nodes.length) + assert(node.getRightChild >= 0 && node.getRightChild < nodes.length) } } } @@ -92,33 +92,33 @@ class NDTreeTest { log.info("Number of nodes = %d".format(nodes.length)) // Since the x dimension is largest, we expect the first node to be a split along the x axis - assertEquals(0, nodes(0).axisIndex) - assertEquals(-1.81, nodes(0).splitValue, 0.1) - assertEquals(1, nodes(0).leftChild) - assertEquals(2, nodes(0).rightChild) + assertEquals(0, nodes(0).getAxisIndex) + assertEquals(-1.81, nodes(0).getSplitValue, 0.1) + assertEquals(1, nodes(0).getLeftChild) + assertEquals(2, nodes(0).getRightChild) // Ensure every point is bounded in the box of the kdtree for (point <- points) { val res = tree.query(point) for (index <- res) { for (i <- 0 until dimensions) { - assert(point(i) >= nodes(index).min.get(i)) - assert(point(i) <= nodes(index).max.get(i)) + assert(point(i) >= nodes(index).getMin.get(i)) + assert(point(i) <= nodes(index).getMax.get(i)) } } } // Ensure all nodes are sensible for (node <- nodes) { - assert(node.count > 0) + assert(node.getSplitValue > 0) for (i <- 0 until dimensions) { - assert(node.min.get(i) <= node.max.get(i)) + assert(node.getMin.get(i) <= node.getMax.get(i)) } - if (node.axisIndex != NDTreeModel.LEAF) { - assert(node.leftChild >= 0 && node.leftChild < nodes.length) - assert(node.rightChild >= 0 && node.rightChild < nodes.length) + if (node.getAxisIndex != NDTreeModel.LEAF) { + assert(node.getLeftChild >= 0 && node.getLeftChild < nodes.length) + assert(node.getRightChild >= 0 && node.getRightChild < nodes.length) } } } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/SplineRankingTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/SplineRankingTrainerTest.scala index 515e90e5..7ff649e3 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/SplineRankingTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/SplineRankingTrainerTest.scala @@ -3,13 +3,14 @@ package com.airbnb.aerosolve.training import java.io.{BufferedReader, BufferedWriter, StringReader, StringWriter} import java.util +import com.airbnb.aerosolve.core.features.{SimpleExample, FeatureRegistry} import com.airbnb.aerosolve.core.models.SplineModel.WeightSpline import com.airbnb.aerosolve.core.models.{ModelFactory, SplineModel} import com.airbnb.aerosolve.core.{Example, FeatureVector} import com.airbnb.aerosolve.core.transforms.Transformer import java.util.{HashMap, Scanner} -import com.airbnb.aerosolve.core.function.Spline +import com.airbnb.aerosolve.core.functions.Spline import com.typesafe.config.ConfigFactory import org.apache.spark.SparkContext import org.junit.Test @@ -22,40 +23,32 @@ import scala.collection.mutable.ArrayBuffer class SplineRankingTrainerTest { val log = LoggerFactory.getLogger("SplineTrainerTest") + val registry = new FeatureRegistry // Creates an example with the context being a name // and items being integers from 0 to 10 and the user "name" // likes integers mod like_mod == 0 - def makeExample(name : String, likes_mod : Int) = { - val example = new Example() - val context: FeatureVector = new FeatureVector - context.setStringFeatures(new java.util.HashMap) - val nameSet = new java.util.HashSet[String]() - nameSet.add(name) - val stringFeatures = context.getStringFeatures - stringFeatures.put("name", nameSet) - example.setContext(context) + def makeExample(name : String, likes_mod : Int): Example = { + val example = new SimpleExample(registry) + example.context.putString("name", name) for (i <- 0 to 10) { addItem(example, likes_mod, i) } example } + def addItem(example : Example, likes_mod : Int, i : Int) = { val rank : Double = if (i % likes_mod == 0) { 1.0 } else { -1.0 } - val item: FeatureVector = new FeatureVector - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", rank) - floatFeatures.put("number", new java.util.HashMap) - floatFeatures.get("number").put("n", i) - example.addToExample(item) + val item: FeatureVector = example.createVector() + item.put("number","n", i.toDouble) + item.put("$rank", "", rank) } + def makeConfig() : String = { """ |identity_transform { @@ -106,10 +99,11 @@ class SplineRankingTrainerTest { val config = ConfigFactory.parseString(makeConfig()) val input = sc.parallelize(examples) - val model = SplineTrainer.train(sc, input, config, "model_config") - val transformer = new Transformer(config, "model_config") - transformer.combineContextAndItems(alice) - val aliceEx = alice.example.asScala + val model = SplineTrainer.train(sc, input, config, "model_config", registry) + + val transformer = new Transformer(config, "model_config", registry) + alice.transform(transformer) + val aliceEx = alice.asScala.toSeq for (i <- 0 until 10) { log.info(model.scoreItem(aliceEx(i)).toString) log.info(aliceEx(i).toString) @@ -118,8 +112,9 @@ class SplineRankingTrainerTest { assertTrue(model.scoreItem(aliceEx(2)) > model.scoreItem(aliceEx(1))) assertTrue(model.scoreItem(aliceEx(4)) > model.scoreItem(aliceEx(3))) assertTrue(model.scoreItem(aliceEx(6)) > model.scoreItem(aliceEx(9))) - transformer.combineContextAndItems(bob) - val bobEx = bob.example.asScala + + bob.transform(transformer) + val bobEx = bob.asScala.toSeq for (i <- 0 until 10) { log.info(model.scoreItem(bobEx(i)).toString) log.info(bobEx(i).toString) diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/SplineTrainerTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/SplineTrainerTest.scala index 3e4ebe1c..e1c28ceb 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/SplineTrainerTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/SplineTrainerTest.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.io.{StringReader, BufferedWriter, BufferedReader, StringWriter} +import com.airbnb.aerosolve.core.features.{Feature, FeatureRegistry} import com.airbnb.aerosolve.core.models.SplineModel.WeightSpline import com.airbnb.aerosolve.core.models.{ModelFactory, SplineModel} import com.typesafe.config.ConfigFactory @@ -14,26 +15,23 @@ import scala.collection.JavaConverters._ class SplineTrainerTest { val log = LoggerFactory.getLogger("SplineTrainerTest") + val registry = new FeatureRegistry def makeSplineModel() : SplineModel = { - val model: SplineModel = new SplineModel() - val weights = new java.util.HashMap[String, java.util.Map[String, WeightSpline]]() - val innerA = new java.util.HashMap[String, WeightSpline] - val innerB = new java.util.HashMap[String, WeightSpline] + val model: SplineModel = new SplineModel(registry) + val weights = new java.util.HashMap[Feature, WeightSpline]() val a = new WeightSpline(1.0f, 10.0f, 2) val b = new WeightSpline(1.0f, 10.0f, 2) a.splineWeights(0) = 1.0f a.splineWeights(1) = 2.0f b.splineWeights(0) = 1.0f b.splineWeights(1) = 5.0f - weights.put("A", innerA) - weights.put("B", innerB) - innerA.put("a", a) - innerB.put("b", b) + weights.put(registry.feature("A", "a"), a) + weights.put(registry.feature("B", "b"), b) model.setNumBins(2) model.setWeightSpline(weights) - model.setOffset(0.5f) - model.setSlope(1.5f) + model.offset(0.5f) + model.slope(1.5f) model } @@ -136,7 +134,7 @@ class SplineTrainerTest { } def testSplineTrainer(loss : String, dropout : Double, extraArgs : String) = { - val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples + val (examples, label, numPos) = TrainingTestHelper.makeClassificationExamples(registry) var sc = new SparkContext("local", "SplineTest") @@ -144,7 +142,7 @@ class SplineTrainerTest { val config = ConfigFactory.parseString(makeConfig(loss, dropout, extraArgs)) val input = sc.parallelize(examples) - val model = SplineTrainer.train(sc, input, config, "model_config") + val model = SplineTrainer.train(sc, input, config, "model_config", registry) TrainingTestHelper.printSpline(model) @@ -152,7 +150,7 @@ class SplineTrainerTest { var i : Int = 0 val labelArr = label.toArray for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) if (score * labelArr(i) > 0) { numCorrect += 1 } @@ -167,19 +165,19 @@ class SplineTrainerTest { val writer = new BufferedWriter(swriter) model.save(writer) writer.close() - val str = swriter.toString() + val str = swriter.toString val sreader = new StringReader(str) val reader = new BufferedReader(sreader) log.info(str) - val model2Opt = ModelFactory.createFromReader(reader) - assertTrue(model2Opt.isPresent()) + val model2Opt = ModelFactory.createFromReader(reader, registry) + assertTrue(model2Opt.isPresent) val model2 = model2Opt.get() for (ex <- examples) { - val score = model.scoreItem(ex.example.get(0)) - val score2 = model2.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) + val score2 = model2.scoreItem(ex.only) assertEquals(score, score2, 0.01f) } @@ -195,54 +193,51 @@ class SplineTrainerTest { def testAddSpline(): Unit = { val model = makeSplineModel() // add an existing feature without overwrite - model.addSpline("A", "a", 0.0f, 1.0f, false) + model.addSpline(registry.feature("A", "a"), 0.0d, 1.0d, false) // add an existing feature with overwrite - model.addSpline("B", "b", 0.0f, 1.0f, true) + model.addSpline(registry.feature("B", "b"), 0.0d, 1.0d, true) // add a new family with overwrite - model.addSpline("C", "c", 0.0f, 1.0f, true) + model.addSpline(registry.feature("C", "c"), 0.0d, 1.0d, true) val weights = model.getWeightSpline.asScala - for (familyMap <- weights) { - for (featureMap <- familyMap._2.asScala) { - val family = familyMap._1 - val feature = featureMap._1 - val spline = featureMap._2.spline - log.info(("family=%s,feature=%s,minVal=%f, maxVal=%f, weights=%s") - .format(family, feature, spline.getMinVal, spline.getMaxVal, spline.toString)) - if (family.equals("A")) { - assertTrue(feature.equals("a")) - assertEquals(spline.getMaxVal, 10.0f, 0.01f) - assertEquals(spline.getMinVal, 1.0f, 0.01f) - assertEquals(spline.evaluate(1.0f), 1.0f, 0.01f) - assertEquals(spline.evaluate(10.0f), 2.0f, 0.01f) - } else if (family.equals("B")) { - assertTrue(feature.equals("b")) - assertEquals(spline.getMaxVal, 1.0f, 0.01f) - assertEquals(spline.getMinVal, 0.0f, 0.01f) - } else { - assertTrue(family.equals("C")) - assertTrue(feature.equals("c")) - assertEquals(spline.getMaxVal, 1.0f, 0.01f) - assertEquals(spline.getMinVal, 0.0f, 0.01f) - } + for ((feature, weightSpline) <- weights) { + val family = feature.family.name + val spline = weightSpline.spline + log.info("family=%s,feature=%s,minVal=%f, maxVal=%f, weights=%s" + .format(family, feature.name, spline.getMinVal, spline.getMaxVal, spline.toString)) + if (family.equals("A")) { + assertTrue(feature.name.equals("a")) + assertEquals(spline.getMaxVal, 10.0f, 0.01f) + assertEquals(spline.getMinVal, 1.0f, 0.01f) + assertEquals(spline.evaluate(1.0f), 1.0f, 0.01f) + assertEquals(spline.evaluate(10.0f), 2.0f, 0.01f) + } else if (family.equals("B")) { + assertTrue(feature.name.equals("b")) + assertEquals(spline.getMaxVal, 1.0f, 0.01f) + assertEquals(spline.getMinVal, 0.0f, 0.01f) + } else { + assertTrue(family.equals("C")) + assertTrue(feature.name.equals("c")) + assertEquals(spline.getMaxVal, 1.0f, 0.01f) + assertEquals(spline.getMinVal, 0.0f, 0.01f) } } } @Test def testSplineRegression(): Unit = { - val (trainingExample, trainingLabel) = TrainingTestHelper.makeRegressionExamples() + val (trainingExample, trainingLabel) = TrainingTestHelper.makeRegressionExamples(registry) var sc = new SparkContext("local", "SplineRegressionTest") try { - val config = ConfigFactory.parseString(makeRegressionConfig) + val config = ConfigFactory.parseString(makeRegressionConfig()) val input = sc.parallelize(trainingExample) - val model = SplineTrainer.train(sc, input, config, "model_config") + val model = SplineTrainer.train(sc, input, config, "model_config", registry) TrainingTestHelper.printSpline(model) val trainLabelArr = trainingLabel.toArray var trainTotalError : Double = 0 var i = 0 // compute training error for (ex <- trainingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = trainLabelArr(i) trainTotalError += math.abs(score - label) i += 1 @@ -253,13 +248,13 @@ class SplineTrainerTest { assertTrue(trainError < 3.0) // compute testing error - val (testingExample, testingLabel) = TrainingTestHelper.makeRegressionExamples(25) + val (testingExample, testingLabel) = TrainingTestHelper.makeRegressionExamples(registry, 25) val testLabelArr = testingLabel.toArray var testTotalError : Double = 0 // compute training error i = 0 for (ex <- testingExample) { - val score = model.scoreItem(ex.example.get(0)) + val score = model.scoreItem(ex.only) val label = testLabelArr(i) testTotalError += math.abs(score - label) i += 1 diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/TrainingTestHelper.scala b/training/src/test/scala/com/airbnb/aerosolve/training/TrainingTestHelper.scala index e5335e8b..83045463 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/TrainingTestHelper.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/TrainingTestHelper.scala @@ -2,6 +2,7 @@ package com.airbnb.aerosolve.training import java.util +import com.airbnb.aerosolve.core.features.{FeatureRegistry, SimpleExample} import com.airbnb.aerosolve.core.models.{AdditiveModel, SplineModel} import com.airbnb.aerosolve.core.{Example, FeatureVector} import org.slf4j.LoggerFactory @@ -14,32 +15,21 @@ object TrainingTestHelper { def makeExample(x : Double, y : Double, - target : Double) : Example = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setFloatFeatures(new java.util.HashMap) - item.setStringFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - val stringFeatures = item.getStringFeatures - // A string feature that is always on. - stringFeatures.put("BIAS", new java.util.HashSet) - stringFeatures.get("BIAS").add("B") - // A string feature that is sometimes on + target : Double, + registry: FeatureRegistry) : Example = { + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.putString("BIAS", "B") if (x + y < 0) { - stringFeatures.put("NEG", new java.util.HashSet) - stringFeatures.get("NEG").add("T") + item.putString("NEG", "T") } - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", target) - floatFeatures.put("loc", new java.util.HashMap) - val loc = floatFeatures.get("loc") - loc.put("x", x) - loc.put("y", y) - example.addToExample(item) + item.put("$rank", "", target) + item.put("loc", "x", x) + item.put("loc", "y", y) example } - def makeSimpleClassificationExamples = { + def makeSimpleClassificationExamples(registry: FeatureRegistry) = { val examples = ArrayBuffer[Example]() val label = ArrayBuffer[Double]() val rnd = new java.util.Random(1234) @@ -55,7 +45,7 @@ object TrainingTestHelper { } if (rank > 0) numPos = numPos + 1 label += rank - examples += makeExample(x, y, rank) + examples += makeExample(x, y, rank, registry) } (examples, label, numPos) } @@ -64,31 +54,23 @@ object TrainingTestHelper { y : Double, z : Double, label : (String, Double), - label2 : Option[(String, Double)]) : Example = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setFloatFeatures(new java.util.HashMap) - item.setStringFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - val stringFeatures = item.getStringFeatures - // A string feature that is always on. - stringFeatures.put("BIAS", new java.util.HashSet) - stringFeatures.get("BIAS").add("B") - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put(label._1, label._2) + label2 : Option[(String, Double)], + registry: FeatureRegistry) : Example = { + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.putString("BIAS", "B") + item.put("$rank", label._1, label._2) if (label2.isDefined) { - floatFeatures.get("$rank").put(label2.get._1, label2.get._2) + item.put("$rank", label2.get._1, label2.get._2) } - floatFeatures.put("loc", new java.util.HashMap) - val loc = floatFeatures.get("loc") - loc.put("x", x) - loc.put("y", y) - loc.put("z", z) - example.addToExample(item) + item.put("loc", "x", x) + item.put("loc", "y", y) + item.put("loc", "z", z) example } - def makeSimpleMulticlassClassificationExamples(multiLabel : Boolean) = { + def makeSimpleMulticlassClassificationExamples(multiLabel : Boolean, + registry : FeatureRegistry) = { val examples = ArrayBuffer[Example]() val labels = ArrayBuffer[String]() val rnd = new java.util.Random(1234) @@ -122,15 +104,15 @@ object TrainingTestHelper { labels += label if (multiLabel) { val label2 = if (x > 0) "right" else "left" - examples += makeMulticlassExample(x, y, z, (label, 1.0), Some((label2, 0.1))) + examples += makeMulticlassExample(x, y, z, (label, 1.0), Some((label2, 0.1)), registry) } else { - examples += makeMulticlassExample(x, y, z, (label, 1.0), None) + examples += makeMulticlassExample(x, y, z, (label, 1.0), None, registry) } } (examples, labels) } - def makeNonlinearMulticlassClassificationExamples() = { + def makeNonlinearMulticlassClassificationExamples(registry: FeatureRegistry) = { val examples = ArrayBuffer[Example]() val labels = ArrayBuffer[String]() val rnd = new java.util.Random(1234) @@ -143,32 +125,25 @@ object TrainingTestHelper { val label : String = if (d < 5) "inner" else if (d < 10) "middle" else "outer" labels += label - examples += makeMulticlassExample(x, y, z, (label, 1.0), None) + examples += makeMulticlassExample(x, y, z, (label, 1.0), None, registry) } (examples, labels) } def makeHybridExample(x : Double, y : Double, - target : Double) : Example = { - val example = new Example - val item: FeatureVector = new FeatureVector - item.setFloatFeatures(new java.util.HashMap) - val floatFeatures = item.getFloatFeatures - floatFeatures.put("$rank", new java.util.HashMap) - floatFeatures.get("$rank").put("", target) - floatFeatures.put("loc", new java.util.HashMap) - floatFeatures.put("xy", new util.HashMap) - val loc = floatFeatures.get("loc") - loc.put("x", x) - loc.put("y", y) - val xy = floatFeatures.get("xy") - xy.put("xy", x * y) - example.addToExample(item) + target : Double, + registry : FeatureRegistry) : Example = { + val example = new SimpleExample(registry) + val item: FeatureVector = example.createVector() + item.put("$rank", "", target) + item.put("loc", "x", x) + item.put("loc", "y", y) + item.put("xy", "xy", x*y) example } - def makeClassificationExamples = { + def makeClassificationExamples(registry: FeatureRegistry) = { val examples = ArrayBuffer[Example]() val label = ArrayBuffer[Double]() val rnd = new java.util.Random(1234) @@ -184,12 +159,12 @@ object TrainingTestHelper { } if (rank > 0) numPos = numPos + 1 label += rank - examples += makeExample(x, y, rank) + examples += makeExample(x, y, rank, registry) } (examples, label, numPos) } - def makeLinearClassificationExamples = { + def makeLinearClassificationExamples(registry: FeatureRegistry) = { val examples = ArrayBuffer[Example]() val label = ArrayBuffer[Double]() val rnd = new java.util.Random(1234) @@ -205,12 +180,12 @@ object TrainingTestHelper { } if (rank > 0) numPos = numPos + 1 label += rank - examples += makeHybridExample(x, y, rank) + examples += makeHybridExample(x, y, rank, registry) } (examples, label, numPos) } - def makeRegressionExamples(randomSeed: Int = 1234) = { + def makeRegressionExamples(registry: FeatureRegistry, randomSeed: Int = 1234) = { val examples = ArrayBuffer[Example]() val label = ArrayBuffer[Double]() val rnd = new java.util.Random(randomSeed) @@ -222,14 +197,14 @@ object TrainingTestHelper { // Curve will be a "saddle" with flat regions where, for instance, x = 0 and y > 2.06 or y < -1.96 val flattenedQuadratic = math.max(x * x - 2 * y * y - 0.5 * x + 0.2 * y, -8.0) - examples += makeExample(x, y, flattenedQuadratic) + examples += makeExample(x, y, flattenedQuadratic, registry) label += flattenedQuadratic } (examples, label) } - def makeLinearRegressionExamples(randomSeed: Int = 1234) = { + def makeLinearRegressionExamples(registry: FeatureRegistry, randomSeed: Int = 1234) = { val examples = ArrayBuffer[Example]() val label = ArrayBuffer[Double]() val rnd = new java.util.Random(randomSeed) @@ -238,7 +213,7 @@ object TrainingTestHelper { val x = 2.0 * (rnd.nextDouble() - 0.5) val y = 2.0 * (rnd.nextDouble() - 0.5) val z = 0.1 * x * y - 0.5 * x + 0.2 * y + 1.0 - examples += makeHybridExample(x, y, z) + examples += makeHybridExample(x, y, z, registry) label += z } @@ -247,29 +222,24 @@ object TrainingTestHelper { def printSpline(model: SplineModel) = { val weights = model.getWeightSpline.asScala - for (familyMap <- weights) { - for (featureMap <- familyMap._2.asScala) { + for ((feature, weightSpline) <- weights) { log.info(("family=%s,feature=%s," + "minVal=%f, maxVal=%f, weights=%s") - .format(familyMap._1, - featureMap._1, - featureMap._2.spline.getMinVal, - featureMap._2.spline.getMaxVal, - featureMap._2.spline.getWeights.mkString(",") + .format(feature.family.name, + feature.name, + weightSpline.spline.getMinVal, + weightSpline.spline.getMaxVal, + weightSpline.spline.getWeights.mkString(",") ) ) - } } } def printAdditiveModel(model: AdditiveModel) = { - val weights = model.getWeights.asScala - for (familyMap <- weights) { - for (featureMap <- familyMap._2.asScala) { - log.info("family=%s,feature=%s".format(familyMap._1, featureMap._1)) - val func = featureMap._2 - log.info(func.toString) - } + val weights = model.weights.asScala + for ((feature, func) <- weights) { + log.info("family=%s,feature=%s".format(feature.family.name, feature.name)) + log.info(func.toString) } } } \ No newline at end of file diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/TrainingUtilsTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/TrainingUtilsTest.scala index abd2ddc7..bf93f946 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/TrainingUtilsTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/TrainingUtilsTest.scala @@ -1,55 +1,55 @@ package com.airbnb.aerosolve.training -import com.typesafe.config.ConfigFactory -import org.slf4j.LoggerFactory -import org.junit.Test +import com.airbnb.aerosolve.core.features.FeatureRegistry import org.apache.spark.SparkContext import org.junit.Assert.assertEquals -import org.junit.Assert.assertTrue -import com.airbnb.aerosolve.core.util.Util +import org.junit.Test +import org.slf4j.LoggerFactory class TrainingUtilsTest { val log = LoggerFactory.getLogger("TrainingUtilsTest") + val registry = new FeatureRegistry @Test def testFeatureStatistics(): Unit = { - val examples = TrainingTestHelper.makeSimpleClassificationExamples._1 + val examples = TrainingTestHelper.makeSimpleClassificationExamples(registry)._1 var sc = new SparkContext("local", "TrainingUtilsTest") try { val statsArr = TrainingUtils.getFeatureStatistics(0, sc.parallelize(examples)) val stats = statsArr.toMap - val statsX = stats.get(("loc", "x")).get + val statsX = stats.get(registry.feature("loc", "x")).get assertEquals(-1, statsX.min, 0.1) assertEquals(1, statsX.max, 0.1) assertEquals(0, statsX.mean, 0.1) // X is uniform between -1 and 1 val xvar = 2.0 * 2.0 / 12.0 assertEquals(xvar, statsX.variance, 0.1) - val statsY = stats.get(("loc", "y")).get + val statsY = stats.get(registry.feature("loc", "y")).get assertEquals(-10.0, statsY.min, 1.0) assertEquals(10.0, statsY.max, 1.0) assertEquals(0.0, statsY.mean, 1.0) // Y is uniform between -10 and 10 val yvar = 20.0 * 20.0 / 12.0 assertEquals(yvar, statsY.variance, 1.0) - val statsBias = stats.get(("BIAS", "B")).get + val statsBias = stats.get(registry.feature("BIAS", "B")).get assertEquals(1.0, statsBias.min, 0.1) assertEquals(1.0, statsBias.max, 0.1) assertEquals(1.0, statsBias.mean, 0.1) assertEquals(0.0, statsBias.variance, 0.1) - val statsNeg = stats.get(("NEG", "T")).get + val statsNeg = stats.get(registry.feature("NEG", "T")).get assertEquals(1.0, statsNeg.min, 0.1) assertEquals(1.0, statsNeg.max, 0.1) assertEquals(1.0, statsNeg.mean, 0.1) assertEquals(0.0, statsNeg.variance, 0.1) - val dictionary = TrainingUtils.createStringDictionaryFromFeatureStatistics(statsArr, Set("$rank")) - assertEquals(4, dictionary.getDictionary().getEntryCount()) - val ex = TrainingTestHelper.makeExample(2.0, 1.0, 2) + val dictionary = TrainingUtils.createStringDictionaryFromFeatureStatistics(statsArr, + Set(registry.family("$rank"))) + assertEquals(4, dictionary.getDictionary.getEntryCount) + val ex = TrainingTestHelper.makeExample(2.0, 1.0, 2, registry) log.info(ex.toString) - val vec = dictionary.makeVectorFromSparseFloats(Util.flattenFeature(ex.example.get(0))); - assertEquals(4, vec.values.length); + val vec = dictionary.makeVectorFromSparseFloats(ex.only) + assertEquals(4, vec.values.length) log.info(vec.toString) val arr = scala.collection.mutable.ArrayBuffer[Float]() for (v <- vec.values) { @@ -57,7 +57,7 @@ class TrainingUtilsTest { } val sorted = arr.sortWith((a, b) => a < b) // Neg is missing - assertEquals(0.0, sorted(0), 0.2) + assertEquals(0.0, sorted.head, 0.2) // Y assertEquals(1.0 / Math.sqrt(yvar), sorted(1), 0.2) // Bias diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/EvalUtilTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/EvalUtilTest.scala index 9cdfecc3..8c217988 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/EvalUtilTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/EvalUtilTest.scala @@ -1,11 +1,15 @@ package com.airbnb.aerosolve.training.pipeline +import com.airbnb.aerosolve.core.features.FeatureRegistry import com.airbnb.aerosolve.training.pipeline.PipelineTestingUtil._ import com.google.common.collect.ImmutableMap import org.junit.Assert._ import org.junit.Test + class EvalUtilTest { + val registry = new FeatureRegistry + @Test def testExampleToEvaluationRecordMulticlass() = { val evalResult = EvalUtil.exampleToEvaluationRecord( @@ -14,7 +18,7 @@ class EvalUtilTest { fullRankLinearModel, false, true, - "LABEL", + registry.family("LABEL"), _ => false ) @@ -36,7 +40,7 @@ class EvalUtilTest { linearModel, false, false, - "LABEL", + registry.family("LABEL"), _ => false ) @@ -53,7 +57,7 @@ class EvalUtilTest { linearModel, true, false, - "LABEL", + registry.family("LABEL"), _ => true ) @@ -68,10 +72,11 @@ class EvalUtilTest { val examples = sc.parallelize(Seq(multiclassExample1, multiclassExample2)) val results = EvalUtil.scoreExamplesForEvaluation( - sc, transformer, fullRankLinearModel, examples, "LABEL", false, true, _ => false + sc, transformer, fullRankLinearModel, examples, registry.family("LABEL"), + false, true, _ => false ).collect() - assertEquals(results.size, 2) + assertEquals(results.length, 2) assertEquals(results(0).getLabels, ImmutableMap.of("label1", 10.0, "label2", 9.0)) assertEquals(results(1).getLabels, ImmutableMap.of("label1", 8.0, "label2", 4.0)) diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/GenericPipelineTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/GenericPipelineTest.scala index 64c0854a..db891b9c 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/GenericPipelineTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/GenericPipelineTest.scala @@ -1,9 +1,12 @@ package com.airbnb.aerosolve.training.pipeline -import com.google.common.collect.{ImmutableMap, ImmutableSet} +import com.airbnb.aerosolve.core.features.FeatureRegistry +import com.google.common.collect.ImmutableMap import org.junit.Assert._ import org.junit.Test +import scala.collection.JavaConverters._ + case class FakeDataRow( i_intFeature1: Int, i_intFeature2: Int, @@ -19,6 +22,8 @@ case class FakeDataRowMulticlass( LABEL: String) class GenericPipelineTest { + val registry = new FeatureRegistry + @Test def hiveTrainingToExample() = { val fakeRow = FakeDataRow( @@ -28,38 +33,21 @@ class GenericPipelineTest { PipelineTestingUtil.withSparkContext(sc => { val (sqlRow, schema) = PipelineTestingUtil.createFakeRowAndSchema(sc, fakeRow) - val example = GenericPipeline.hiveTrainingToExample(sqlRow, schema.fields.toArray, false) - val stringFeatures = example.getExample.get(0).getStringFeatures - val floatFeatures = example.getExample.get(0).getFloatFeatures + val example = GenericPipeline.hiveTrainingToExample(sqlRow, schema.fields.toArray, registry, + false) + val fv = example.only - assertEquals( - stringFeatures.get("s"), - ImmutableSet.of("stringFeature:some string") - ) - assertEquals( - stringFeatures.get("s2"), - ImmutableSet.of("some other string") - ) - assertEquals( - floatFeatures.get("i"), - ImmutableMap.of("intFeature1", 10.0, "intFeature2", 7.0) - ) - assertEquals( - floatFeatures.get("f"), - ImmutableMap.of("floatFeature", 4.1f.toDouble) - ) - assertEquals( - floatFeatures.get("d"), - ImmutableMap.of("doubleFeature", 11.0) - ) - assertEquals( - stringFeatures.get("b"), - ImmutableSet.of("boolFeature:F") - ) - assertEquals( - floatFeatures.get("LABEL"), - ImmutableMap.of("", 4.5) - ) + assertTrue(fv.containsKey("s", "stringFeature:some string")) + assertTrue(fv.containsKey("s2", "some other string")) + + assertEquals(fv.get("i", "intFeature1"), 10.0, 0.01) + assertEquals(fv.get("i", "intFeature2"), 7.0, 0.01) + assertEquals(fv.get("f", "floatFeature"), 4.1, 0.01) + assertEquals(fv.get("d", "sparseFeature"), 11.0, 0.01) + + assertTrue(fv.containsKey("b", "boolFeature:F")) + + assertEquals(fv.get("LABEL", ""), 4.5, 0.01) }) } @@ -72,15 +60,18 @@ class GenericPipelineTest { PipelineTestingUtil.withSparkContext(sc => { val (sqlRow, schema) = PipelineTestingUtil.createFakeRowAndSchema(sc, fakeRow) - val example = GenericPipeline.hiveTrainingToExample(sqlRow, schema.fields.toArray, true) - val floatFeatures = example.getExample.get(0).getFloatFeatures + val example = GenericPipeline.hiveTrainingToExample(sqlRow, schema.fields.toArray, registry, + true) + val fv = example.only + + assertEquals(fv.get("i", "intFeature"), 10.0, 0.01) + + val labelMap = fv.get(registry.family("label")) + .iterator.asScala.map(fv => (fv.feature.name, fv.value)) + .toMap assertEquals( - floatFeatures.get("i"), - ImmutableMap.of("intFeature", 10.0) - ) - assertEquals( - floatFeatures.get("LABEL"), + labelMap, ImmutableMap.of( "CLASS1", 2.1, "CLASS2", 4.5, diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineTestingUtil.scala b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineTestingUtil.scala index a25aea37..8a1c6094 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineTestingUtil.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineTestingUtil.scala @@ -1,10 +1,11 @@ package com.airbnb.aerosolve.training.pipeline -import com.airbnb.aerosolve.core.{Example, FeatureVector, LabelDictionaryEntry} +import com.airbnb.aerosolve.core.{Example, LabelDictionaryEntry} +import com.airbnb.aerosolve.core.features.{FeatureRegistry, SimpleExample} import com.airbnb.aerosolve.core.models.{FullRankLinearModel, LinearModel} import com.airbnb.aerosolve.core.transforms.Transformer import com.airbnb.aerosolve.core.util.FloatVector -import com.google.common.collect.{ImmutableMap, ImmutableSet} +import com.google.common.collect.ImmutableMap import com.typesafe.config.ConfigFactory import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructType @@ -21,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag * Misc. utilities that may be useful for testing Spark pipelines. */ object PipelineTestingUtil { + val registry = new FeatureRegistry val transformer = { val config = """ @@ -36,14 +38,14 @@ object PipelineTestingUtil { |} """.stripMargin - new Transformer(ConfigFactory.parseString(config), "model_transforms") + new Transformer(ConfigFactory.parseString(config), "model_transforms", registry) } // Simple full rank linear model with 2 label classes and 2 features val fullRankLinearModel = { - val model = new FullRankLinearModel() + val model = new FullRankLinearModel(registry) - model.setLabelToIndex(ImmutableMap.of("label1", 0, "label2", 1)) + model.labelToIndex(ImmutableMap.of("label1", 0, "label2", 1)) val labelDictEntry1 = new LabelDictionaryEntry() labelDictEntry1.setLabel("label1") @@ -58,15 +60,15 @@ object PipelineTestingUtil { labelDictionary.add(labelDictEntry1) labelDictionary.add(labelDictEntry2) - model.setLabelDictionary(labelDictionary) + model.labelDictionary(labelDictionary) val floatVector1 = new FloatVector(Array(1.2f, 2.1f)) val floatVector2 = new FloatVector(Array(3.4f, -1.2f)) - model.setWeightVector( + model.weightVector.putAll( ImmutableMap.of( - "f", ImmutableMap.of("feature1", floatVector1, "feature2", floatVector2) - ) + registry.feature("f", "feature1"), floatVector1, + registry.feature("f", "feature2"), floatVector2) ) model @@ -74,71 +76,57 @@ object PipelineTestingUtil { // Simple linear model with 2 features val linearModel = { - val model = new LinearModel() + val model = new LinearModel(registry) - model.setWeights(ImmutableMap.of("s", ImmutableMap.of("feature1", 1.4f, "feature2", 1.3f))) + model.weights.putAll( + ImmutableMap.of( + registry.feature("s", "feature1"), 1.4d, + registry.feature("s", "feature2"), 1.3d)) model } val multiclassExample1 = { - val example = new Example() - val fv = new FeatureVector() - - fv.setFloatFeatures(ImmutableMap.of( - "f", ImmutableMap.of("feature1", 1.2, "feature2", 5.6), - "LABEL", ImmutableMap.of("label1", 10.0, "label2", 9.0) - )) + val example = new SimpleExample(registry) + val fv = example.createVector() - example.addToExample(fv) + fv.put("f", "feature1", 1.2) + fv.put("f", "feature2", 5.6) + fv.put("LABEL", "label1", 10.0) + fv.put("LABEL", "label2", 9.0) example } - val multiclassExample2 = { - val example = new Example() - val fv = new FeatureVector() + val multiclassExample2: Example = { + val example = new SimpleExample(registry) + val fv = example.createVector() - fv.setFloatFeatures(ImmutableMap.of( - "f", ImmutableMap.of("feature1", 1.8, "feature2", -1.6), - "LABEL", ImmutableMap.of("label1", 8.0, "label2", 4.0) - )) - - example.addToExample(fv) + fv.put("f", "feature1", 1.8) + fv.put("f", "feature2", -1.6) + fv.put("LABEL", "label1", 8.0) + fv.put("LABEL", "label2", 4.0) example } - val linearExample1 = { - val example = new Example() - val fv = new FeatureVector() - - fv.setFloatFeatures(ImmutableMap.of( - "LABEL", ImmutableMap.of("", 3.5) - )) + val linearExample1: Example = { + val example = new SimpleExample(registry) + val fv = example.createVector() - fv.setStringFeatures(ImmutableMap.of( - "s", ImmutableSet.of("feature1", "feature2") - )) - - example.addToExample(fv) + fv.putString("s", "feature1") + fv.putString("s", "feature2") + fv.put("LABEL", "", 3.5) example } - val linearExample2 = { - val example = new Example() - val fv = new FeatureVector() - - fv.setFloatFeatures(ImmutableMap.of( - "LABEL", ImmutableMap.of("", -2.0) - )) - - fv.setStringFeatures(ImmutableMap.of( - "s", ImmutableSet.of("feature1") - )) + val linearExample2: Example = { + val example = new SimpleExample(registry) + val fv = example.createVector() - example.addToExample(fv) + fv.putString("s", "feature1") + fv.put("LABEL", "", -2.0) example } diff --git a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtilTest.scala b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtilTest.scala index d6bdb16e..b93bc4e7 100644 --- a/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtilTest.scala +++ b/training/src/test/scala/com/airbnb/aerosolve/training/pipeline/PipelineUtilTest.scala @@ -12,7 +12,7 @@ class PipelineUtilTest { val examples = sc.parallelize(Seq(linearExample1, linearExample2)) val trainingPredicate = (example: Example) => { - example.getExample.get(0).getStringFeatures.get("s").size() == 1 + example.only.get("s").size() == 1 } val results = PipelineUtil.scoreExamples(