|
7 | 7 | class TestKalmanFilter(unittest.TestCase):
|
8 | 8 |
|
9 | 9 | def test_KalmanFilter(self):
|
10 |
| - |
11 | 10 | print('Run KF class test.')
|
12 | 11 |
|
13 | 12 | # initialization of state matrices
|
14 | 13 | A = np.eye(2)
|
15 | 14 | X = np.array([[0.1], [0.1]])
|
16 |
| - B = np.ones((X.shape[0], 1)) |
17 |
| - U = 0.5 |
| 15 | + B = np.ones((2, 1)) # 2 rows (same as X), 1 column |
| 16 | + U = np.array([[0.5]]) # single control input as a column vector |
18 | 17 |
|
19 | 18 | # process noise covariance matrix and initial state covariance
|
20 |
| - Q = np.eye(X.shape[0]) |
| 19 | + Q = np.eye(2) |
21 | 20 | P = np.diag((0.01, 0.01))
|
22 | 21 |
|
23 | 22 | # measurement matrices (state X plus a random gaussian noise)
|
24 |
| - Y = np.array([X[0, 0] + 0.001, X[1, 0] + 0.001]) |
25 |
| - Y = Y.reshape(-1, 1) |
| 23 | + Y = np.array([[X[0, 0] + 0.001], [X[1, 0] + 0.001]]) # 2x1 matrix |
26 | 24 | C = np.eye(2)
|
27 | 25 |
|
28 | 26 | # measurement noise covariance
|
29 |
| - R = np.eye(Y.shape[0]) |
30 |
| - |
31 |
| - var = {} |
32 |
| - var.update({'X': X}) |
33 |
| - var.update({'A': A}) |
34 |
| - var.update({'B': B}) |
35 |
| - var.update({'U': U}) |
36 |
| - var.update({'Q': Q}) |
37 |
| - var.update({'C': C}) |
38 |
| - var.update({'R': R}) |
| 27 | + R = np.eye(2) |
| 28 | + |
| 29 | + var = { |
| 30 | + 'X': X, |
| 31 | + 'A': A, |
| 32 | + 'B': B, |
| 33 | + 'U': U, |
| 34 | + 'Q': Q, |
| 35 | + 'C': C, |
| 36 | + 'R': R |
| 37 | + } |
39 | 38 |
|
40 | 39 | kf = KalmanFilter()
|
41 | 40 | kf.setup(var)
|
42 | 41 | kf.predict()
|
43 | 42 | x_est, y_predict = kf.update(Y)
|
44 | 43 |
|
45 | 44 | # verify if the object of the class is correct
|
46 |
| - self.assertEqual(x_est[0], 0.3505) |
47 |
| - self.assertEqual(x_est[1], 0.3505) |
48 |
| - self.assertEqual(y_predict[0], 0.07026195923972386) |
| 45 | + np.testing.assert_almost_equal(x_est[0], 0.1014, decimal=4) |
| 46 | + np.testing.assert_almost_equal(x_est[1], 0.1014, decimal=4) |
| 47 | + np.testing.assert_almost_equal(y_predict[0], 0.0001588, decimal=7) |
49 | 48 |
|
50 | 49 |
|
51 | 50 | if __name__ == '__main__':
|
52 | 51 | suite = unittest.TestSuite()
|
53 | 52 | suite.addTest(TestKalmanFilter('test_KalmanFilter'))
|
| 53 | + unittest.TextTestRunner().run(suite) |
0 commit comments