diff --git a/core/src/main/java/com/alibaba/alink/common/comqueue/communication/AllReduce.java b/core/src/main/java/com/alibaba/alink/common/comqueue/communication/AllReduce.java
index ad08f151d..01601d752 100644
--- a/core/src/main/java/com/alibaba/alink/common/comqueue/communication/AllReduce.java
+++ b/core/src/main/java/com/alibaba/alink/common/comqueue/communication/AllReduce.java
@@ -33,8 +33,8 @@
*
* There're mainly three stages:
*
- * - 1. All workers send the there partial data to other workers for reduce.
- * - 2. All workers do reduce on all data it received and then send partial results to others.
+ * - 1. All workers send their partial data to the other workers for reduce.
+ * - 2. All workers do reduce on all data they received and then send partial results to the others.
* - 3. All workers merge partial results into final result and put it into session context with pre-defined
* object name.
*
diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java b/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java
index 87c35e26f..546517af9 100644
--- a/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java
+++ b/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java
@@ -8,7 +8,7 @@
import java.util.function.Consumer;
/**
- * An interface indicating a BatchOperator can information for its training process.
+ * An interface indicating the information a BatchOperator can provide for its training process.
*
* @param the class which conveys the train information.
* @param the BatchOperator class which provides the train information.
diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/MultilayerPerceptronTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/MultilayerPerceptronTrainBatchOp.java
index 3feaa8e31..a9befc9ee 100644
--- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/MultilayerPerceptronTrainBatchOp.java
+++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/MultilayerPerceptronTrainBatchOp.java
@@ -29,7 +29,7 @@
/**
* MultilayerPerceptronClassifier is a neural network based multi-class classifier.
- * Valina neural network with all dense layers are used, the output layer is a softmax layer.
+ * Vanilla neural network with all dense layers are used, the output layer is a softmax layer.
* Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels.
*/
@@ -143,7 +143,8 @@ public MultilayerPerceptronTrainBatchOp linkFrom(BatchOperator>... inputs) {
final int[] layerSize = getLayers();
final int blockSize = getBlockSize();
final DenseVector initialWeights = getInitialWeights();
- Topology topology = FeedForwardTopology.multiLayerPerceptron(layerSize, true);
+ final double dropoutRate = getDropoutRate();
+ Topology topology = FeedForwardTopology.multiLayerPerceptron(layerSize, true, dropoutRate);
FeedForwardTrainer trainer = new FeedForwardTrainer(topology,
layerSize[0], layerSize[layerSize.length - 1], true, blockSize, initialWeights);
DataSet weights = trainer.train(trainData, getParams());
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayer.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayer.java
new file mode 100644
index 000000000..3c82a2d39
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayer.java
@@ -0,0 +1,32 @@
+package com.alibaba.alink.operator.common.classification.ann;
+
+/**
+ * Layer properties of dropout layer.
+ */
+public class DropoutLayer extends Layer {
+ public double dropoutRate;
+
+ public DropoutLayer(double dropoutRate) {
+ this.dropoutRate = dropoutRate;
+ }
+
+ @Override
+ public LayerModel createModel() {
+ return new DropoutLayerModel(this);
+ }
+
+ @Override
+ public int getWeightSize() {
+ return 0;
+ }
+
+ @Override
+ public int getOutputSize(int inputSize) {
+ return inputSize;
+ }
+
+ @Override
+ public boolean isInPlace() {
+ return true;
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayerModel.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayerModel.java
new file mode 100644
index 000000000..7bfdf1345
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/DropoutLayerModel.java
@@ -0,0 +1,48 @@
+package com.alibaba.alink.operator.common.classification.ann;
+
+import com.alibaba.alink.common.linalg.DenseMatrix;
+import com.alibaba.alink.common.linalg.DenseVector;
+import org.apache.commons.math3.distribution.BinomialDistribution;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well19937c;
+
+public class DropoutLayerModel extends LayerModel {
+ private DropoutLayer layer;
+
+ public DropoutLayerModel(DropoutLayer layer) {
+ this.layer = layer;
+ }
+
+ @Override
+ public void resetModel(DenseVector weights, int offset) {
+ }
+
+ @Override
+ public void eval(DenseMatrix data, DenseMatrix output) {
+ double dropoutRate = layer.dropoutRate;
+
+ RandomGenerator randomGenerator = new Well19937c(1L);
+ BinomialDistribution bionimialDistribution = new BinomialDistribution(randomGenerator,1, 1 - dropoutRate);
+
+ for (int i = 0; i < data.numRows(); i++) {
+ for (int j = 0; j < data.numCols(); j++) {
+ output.set(i, j, data.get(i, j) * bionimialDistribution.sample() * (1.0 / (1 - dropoutRate)));
+ }
+ }
+ }
+
+ @Override
+ public void computePrevDelta(DenseMatrix delta, DenseMatrix output, DenseMatrix prevDelta) {
+ for (int i = 0; i < delta.numRows(); i++) {
+ for (int j = 0; j < delta.numCols(); j++) {
+ double y = output.get(i, j);
+ prevDelta.set(i, j, y * delta.get(i, j));
+ }
+ }
+ }
+
+ @Override
+ public void grad(DenseMatrix delta, DenseMatrix input, DenseVector cumGrad, int offset) {
+
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayer.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayer.java
new file mode 100644
index 000000000..b8d154fb3
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayer.java
@@ -0,0 +1,34 @@
+package com.alibaba.alink.operator.common.classification.ann;
+
+/**
+ * Layer properties of embedding transformations which is y = A * x.
+ */
+public class EmbeddingLayer extends Layer {
+ public int numIn;
+ public int embeddingSize;
+
+ public EmbeddingLayer(int numIn, int embeddingSize) {
+ this.numIn = numIn;
+ this.embeddingSize = embeddingSize;
+ }
+
+ @Override
+ public LayerModel createModel() {
+ return new EmbeddingLayerModel(this);
+ }
+
+ @Override
+ public int getWeightSize() {
+ return numIn * embeddingSize;
+ }
+
+ @Override
+ public int getOutputSize(int inputSize) {
+ return embeddingSize;
+ }
+
+ @Override
+ public boolean isInPlace() {
+ return false;
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayerModel.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayerModel.java
new file mode 100644
index 000000000..df5c826fc
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/EmbeddingLayerModel.java
@@ -0,0 +1,64 @@
+
+package com.alibaba.alink.operator.common.classification.ann;
+
+import com.alibaba.alink.common.linalg.BLAS;
+import com.alibaba.alink.common.linalg.DenseMatrix;
+import com.alibaba.alink.common.linalg.DenseVector;
+
+/**
+ * The LayerModel for {@link EmbeddingLayer}
+ */
+public class EmbeddingLayerModel extends LayerModel {
+ private DenseMatrix w;
+
+ // buffer for holding gradw
+ private DenseMatrix gradw;
+
+ public EmbeddingLayerModel(EmbeddingLayer layer) {
+ this.w = new DenseMatrix(layer.numIn, layer.embeddingSize);
+ this.gradw = new DenseMatrix(layer.numIn, layer.embeddingSize);
+ }
+
+ private void pack(DenseVector weights, int offset, DenseMatrix w) {
+ int pos = 0;
+ for (int i = 0; i < this.w.numRows(); i++) {
+ for (int j = 0; j < this.w.numCols(); j++) {
+ weights.set(offset + pos, w.get(i, j));
+ pos++;
+ }
+ }
+ }
+
+ private void unpack(DenseVector weights, int offset, DenseMatrix w) {
+ int pos = 0;
+ for (int i = 0; i < this.w.numRows(); i++) {
+ for (int j = 0; j < this.w.numCols(); j++) {
+ w.set(i, j, weights.get(offset + pos));
+ pos++;
+ }
+ }
+ }
+
+ @Override
+ public void resetModel(DenseVector weights, int offset) {
+ unpack(weights, offset, this.w);
+ }
+
+ @Override
+ public void eval(DenseMatrix data, DenseMatrix output) {
+ BLAS.gemm(1., data, false, this.w, false, 0., output);
+ }
+
+ @Override
+ public void computePrevDelta(DenseMatrix delta, DenseMatrix output, DenseMatrix prevDelta) {
+ BLAS.gemm(1.0, delta, false, this.w, true, 0., prevDelta);
+ }
+
+ @Override
+ public void grad(DenseMatrix delta, DenseMatrix input, DenseVector cumGrad, int offset) {
+ unpack(cumGrad, offset, this.gradw);
+ int batchSize = input.numRows();
+ BLAS.gemm(1.0, input, true, delta, false, 1.0, this.gradw);
+ pack(cumGrad, offset, this.gradw);
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/FeedForwardTopology.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/FeedForwardTopology.java
index c389e6cac..573589c64 100644
--- a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/FeedForwardTopology.java
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/FeedForwardTopology.java
@@ -18,7 +18,14 @@ public FeedForwardTopology(List layers) {
this.layers = layers;
}
- public static FeedForwardTopology multiLayerPerceptron(int[] layerSize, boolean softmaxOnTop) {
+ public static FeedForwardTopology multiLayerPerceptron(int[] layerSize, boolean softmaxOnTop, double dropoutRate) {
+ return multiLayerPerceptron(layerSize, softmaxOnTop, dropoutRate, "sigmoid");
+ }
+
+ public static FeedForwardTopology multiLayerPerceptron(int[] layerSize,
+ boolean softmaxOnTop,
+ double dropoutRate,
+ String activation) {
List layers = new ArrayList<>((layerSize.length - 1) * 2);
for (int i = 0; i < layerSize.length - 1; i++) {
layers.add(new AffineLayer(layerSize[i], layerSize[i + 1]));
@@ -29,7 +36,17 @@ public static FeedForwardTopology multiLayerPerceptron(int[] layerSize, boolean
layers.add(new SigmoidLayerWithSquaredError());
}
} else {
- layers.add(new FuntionalLayer(new SigmoidFunction()));
+ if (activation.toLowerCase().equals("sigmoid")) {
+ layers.add(new FuntionalLayer(new SigmoidFunction()));
+ } else if (activation.toLowerCase().equals("relu")) {
+ layers.add(new FuntionalLayer(new ReluFunction()));
+ } else if (activation.toLowerCase().equals("tanh")) {
+ layers.add(new FuntionalLayer(new TanhFunction()));
+ } else {
+ throw new RuntimeException("This activation method is not supported now.");
+ }
+
+ layers.add(new DropoutLayer(dropoutRate));
}
}
return new FeedForwardTopology(layers);
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/MlpcModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/MlpcModelMapper.java
index 4d19da6ea..a52ded90e 100644
--- a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/MlpcModelMapper.java
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/MlpcModelMapper.java
@@ -64,7 +64,7 @@ public void loadModel(List modelRows) {
this.labels = model.labels;
int[] layerSize0 = model.meta.get(MultilayerPerceptronTrainParams.LAYERS);
- Topology topology = FeedForwardTopology.multiLayerPerceptron(layerSize0, true);
+ Topology topology = FeedForwardTopology.multiLayerPerceptron(layerSize0, true, 0);
this.topo = topology.getModel(model.weights);
this.predDetailMap = new HashMap<>(layerSize0[layerSize0.length - 1]);
isVectorInput = model.meta.get(ModelParamName.IS_VECTOR_INPUT);
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/ReluFunction.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/ReluFunction.java
new file mode 100644
index 000000000..c1c3dcbf6
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/ReluFunction.java
@@ -0,0 +1,22 @@
+package com.alibaba.alink.operator.common.classification.ann;
+
+/**
+ * The Relu function.
+ * f(x) = max(0, x)
+ * f'(x) = 1 if x > 0 else 0
+ */
+public class ReluFunction implements ActivationFunction {
+ @Override
+ public double eval(double x) {
+ return Math.max(0, x);
+ }
+
+ @Override
+ public double derivative(double x) {
+ if (x > 0) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/TanhFunction.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/TanhFunction.java
new file mode 100644
index 000000000..679db9c7c
--- /dev/null
+++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/TanhFunction.java
@@ -0,0 +1,18 @@
+package com.alibaba.alink.operator.common.classification.ann;
+
+/**
+ * The tanh function.
+ * tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
+ * tanh'(x) = 1 - tanh(z)^2
+ */
+public class TanhFunction implements ActivationFunction {
+ @Override
+ public double eval(double x) {
+ return (Math.exp(x) - Math.exp(-x)) / (Math.exp(x) + Math.exp(-x));
+ }
+
+ @Override
+ public double derivative(double z) {
+ return 1 - Math.pow(z, 2.0);
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/java/com/alibaba/alink/params/classification/MultilayerPerceptronTrainParams.java b/core/src/main/java/com/alibaba/alink/params/classification/MultilayerPerceptronTrainParams.java
index ea4dc5cc4..c377eea0f 100644
--- a/core/src/main/java/com/alibaba/alink/params/classification/MultilayerPerceptronTrainParams.java
+++ b/core/src/main/java/com/alibaba/alink/params/classification/MultilayerPerceptronTrainParams.java
@@ -36,6 +36,11 @@ public interface MultilayerPerceptronTrainParams extends
.setDescription("Initial weights.")
.setHasDefaultValue(null)
.build();
+ ParamInfo DROPOUT_RATE = ParamInfoFactory
+ .createParamInfo("dropoutRate", Double.class)
+ .setDescription("Dropout rate for MLP")
+ .setHasDefaultValue(0.)
+ .build();
default int[] getLayers() {
return get(LAYERS);
@@ -60,4 +65,12 @@ default DenseVector getInitialWeights() {
default T setInitialWeights(DenseVector value) {
return set(INITIAL_WEIGHTS, value);
}
+
+ default Double getDropoutRate() {
+ return get(DROPOUT_RATE);
+ }
+
+ default T setDropoutRate(Double value) {
+ return set(DROPOUT_RATE, value);
+ }
}
diff --git a/core/src/test/java/com/alibaba/alink/pipeline/classification/MultilayerPerceptronClassifierTest.java b/core/src/test/java/com/alibaba/alink/pipeline/classification/MultilayerPerceptronClassifierTest.java
index 4e8a24561..790007476 100644
--- a/core/src/test/java/com/alibaba/alink/pipeline/classification/MultilayerPerceptronClassifierTest.java
+++ b/core/src/test/java/com/alibaba/alink/pipeline/classification/MultilayerPerceptronClassifierTest.java
@@ -17,6 +17,7 @@ public void testMLPC() throws Exception {
.setFeatureCols(Iris.getFeatureColNames())
.setLabelCol(Iris.getLabelColName())
.setLayers(new int[]{4, 5, 3})
+ .setDropoutRate(0.2)
.setMaxIter(100)
.setPredictionCol("pred_label")
.setPredictionDetailCol("pred_detail");
diff --git a/docs/cn/samplewithsizebatchop.md b/docs/cn/samplewithsizebatchop.md
index e5ec2972d..510d85016 100644
--- a/docs/cn/samplewithsizebatchop.md
+++ b/docs/cn/samplewithsizebatchop.md
@@ -31,7 +31,7 @@ df = pd.DataFrame({"Y": data[:, 0]})
# batch source
inOp = dataframeToOperator(df, schemaStr='Y string', op_type='batch')
-sampleOp = SampleBatchOp()\
+sampleOp = SampleWithSizeBatchOp()\
.setSize(2)\
.setWithReplacement(False)
diff --git a/docs/en/samplewithsizebatchop.md b/docs/en/samplewithsizebatchop.md
index 2ac8ee63f..20cfc7073 100644
--- a/docs/en/samplewithsizebatchop.md
+++ b/docs/en/samplewithsizebatchop.md
@@ -27,7 +27,7 @@ df = pd.DataFrame({"Y": data[:, 0]})
# batch source
inOp = dataframeToOperator(df, schemaStr='Y string', op_type='batch')
-sampleOp = SampleBatchOp()\
+sampleOp = SampleWithSizeBatchOp()\
.setSize(2)\
.setWithReplacement(False)