Skip to content

Commit faffcbe

Browse files
committed
Improve ArrayReduce for list of integers
- implement special case `n == 1`
1 parent ac20e1a commit faffcbe

File tree

3 files changed

+318
-12
lines changed

3 files changed

+318
-12
lines changed

symja_android_library/matheclipse-core/src/main/java/org/matheclipse/core/builtin/TensorFunctions.java

+59-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.matheclipse.core.interfaces.IReal;
2626
import org.matheclipse.core.interfaces.ISparseArray;
2727
import org.matheclipse.core.interfaces.ISymbol;
28+
import org.matheclipse.core.visit.VisitorLevelSpecification;
2829
import it.unimi.dsi.fastutil.ints.IntList;
2930

3031
public class TensorFunctions {
@@ -57,20 +58,59 @@ private static void init() {
5758
}
5859

5960
private static final class ArrayReduce extends AbstractEvaluator {
61+
private IExpr arrayReduce(IExpr f, IAST array, int[] levels, EvalEngine engine) {
62+
IAST currentArray = array;
63+
Arrays.sort(levels);
64+
IntList dimensions = LinearAlgebra.dimensions(array, S.List, Integer.MAX_VALUE, false);
65+
int iDepth = dimensions.size();
66+
for (int i = levels.length - 1; i >= 0; i--) {
67+
int level = levels[i];
68+
currentArray =
69+
arrayReduce(f, currentArray, dimensions, level, engine, i == 0 ? true : false);
70+
71+
dimensions = LinearAlgebra.dimensions(currentArray, S.List, --iDepth, false);
72+
dimensions = dimensions.subList(0, iDepth);
73+
}
74+
return currentArray;
75+
}
6076

61-
private IExpr arrayReduce(IExpr f, IAST array, int n, EvalEngine engine) {
62-
int iDepth = LinearAlgebra.arrayDepth(array);
77+
/**
78+
*
79+
* @param f
80+
* @param array
81+
* @param dimensions the dimensions of the array or <code>null</code> if the dimension should be
82+
* calculated new
83+
* @param level
84+
* @param engine
85+
* @return an array of 2 objects `[IAST, IntList]` with the reduced array and the new dimensions
86+
*/
87+
private IAST arrayReduce(IExpr f, IAST array, IntList dimensions, int level, EvalEngine engine,
88+
boolean doMap) {
89+
int iDepth = dimensions == null ? LinearAlgebra.arrayDepth(array) : dimensions.size();
6390
IAST range = ListFunctions.range(iDepth + 1);
64-
IAST rotateRight = range.rotateRight(F.NIL, n);
65-
IntList dimensions = LinearAlgebra.dimensions(array, S.List, Integer.MAX_VALUE, false);
91+
IAST rotateRight = range.rotateRight(F.NIL, level);
92+
if (dimensions == null) {
93+
dimensions = LinearAlgebra.dimensions(array, S.List, iDepth, false);
94+
}
6695
IAST transposed = (IAST) LinearAlgebra.transpose(array, rotateRight, dimensions, x -> x,
6796
F.Transpose(array, rotateRight), engine);
68-
IExpr reduced = F.Map(f, transposed, F.List(F.ZZ(iDepth - 1))).eval(engine);
69-
IAST rotateLeft = ListFunctions.range(iDepth).rotateLeft(F.NIL, n - 1);
70-
dimensions =
71-
LinearAlgebra.dimensions(reduced, S.List, Integer.MAX_VALUE, false);
97+
IAST reduced;
98+
if (doMap) {
99+
reduced = (IAST) F.Map(f, transposed, F.List(F.ZZ(iDepth - 1))).eval(engine);
100+
} else {
101+
// flatten lists
102+
VisitorLevelSpecification levelSpec = new VisitorLevelSpecification(
103+
x -> F.binaryAST2(S.Apply, S.Sequence, x), iDepth - 1, false);
104+
reduced = (IAST) transposed.accept(levelSpec);
105+
}
106+
if (level == 1) {
107+
return reduced;
108+
}
109+
IAST rotateLeft = ListFunctions.range(iDepth).rotateLeft(F.NIL, level - 1);
110+
dimensions = LinearAlgebra.dimensions(reduced, S.List, Integer.MAX_VALUE, false);
72111
dimensions = dimensions.subList(0, iDepth - 1);
73-
return LinearAlgebra.transpose(reduced, rotateLeft, dimensions, x -> x,
112+
113+
return (IAST) LinearAlgebra.transpose(reduced, rotateLeft, dimensions, x -> x,
74114
F.Transpose(reduced, rotateLeft), engine);
75115
}
76116

@@ -81,12 +121,20 @@ public IExpr evaluate(final IAST ast, EvalEngine engine) {
81121
final IExpr f = ast.arg1();
82122
IAST tensor = (IAST) ast.arg2();
83123
final IntList dims = LinearAlgebra.dimensions(tensor, S.List);
84-
int n = ast.arg3().toIntDefault();
124+
IExpr arg3 = ast.arg3();
125+
if (arg3.isList()) {
126+
int[] ni = Validate.checkListOfInts(ast, arg3, 1, dims.size(), engine);
127+
if (ni == null) {
128+
return F.NIL;
129+
}
130+
return arrayReduce(f, tensor, ni, engine);
131+
}
132+
int n = arg3.toIntDefault();
85133
if (n > 0) {
86134
if (n == 1 && dims.size() == 1) {
87135
return tensor;
88136
}
89-
return arrayReduce(f, tensor, n, engine);
137+
return arrayReduce(f, tensor, null, n, engine, true);
90138
}
91139
}
92140
return F.NIL;

symja_android_library/matheclipse-core/src/main/java/org/matheclipse/core/interfaces/IExpr.java

+5
Original file line numberDiff line numberDiff line change
@@ -3254,6 +3254,11 @@ default boolean isIndeterminate() {
32543254
return false;
32553255
}
32563256

3257+
@Override
3258+
default boolean isInfinite() {
3259+
return false;
3260+
}
3261+
32573262
/**
32583263
* Test if this expression is an inexact number. I.e. an instance of type <code>INum</code> or
32593264
* <code>IComplexNum</code>.

0 commit comments

Comments
 (0)