@@ -56,69 +56,90 @@ pub mod train_and_predict_functions {
56
56
y : Array ,
57
57
algorithm : ImmutableString ,
58
58
) -> Result < Model , Box < EvalAltResult > > {
59
- let algorithm_string = algorithm. as_str ( ) ;
60
- let xvec = smartcorelib:: linalg:: basic:: matrix:: DenseMatrix :: from_2d_vec (
61
- & x. into_iter ( )
62
- . map ( |observation| {
63
- array_to_vec_float ( & mut observation. clone ( ) . into_array ( ) . unwrap ( ) )
64
- } )
65
- . collect :: < Vec < Vec < FLOAT > > > ( ) ,
66
- ) ;
67
- match algorithm_string {
68
- "linear" => {
69
- let yvec = y
70
- . clone ( )
71
- . into_iter ( )
72
- . map ( |el| el. as_float ( ) . unwrap ( ) )
73
- . collect :: < Vec < FLOAT > > ( ) ;
74
- match LinearRegression :: fit ( & xvec, & yvec, LinearRegressionParameters :: default ( ) ) {
75
- Ok ( model) => Ok ( Model {
76
- saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
77
- model_type : algorithm_string. to_string ( ) ,
78
- } ) ,
79
- Err ( e) => {
80
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
59
+ // Make x array
60
+ let array_as_vec_vec_float = & x
61
+ . into_iter ( )
62
+ . map ( |observation| {
63
+ crate :: train_and_predict_functions:: array_to_vec_float (
64
+ & mut observation. clone ( ) . into_array ( ) . unwrap ( ) ,
65
+ )
66
+ } )
67
+ . collect :: < Vec < Vec < FLOAT > > > ( ) ;
68
+
69
+ // Check if x array is empty
70
+ if array_as_vec_vec_float. len ( ) == 0 {
71
+ Err ( EvalAltResult :: ErrorArrayBounds ( 0 , 0 , Position :: NONE ) . into ( ) )
72
+ } else {
73
+ let algorithm_string = algorithm. as_str ( ) ;
74
+ let xvec = smartcorelib:: linalg:: basic:: matrix:: DenseMatrix :: from_2d_vec (
75
+ array_as_vec_vec_float,
76
+ ) ;
77
+ match algorithm_string {
78
+ "linear" => {
79
+ let yvec = y
80
+ . clone ( )
81
+ . into_iter ( )
82
+ . map ( |el| el. as_float ( ) . unwrap ( ) )
83
+ . collect :: < Vec < FLOAT > > ( ) ;
84
+ match LinearRegression :: fit ( & xvec, & yvec, LinearRegressionParameters :: default ( ) )
85
+ {
86
+ Ok ( model) => Ok ( Model {
87
+ saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
88
+ model_type : algorithm_string. to_string ( ) ,
89
+ } ) ,
90
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
91
+ format ! ( "{e}" ) ,
92
+ Position :: NONE ,
93
+ )
94
+ . into ( ) ) ,
81
95
}
82
96
}
83
- }
84
- "lasso" => {
85
- let yvec = y
86
- . clone ( )
87
- . into_iter ( )
88
- . map ( |el| el. as_float ( ) . unwrap ( ) )
89
- . collect :: < Vec < FLOAT > > ( ) ;
90
- match Lasso :: fit ( & xvec, & yvec, LassoParameters :: default ( ) ) {
91
- Ok ( model) => Ok ( Model {
92
- saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
93
- model_type : algorithm_string. to_string ( ) ,
94
- } ) ,
95
- Err ( e) => {
96
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
97
+ "lasso" => {
98
+ let yvec = y
99
+ . clone ( )
100
+ . into_iter ( )
101
+ . map ( |el| el. as_float ( ) . unwrap ( ) )
102
+ . collect :: < Vec < FLOAT > > ( ) ;
103
+ match Lasso :: fit ( & xvec, & yvec, LassoParameters :: default ( ) ) {
104
+ Ok ( model) => Ok ( Model {
105
+ saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
106
+ model_type : algorithm_string. to_string ( ) ,
107
+ } ) ,
108
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
109
+ format ! ( "{e}" ) ,
110
+ Position :: NONE ,
111
+ )
112
+ . into ( ) ) ,
97
113
}
98
114
}
99
- }
100
- "logistic" => {
101
- let yvec = y
102
- . clone ( )
103
- . into_iter ( )
104
- . map ( |el| el. as_int ( ) . unwrap ( ) )
105
- . collect :: < Vec < INT > > ( ) ;
106
- match LogisticRegression :: fit ( & xvec, & yvec, LogisticRegressionParameters :: default ( ) )
107
- {
108
- Ok ( model) => Ok ( Model {
109
- saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
110
- model_type : algorithm_string. to_string ( ) ,
111
- } ) ,
112
- Err ( e) => {
113
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
115
+ "logistic" => {
116
+ let yvec = y
117
+ . clone ( )
118
+ . into_iter ( )
119
+ . map ( |el| el. as_int ( ) . unwrap ( ) )
120
+ . collect :: < Vec < INT > > ( ) ;
121
+ match LogisticRegression :: fit (
122
+ & xvec,
123
+ & yvec,
124
+ LogisticRegressionParameters :: default ( ) ,
125
+ ) {
126
+ Ok ( model) => Ok ( Model {
127
+ saved_model : bincode:: serialize ( & model) . unwrap ( ) ,
128
+ model_type : algorithm_string. to_string ( ) ,
129
+ } ) ,
130
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
131
+ format ! ( "{e}" ) ,
132
+ Position :: NONE ,
133
+ )
134
+ . into ( ) ) ,
114
135
}
115
136
}
137
+ & _ => Err ( EvalAltResult :: ErrorArithmetic (
138
+ format ! ( "{} is not a recognized model type." , algorithm_string) ,
139
+ Position :: NONE ,
140
+ )
141
+ . into ( ) ) ,
116
142
}
117
- & _ => Err ( EvalAltResult :: ErrorArithmetic (
118
- format ! ( "{} is not a recognized model type." , algorithm_string) ,
119
- Position :: NONE ,
120
- )
121
- . into ( ) ) ,
122
143
}
123
144
}
124
145
@@ -136,59 +157,78 @@ pub mod train_and_predict_functions {
136
157
/// ```
137
158
#[ rhai_fn( name = "predict" , return_raw, pure) ]
138
159
pub fn predict_with_model ( x : & mut Array , model : Model ) -> Result < Array , Box < EvalAltResult > > {
139
- let xvec = DenseMatrix :: from_2d_vec (
140
- & x. into_iter ( )
141
- . map ( |observation| {
142
- array_to_vec_float ( & mut observation. clone ( ) . into_array ( ) . unwrap ( ) )
143
- } )
144
- . collect :: < Vec < Vec < FLOAT > > > ( ) ,
145
- ) ;
146
- let algorithm_string = model. model_type . as_str ( ) ;
147
- match algorithm_string {
148
- "linear" => {
149
- let model_ready: LinearRegression < FLOAT , FLOAT , DenseMatrix < FLOAT > , Vec < FLOAT > > =
150
- bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
151
- return match model_ready. predict ( & xvec) {
152
- Ok ( y) => Ok ( y
153
- . into_iter ( )
154
- . map ( |observation| Dynamic :: from_float ( observation) )
155
- . collect :: < Vec < Dynamic > > ( ) ) ,
156
- Err ( e) => {
157
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
158
- }
159
- } ;
160
- }
161
- "lasso" => {
162
- let model_ready: Lasso < FLOAT , FLOAT , DenseMatrix < FLOAT > , Vec < FLOAT > > =
163
- bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
164
- return match model_ready. predict ( & xvec) {
165
- Ok ( y) => Ok ( y
166
- . into_iter ( )
167
- . map ( |observation| Dynamic :: from_float ( observation) )
168
- . collect :: < Vec < Dynamic > > ( ) ) ,
169
- Err ( e) => {
170
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
171
- }
172
- } ;
173
- }
174
- "logistic" => {
175
- let model_ready: LogisticRegression < FLOAT , INT , DenseMatrix < FLOAT > , Vec < INT > > =
176
- bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
177
- return match model_ready. predict ( & xvec) {
178
- Ok ( y) => Ok ( y
179
- . into_iter ( )
180
- . map ( |observation| Dynamic :: from_int ( observation) )
181
- . collect :: < Vec < Dynamic > > ( ) ) ,
182
- Err ( e) => {
183
- Err ( EvalAltResult :: ErrorArithmetic ( format ! ( "{e}" ) , Position :: NONE ) . into ( ) )
184
- }
185
- } ;
160
+ // Make x array
161
+ let array_as_vec_vec_float = & x
162
+ . into_iter ( )
163
+ . map ( |observation| {
164
+ crate :: train_and_predict_functions:: array_to_vec_float (
165
+ & mut observation. clone ( ) . into_array ( ) . unwrap ( ) ,
166
+ )
167
+ } )
168
+ . collect :: < Vec < Vec < FLOAT > > > ( ) ;
169
+
170
+ // Check if x array is empty
171
+ if array_as_vec_vec_float. len ( ) == 0 {
172
+ Err ( EvalAltResult :: ErrorArrayBounds ( 0 , 0 , Position :: NONE ) . into ( ) )
173
+ } else {
174
+ let xvec = DenseMatrix :: from_2d_vec ( array_as_vec_vec_float) ;
175
+ let algorithm_string = model. model_type . as_str ( ) ;
176
+ match algorithm_string {
177
+ "linear" => {
178
+ let model_ready: LinearRegression <
179
+ FLOAT ,
180
+ FLOAT ,
181
+ DenseMatrix < FLOAT > ,
182
+ Vec < FLOAT > ,
183
+ > = bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
184
+ return match model_ready. predict ( & xvec) {
185
+ Ok ( y) => Ok ( y
186
+ . into_iter ( )
187
+ . map ( |observation| Dynamic :: from_float ( observation) )
188
+ . collect :: < Vec < Dynamic > > ( ) ) ,
189
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
190
+ format ! ( "{e}" ) ,
191
+ Position :: NONE ,
192
+ )
193
+ . into ( ) ) ,
194
+ } ;
195
+ }
196
+ "lasso" => {
197
+ let model_ready: Lasso < FLOAT , FLOAT , DenseMatrix < FLOAT > , Vec < FLOAT > > =
198
+ bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
199
+ return match model_ready. predict ( & xvec) {
200
+ Ok ( y) => Ok ( y
201
+ . into_iter ( )
202
+ . map ( |observation| Dynamic :: from_float ( observation) )
203
+ . collect :: < Vec < Dynamic > > ( ) ) ,
204
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
205
+ format ! ( "{e}" ) ,
206
+ Position :: NONE ,
207
+ )
208
+ . into ( ) ) ,
209
+ } ;
210
+ }
211
+ "logistic" => {
212
+ let model_ready: LogisticRegression < FLOAT , INT , DenseMatrix < FLOAT > , Vec < INT > > =
213
+ bincode:: deserialize ( & * model. saved_model ) . unwrap ( ) ;
214
+ return match model_ready. predict ( & xvec) {
215
+ Ok ( y) => Ok ( y
216
+ . into_iter ( )
217
+ . map ( |observation| Dynamic :: from_int ( observation) )
218
+ . collect :: < Vec < Dynamic > > ( ) ) ,
219
+ Err ( e) => Err ( EvalAltResult :: ErrorArithmetic (
220
+ format ! ( "{e}" ) ,
221
+ Position :: NONE ,
222
+ )
223
+ . into ( ) ) ,
224
+ } ;
225
+ }
226
+ & _ => Err ( EvalAltResult :: ErrorArithmetic (
227
+ format ! ( "{} is not a recognized model type." , algorithm_string) ,
228
+ Position :: NONE ,
229
+ )
230
+ . into ( ) ) ,
186
231
}
187
- & _ => Err ( EvalAltResult :: ErrorArithmetic (
188
- format ! ( "{} is not a recognized model type." , algorithm_string) ,
189
- Position :: NONE ,
190
- )
191
- . into ( ) ) ,
192
232
}
193
233
}
194
234
}
0 commit comments