Skip to content

Commit

Permalink
Compatible-Array-Sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
apete committed Mar 18, 2024
1 parent 6615d05 commit 5464ff8
Show file tree
Hide file tree
Showing 26 changed files with 1,200 additions and 38 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@ Added / Changed / Deprecated / Fixed / Removed / Security

- Added the ability to specify the initial capacity of a `SparseStore`.

#### org.ojalgo.structure

- There are new methods fillCompatile(...), modifyCompatile(...) and onCompatile(...) to complement the already existing "fill", "modify" and "on" methods. (Inspired by MATLAB's concept of compatible array sizes for binary operation.)
- The factory interfaces got methods to construct instances of compatible sizes/shapes. (Inspired by MATLAB's concept of compatible array sizes for binary operation.)

### Changed

#### org.ojalgo.matrix

- Vector space method like "add" and "subtract" have been redefined to no longer throw exceptions if dimensions are not equal, but instead broadcast/repeat rows or columns. (Inspired by MATLAB's concept of compatible array sizes for binary operation.)

#### org.ojalgo.structure

- Refactored the builder/factory interfaces to better support creating immutable 1D, 2D or AnyD structures. This has implications for most ojAlgo data structures. There are deprecations in all factory classes, but everything that worked before still works (I believe).
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/org/ojalgo/array/Array1D.java
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,13 @@ protected void compute() {

}

public static final Factory<ComplexNumber> C128 = Array1D.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = Array1D.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = Array1D.factory(ArrayQ128.FACTORY);
public static final Factory<Double> R032 = Array1D.factory(ArrayR032.FACTORY);
public static final Factory<Double> R064 = Array1D.factory(ArrayR064.FACTORY);
public static final Factory<Quadruple> R128 = Array1D.factory(ArrayR128.FACTORY);
public static final Factory<BigDecimal> R256 = Array1D.factory(ArrayR256.FACTORY);
public static final Factory<ComplexNumber> C128 = Array1D.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = Array1D.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = Array1D.factory(ArrayQ128.FACTORY);
public static final Factory<Double> Z008 = Array1D.factory(ArrayZ008.FACTORY);
public static final Factory<Double> Z016 = Array1D.factory(ArrayZ016.FACTORY);
public static final Factory<Double> Z032 = Array1D.factory(ArrayZ032.FACTORY);
Expand Down Expand Up @@ -765,6 +765,35 @@ void exchange(final long indexA, final long indexB) {
}
}

Factory1D<Array1D<N>> factory() {

return new Factory1D<>() {

ArrayFactory<N, ?> delegate = myDelegate.factory();

public FunctionSet<?> function() {
return delegate.function();
}

public MathType getMathType() {
return delegate.getMathType();
}

public Array1D<N> make(final int size) {
return this.make((long) size);
}

public Array1D<N> make(final long count) {
return delegate.make(count).wrapInArray1D();
}

public Scalar.Factory<?> scalar() {
return delegate.scalar();
}

};
}

BasicArray<N> getDelegate() {
return myDelegate;
}
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/org/ojalgo/array/Array2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ public TensorFactory2D<N, Array2D<N>> tensor() {

}

public static final Factory<ComplexNumber> C128 = Array2D.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = Array2D.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = Array2D.factory(ArrayQ128.FACTORY);
public static final Factory<Double> R032 = Array2D.factory(ArrayR032.FACTORY);
public static final Factory<Double> R064 = Array2D.factory(ArrayR064.FACTORY);
public static final Factory<Quadruple> R128 = Array2D.factory(ArrayR128.FACTORY);
public static final Factory<BigDecimal> R256 = Array2D.factory(ArrayR256.FACTORY);
public static final Factory<ComplexNumber> C128 = Array2D.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = Array2D.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = Array2D.factory(ArrayQ128.FACTORY);
public static final Factory<Double> Z008 = Array2D.factory(ArrayZ008.FACTORY);
public static final Factory<Double> Z016 = Array2D.factory(ArrayZ016.FACTORY);
public static final Factory<Double> Z032 = Array2D.factory(ArrayZ032.FACTORY);
Expand Down Expand Up @@ -857,6 +857,35 @@ public void visitRow(final long row, final long col, final VoidFunction<N> visit
myDelegate.visit(Structure2D.index(myRowsCount, row, col), Structure2D.index(myRowsCount, row, myColumnsCount), myRowsCount, visitor);
}

Factory2D<Array2D<N>> factory() {

return new Factory2D<>() {

ArrayFactory<N, ?> delegate = myDelegate.factory();

public FunctionSet<?> function() {
return delegate.function();
}

public MathType getMathType() {
return delegate.getMathType();
}

public Array2D<N> make(final int nbRows, final int nbCols) {
return this.make((long) nbRows, (long) nbCols);
}

public Array2D<N> make(final long nbRows, final long nbCols) {
return delegate.make(Structure2D.count(nbRows, nbCols)).wrapInArray2D(nbRows);
}

public Scalar.Factory<?> scalar() {
return delegate.scalar();
}

};
}

BasicArray<N> getDelegate() {
return myDelegate;
}
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/org/ojalgo/array/ArrayAnyD.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ public TensorFactoryAnyD<N, ArrayAnyD<N>> tensor() {

}

public static final Factory<ComplexNumber> C128 = ArrayAnyD.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = ArrayAnyD.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = ArrayAnyD.factory(ArrayQ128.FACTORY);
public static final Factory<Double> R032 = ArrayAnyD.factory(ArrayR032.FACTORY);
public static final Factory<Double> R064 = ArrayAnyD.factory(ArrayR064.FACTORY);
public static final Factory<Quadruple> R128 = ArrayAnyD.factory(ArrayR128.FACTORY);
public static final Factory<BigDecimal> R256 = ArrayAnyD.factory(ArrayR256.FACTORY);
public static final Factory<ComplexNumber> C128 = ArrayAnyD.factory(ArrayC128.FACTORY);
public static final Factory<Quaternion> H256 = ArrayAnyD.factory(ArrayH256.FACTORY);
public static final Factory<RationalNumber> Q128 = ArrayAnyD.factory(ArrayQ128.FACTORY);
public static final Factory<Double> Z008 = ArrayAnyD.factory(ArrayZ008.FACTORY);
public static final Factory<Double> Z016 = ArrayAnyD.factory(ArrayZ016.FACTORY);
public static final Factory<Double> Z032 = ArrayAnyD.factory(ArrayZ032.FACTORY);
Expand Down Expand Up @@ -805,6 +805,35 @@ public void visitSet(final long[] initial, final int dimension, final VoidFuncti
this.loop(initial, dimension, (f, l, s) -> myDelegate.visit(f, l, s, visitor));
}

FactoryAnyD<ArrayAnyD<N>> factory() {

return new FactoryAnyD<>() {

ArrayFactory<N, ?> delegate = myDelegate.factory();

public FunctionSet<?> function() {
return delegate.function();
}

public MathType getMathType() {
return delegate.getMathType();
}

@Override
public ArrayAnyD<N> make(final int... shape) {
return this.make(Structure1D.toLongIndexes(shape));
}

public ArrayAnyD<N> make(final long... shape) {
return delegate.make(StructureAnyD.count(shape)).wrapInArrayAnyD(shape);
}

public Scalar.Factory<?> scalar() {
return delegate.scalar();
}
};
}

BasicArray<N> getDelegate() {
return myDelegate;
}
Expand Down
221 changes: 221 additions & 0 deletions src/main/java/org/ojalgo/array/operation/FillCompatible.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package org.ojalgo.array.operation;

import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.special.MissingMath;
import org.ojalgo.structure.*;

/**
* https://se.mathworks.com/help/matlab/matlab_prog/compatible-array-sizes-for-basic-operations.html
*/
public class FillCompatible {

public static void invoke(final double[] target, final int structure, final Access2D<?> left, final BinaryFunction<?> operator, final Access2D<?> right) {

int nbRows = structure;
int nbCols = target.length / structure;

int modRowsL = left.getRowDim();
int modColsL = left.getColDim();

int modRowsR = right.getRowDim();
int modColsR = right.getColDim();

for (int j = 0; j < nbCols; j++) {

int colL = j % modColsL;
int colR = j % modColsR;

for (int i = 0; i < nbRows; i++) {

int rowL = i % modRowsL;
int rowR = i % modRowsR;

double argL = left.doubleValue(rowL, colL);
double argR = right.doubleValue(rowR, colR);

double newVal = operator.invoke(argL, argR);

target[i + j * structure] = newVal;
}
}
}

public static void invoke(final double[][] target, final Access2D<?> left, final BinaryFunction<?> operator, final Access2D<?> right) {

int nbRows = target.length;
int nbCols = nbRows > 0 ? target[0].length : 0;

int modRowsL = left.getRowDim();
int modColsL = left.getColDim();

int modRowsR = right.getRowDim();
int modColsR = right.getColDim();

for (int j = 0; j < nbCols; j++) {

int colL = j % modColsL;
int colR = j % modColsR;

for (int i = 0; i < nbRows; i++) {

int rowL = i % modRowsL;
int rowR = i % modRowsR;

double argL = left.doubleValue(rowL, colL);
double argR = right.doubleValue(rowR, colR);

double newVal = operator.invoke(argL, argR);

target[i][j] = newVal;
}
}
}

public static void invoke(final float[] target, final int structure, final Access2D<?> left, final BinaryFunction<?> operator, final Access2D<?> right) {

int nbRows = structure;
int nbCols = target.length / structure;

int modRowsL = left.getRowDim();
int modColsL = left.getColDim();

int modRowsR = right.getRowDim();
int modColsR = right.getColDim();

for (int j = 0; j < nbCols; j++) {

int colL = j % modColsL;
int colR = j % modColsR;

for (int i = 0; i < nbRows; i++) {

int rowL = i % modRowsL;
int rowR = i % modRowsR;

float argL = left.floatValue(rowL, colL);
float argR = right.floatValue(rowR, colR);

float newVal = operator.invoke(argL, argR);

target[i + j * structure] = newVal;
}
}
}

public static <N extends Comparable<N>> void invoke(final Mutate1D target, final Access1D<N> left, final BinaryFunction<N> operator,
final Access1D<N> right) {

int size = target.size();

int modL = left.size();
int modR = right.size();

for (int i = 0; i < size; i++) {

int indexL = i % modL;
int indexR = i % modR;

N argL = left.get(indexL);
N argR = right.get(indexR);

N newVal = operator.invoke(argL, argR);

target.set(i, newVal);
}
}

public static <N extends Comparable<N>> void invoke(final Mutate2D target, final Access2D<N> left, final BinaryFunction<N> operator,
final Access2D<N> right) {

int nbRows = target.getRowDim();
int nbCols = target.getColDim();

int modRowsL = left.getRowDim();
int modColsL = left.getColDim();

int modRowsR = right.getRowDim();
int modColsR = right.getColDim();

for (int j = 0; j < nbCols; j++) {

int colL = j % modColsL;
int colR = j % modColsR;

for (int i = 0; i < nbRows; i++) {

int rowL = i % modRowsL;
int rowR = i % modRowsR;

N argL = left.get(rowL, colL);
N argR = right.get(rowR, colR);

N newVal = operator.invoke(argL, argR);

target.set(i, j, newVal);
}
}
}

public static <N extends Comparable<N>> void invoke(final MutateAnyD target, final AccessAnyD<N> left, final BinaryFunction<N> operator,
final AccessAnyD<N> right) {

int rank = MissingMath.max(target.rank(), left.rank(), right.rank());

long[] refT = new long[target.rank()];
long[] refL = new long[target.rank()];
long[] refR = new long[target.rank()];

FillCompatible.doOneOfAnyD(target, refT, rank - 1, left, refL, operator, right, refR);
}

private static <N extends Comparable<N>> void doOneOfAnyD(final MutateAnyD target, final long[] targRef, final int dim, final AccessAnyD<N> left,
final long[] leftRef, final BinaryFunction<N> operator, final AccessAnyD<N> right, final long[] righRef) {

long modL = left.count(dim);
long modR = right.count(dim);

for (long i = 0L, limit = target.count(dim); i < limit; i++) {

targRef[dim] = i;
leftRef[dim] = i % modL;
righRef[dim] = i % modR;

if (dim == 0) {
target.set(targRef, operator.invoke(left.get(leftRef), right.get(righRef)));
} else {
FillCompatible.doOneOfAnyD(target, targRef, dim - 1, left, leftRef, operator, right, righRef);
}
}
}

static <N extends Comparable<N>, T extends Mutate1D> T expand(final Factory1D<T> factory, final Access1D<N> left, final BinaryFunction<N> operator,
final Access1D<N> right) {

T target = factory.make(left, right);

FillCompatible.invoke(target, left, operator, right);

return target;
}

static <N extends Comparable<N>, T extends Mutate2D> T expand(final Factory2D<T> factory, final Access2D<N> left, final BinaryFunction<N> operator,
final Access2D<N> right) {

T target = factory.make(left, right);

FillCompatible.invoke(target, left, operator, right);

return target;
}

static <N extends Comparable<N>, T extends MutateAnyD> T expand(final FactoryAnyD<T> factory, final AccessAnyD<N> left, final BinaryFunction<N> operator,
final AccessAnyD<N> right) {

T target = factory.make(left, right);

FillCompatible.invoke(target, left, operator, right);

return target;
}

}
Loading

0 comments on commit 5464ff8

Please sign in to comment.