25
25
import org .matheclipse .core .interfaces .IReal ;
26
26
import org .matheclipse .core .interfaces .ISparseArray ;
27
27
import org .matheclipse .core .interfaces .ISymbol ;
28
+ import org .matheclipse .core .visit .VisitorLevelSpecification ;
28
29
import it .unimi .dsi .fastutil .ints .IntList ;
29
30
30
31
public class TensorFunctions {
@@ -57,20 +58,59 @@ private static void init() {
57
58
}
58
59
59
60
private static final class ArrayReduce extends AbstractEvaluator {
61
+ private IExpr arrayReduce (IExpr f , IAST array , int [] levels , EvalEngine engine ) {
62
+ IAST currentArray = array ;
63
+ Arrays .sort (levels );
64
+ IntList dimensions = LinearAlgebra .dimensions (array , S .List , Integer .MAX_VALUE , false );
65
+ int iDepth = dimensions .size ();
66
+ for (int i = levels .length - 1 ; i >= 0 ; i --) {
67
+ int level = levels [i ];
68
+ currentArray =
69
+ arrayReduce (f , currentArray , dimensions , level , engine , i == 0 ? true : false );
70
+
71
+ dimensions = LinearAlgebra .dimensions (currentArray , S .List , --iDepth , false );
72
+ dimensions = dimensions .subList (0 , iDepth );
73
+ }
74
+ return currentArray ;
75
+ }
60
76
61
- private IExpr arrayReduce (IExpr f , IAST array , int n , EvalEngine engine ) {
62
- int iDepth = LinearAlgebra .arrayDepth (array );
77
+ /**
78
+ *
79
+ * @param f
80
+ * @param array
81
+ * @param dimensions the dimensions of the array or <code>null</code> if the dimension should be
82
+ * calculated new
83
+ * @param level
84
+ * @param engine
85
+ * @return an array of 2 objects `[IAST, IntList]` with the reduced array and the new dimensions
86
+ */
87
+ private IAST arrayReduce (IExpr f , IAST array , IntList dimensions , int level , EvalEngine engine ,
88
+ boolean doMap ) {
89
+ int iDepth = dimensions == null ? LinearAlgebra .arrayDepth (array ) : dimensions .size ();
63
90
IAST range = ListFunctions .range (iDepth + 1 );
64
- IAST rotateRight = range .rotateRight (F .NIL , n );
65
- IntList dimensions = LinearAlgebra .dimensions (array , S .List , Integer .MAX_VALUE , false );
91
+ IAST rotateRight = range .rotateRight (F .NIL , level );
92
+ if (dimensions == null ) {
93
+ dimensions = LinearAlgebra .dimensions (array , S .List , iDepth , false );
94
+ }
66
95
IAST transposed = (IAST ) LinearAlgebra .transpose (array , rotateRight , dimensions , x -> x ,
67
96
F .Transpose (array , rotateRight ), engine );
68
- IExpr reduced = F .Map (f , transposed , F .List (F .ZZ (iDepth - 1 ))).eval (engine );
69
- IAST rotateLeft = ListFunctions .range (iDepth ).rotateLeft (F .NIL , n - 1 );
70
- dimensions =
71
- LinearAlgebra .dimensions (reduced , S .List , Integer .MAX_VALUE , false );
97
+ IAST reduced ;
98
+ if (doMap ) {
99
+ reduced = (IAST ) F .Map (f , transposed , F .List (F .ZZ (iDepth - 1 ))).eval (engine );
100
+ } else {
101
+ // flatten lists
102
+ VisitorLevelSpecification levelSpec = new VisitorLevelSpecification (
103
+ x -> F .binaryAST2 (S .Apply , S .Sequence , x ), iDepth - 1 , false );
104
+ reduced = (IAST ) transposed .accept (levelSpec );
105
+ }
106
+ if (level == 1 ) {
107
+ return reduced ;
108
+ }
109
+ IAST rotateLeft = ListFunctions .range (iDepth ).rotateLeft (F .NIL , level - 1 );
110
+ dimensions = LinearAlgebra .dimensions (reduced , S .List , Integer .MAX_VALUE , false );
72
111
dimensions = dimensions .subList (0 , iDepth - 1 );
73
- return LinearAlgebra .transpose (reduced , rotateLeft , dimensions , x -> x ,
112
+
113
+ return (IAST ) LinearAlgebra .transpose (reduced , rotateLeft , dimensions , x -> x ,
74
114
F .Transpose (reduced , rotateLeft ), engine );
75
115
}
76
116
@@ -81,12 +121,20 @@ public IExpr evaluate(final IAST ast, EvalEngine engine) {
81
121
final IExpr f = ast .arg1 ();
82
122
IAST tensor = (IAST ) ast .arg2 ();
83
123
final IntList dims = LinearAlgebra .dimensions (tensor , S .List );
84
- int n = ast .arg3 ().toIntDefault ();
124
+ IExpr arg3 = ast .arg3 ();
125
+ if (arg3 .isList ()) {
126
+ int [] ni = Validate .checkListOfInts (ast , arg3 , 1 , dims .size (), engine );
127
+ if (ni == null ) {
128
+ return F .NIL ;
129
+ }
130
+ return arrayReduce (f , tensor , ni , engine );
131
+ }
132
+ int n = arg3 .toIntDefault ();
85
133
if (n > 0 ) {
86
134
if (n == 1 && dims .size () == 1 ) {
87
135
return tensor ;
88
136
}
89
- return arrayReduce (f , tensor , n , engine );
137
+ return arrayReduce (f , tensor , null , n , engine , true );
90
138
}
91
139
}
92
140
return F .NIL ;
0 commit comments