Skip to content

Commit 65b1bc4

Browse files
committed
Checking for non-zero size of x-matrix
1 parent d9467d3 commit 65b1bc4

File tree

1 file changed

+148
-108
lines changed

1 file changed

+148
-108
lines changed

src/train_and_predict.rs

+148-108
Original file line numberDiff line numberDiff line change
@@ -56,69 +56,90 @@ pub mod train_and_predict_functions {
5656
y: Array,
5757
algorithm: ImmutableString,
5858
) -> Result<Model, Box<EvalAltResult>> {
59-
let algorithm_string = algorithm.as_str();
60-
let xvec = smartcorelib::linalg::basic::matrix::DenseMatrix::from_2d_vec(
61-
&x.into_iter()
62-
.map(|observation| {
63-
array_to_vec_float(&mut observation.clone().into_array().unwrap())
64-
})
65-
.collect::<Vec<Vec<FLOAT>>>(),
66-
);
67-
match algorithm_string {
68-
"linear" => {
69-
let yvec = y
70-
.clone()
71-
.into_iter()
72-
.map(|el| el.as_float().unwrap())
73-
.collect::<Vec<FLOAT>>();
74-
match LinearRegression::fit(&xvec, &yvec, LinearRegressionParameters::default()) {
75-
Ok(model) => Ok(Model {
76-
saved_model: bincode::serialize(&model).unwrap(),
77-
model_type: algorithm_string.to_string(),
78-
}),
79-
Err(e) => {
80-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
59+
// Make x array
60+
let array_as_vec_vec_float = &x
61+
.into_iter()
62+
.map(|observation| {
63+
crate::train_and_predict_functions::array_to_vec_float(
64+
&mut observation.clone().into_array().unwrap(),
65+
)
66+
})
67+
.collect::<Vec<Vec<FLOAT>>>();
68+
69+
// Check if x array is empty
70+
if array_as_vec_vec_float.len() == 0 {
71+
Err(EvalAltResult::ErrorArrayBounds(0, 0, Position::NONE).into())
72+
} else {
73+
let algorithm_string = algorithm.as_str();
74+
let xvec = smartcorelib::linalg::basic::matrix::DenseMatrix::from_2d_vec(
75+
array_as_vec_vec_float,
76+
);
77+
match algorithm_string {
78+
"linear" => {
79+
let yvec = y
80+
.clone()
81+
.into_iter()
82+
.map(|el| el.as_float().unwrap())
83+
.collect::<Vec<FLOAT>>();
84+
match LinearRegression::fit(&xvec, &yvec, LinearRegressionParameters::default())
85+
{
86+
Ok(model) => Ok(Model {
87+
saved_model: bincode::serialize(&model).unwrap(),
88+
model_type: algorithm_string.to_string(),
89+
}),
90+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
91+
format!("{e}"),
92+
Position::NONE,
93+
)
94+
.into()),
8195
}
8296
}
83-
}
84-
"lasso" => {
85-
let yvec = y
86-
.clone()
87-
.into_iter()
88-
.map(|el| el.as_float().unwrap())
89-
.collect::<Vec<FLOAT>>();
90-
match Lasso::fit(&xvec, &yvec, LassoParameters::default()) {
91-
Ok(model) => Ok(Model {
92-
saved_model: bincode::serialize(&model).unwrap(),
93-
model_type: algorithm_string.to_string(),
94-
}),
95-
Err(e) => {
96-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
97+
"lasso" => {
98+
let yvec = y
99+
.clone()
100+
.into_iter()
101+
.map(|el| el.as_float().unwrap())
102+
.collect::<Vec<FLOAT>>();
103+
match Lasso::fit(&xvec, &yvec, LassoParameters::default()) {
104+
Ok(model) => Ok(Model {
105+
saved_model: bincode::serialize(&model).unwrap(),
106+
model_type: algorithm_string.to_string(),
107+
}),
108+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
109+
format!("{e}"),
110+
Position::NONE,
111+
)
112+
.into()),
97113
}
98114
}
99-
}
100-
"logistic" => {
101-
let yvec = y
102-
.clone()
103-
.into_iter()
104-
.map(|el| el.as_int().unwrap())
105-
.collect::<Vec<INT>>();
106-
match LogisticRegression::fit(&xvec, &yvec, LogisticRegressionParameters::default())
107-
{
108-
Ok(model) => Ok(Model {
109-
saved_model: bincode::serialize(&model).unwrap(),
110-
model_type: algorithm_string.to_string(),
111-
}),
112-
Err(e) => {
113-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
115+
"logistic" => {
116+
let yvec = y
117+
.clone()
118+
.into_iter()
119+
.map(|el| el.as_int().unwrap())
120+
.collect::<Vec<INT>>();
121+
match LogisticRegression::fit(
122+
&xvec,
123+
&yvec,
124+
LogisticRegressionParameters::default(),
125+
) {
126+
Ok(model) => Ok(Model {
127+
saved_model: bincode::serialize(&model).unwrap(),
128+
model_type: algorithm_string.to_string(),
129+
}),
130+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
131+
format!("{e}"),
132+
Position::NONE,
133+
)
134+
.into()),
114135
}
115136
}
137+
&_ => Err(EvalAltResult::ErrorArithmetic(
138+
format!("{} is not a recognized model type.", algorithm_string),
139+
Position::NONE,
140+
)
141+
.into()),
116142
}
117-
&_ => Err(EvalAltResult::ErrorArithmetic(
118-
format!("{} is not a recognized model type.", algorithm_string),
119-
Position::NONE,
120-
)
121-
.into()),
122143
}
123144
}
124145

@@ -136,59 +157,78 @@ pub mod train_and_predict_functions {
136157
/// ```
137158
#[rhai_fn(name = "predict", return_raw, pure)]
138159
pub fn predict_with_model(x: &mut Array, model: Model) -> Result<Array, Box<EvalAltResult>> {
139-
let xvec = DenseMatrix::from_2d_vec(
140-
&x.into_iter()
141-
.map(|observation| {
142-
array_to_vec_float(&mut observation.clone().into_array().unwrap())
143-
})
144-
.collect::<Vec<Vec<FLOAT>>>(),
145-
);
146-
let algorithm_string = model.model_type.as_str();
147-
match algorithm_string {
148-
"linear" => {
149-
let model_ready: LinearRegression<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
150-
bincode::deserialize(&*model.saved_model).unwrap();
151-
return match model_ready.predict(&xvec) {
152-
Ok(y) => Ok(y
153-
.into_iter()
154-
.map(|observation| Dynamic::from_float(observation))
155-
.collect::<Vec<Dynamic>>()),
156-
Err(e) => {
157-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
158-
}
159-
};
160-
}
161-
"lasso" => {
162-
let model_ready: Lasso<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
163-
bincode::deserialize(&*model.saved_model).unwrap();
164-
return match model_ready.predict(&xvec) {
165-
Ok(y) => Ok(y
166-
.into_iter()
167-
.map(|observation| Dynamic::from_float(observation))
168-
.collect::<Vec<Dynamic>>()),
169-
Err(e) => {
170-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
171-
}
172-
};
173-
}
174-
"logistic" => {
175-
let model_ready: LogisticRegression<FLOAT, INT, DenseMatrix<FLOAT>, Vec<INT>> =
176-
bincode::deserialize(&*model.saved_model).unwrap();
177-
return match model_ready.predict(&xvec) {
178-
Ok(y) => Ok(y
179-
.into_iter()
180-
.map(|observation| Dynamic::from_int(observation))
181-
.collect::<Vec<Dynamic>>()),
182-
Err(e) => {
183-
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
184-
}
185-
};
160+
// Make x array
161+
let array_as_vec_vec_float = &x
162+
.into_iter()
163+
.map(|observation| {
164+
crate::train_and_predict_functions::array_to_vec_float(
165+
&mut observation.clone().into_array().unwrap(),
166+
)
167+
})
168+
.collect::<Vec<Vec<FLOAT>>>();
169+
170+
// Check if x array is empty
171+
if array_as_vec_vec_float.len() == 0 {
172+
Err(EvalAltResult::ErrorArrayBounds(0, 0, Position::NONE).into())
173+
} else {
174+
let xvec = DenseMatrix::from_2d_vec(array_as_vec_vec_float);
175+
let algorithm_string = model.model_type.as_str();
176+
match algorithm_string {
177+
"linear" => {
178+
let model_ready: LinearRegression<
179+
FLOAT,
180+
FLOAT,
181+
DenseMatrix<FLOAT>,
182+
Vec<FLOAT>,
183+
> = bincode::deserialize(&*model.saved_model).unwrap();
184+
return match model_ready.predict(&xvec) {
185+
Ok(y) => Ok(y
186+
.into_iter()
187+
.map(|observation| Dynamic::from_float(observation))
188+
.collect::<Vec<Dynamic>>()),
189+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
190+
format!("{e}"),
191+
Position::NONE,
192+
)
193+
.into()),
194+
};
195+
}
196+
"lasso" => {
197+
let model_ready: Lasso<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
198+
bincode::deserialize(&*model.saved_model).unwrap();
199+
return match model_ready.predict(&xvec) {
200+
Ok(y) => Ok(y
201+
.into_iter()
202+
.map(|observation| Dynamic::from_float(observation))
203+
.collect::<Vec<Dynamic>>()),
204+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
205+
format!("{e}"),
206+
Position::NONE,
207+
)
208+
.into()),
209+
};
210+
}
211+
"logistic" => {
212+
let model_ready: LogisticRegression<FLOAT, INT, DenseMatrix<FLOAT>, Vec<INT>> =
213+
bincode::deserialize(&*model.saved_model).unwrap();
214+
return match model_ready.predict(&xvec) {
215+
Ok(y) => Ok(y
216+
.into_iter()
217+
.map(|observation| Dynamic::from_int(observation))
218+
.collect::<Vec<Dynamic>>()),
219+
Err(e) => Err(EvalAltResult::ErrorArithmetic(
220+
format!("{e}"),
221+
Position::NONE,
222+
)
223+
.into()),
224+
};
225+
}
226+
&_ => Err(EvalAltResult::ErrorArithmetic(
227+
format!("{} is not a recognized model type.", algorithm_string),
228+
Position::NONE,
229+
)
230+
.into()),
186231
}
187-
&_ => Err(EvalAltResult::ErrorArithmetic(
188-
format!("{} is not a recognized model type.", algorithm_string),
189-
Position::NONE,
190-
)
191-
.into()),
192232
}
193233
}
194234
}

0 commit comments

Comments
 (0)