Skip to content

Commit

Permalink
Merge pull request #46 from bonsai-rx/fix/get-gaussian-observation-st…
Browse files Browse the repository at this point in the history
…atistics

Fix get gaussian observation statistics
  • Loading branch information
ncguilbeault authored Jan 10, 2025
2 parents db0ab0e + b58cbd2 commit 34e43d1
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ public override void Show(object value)
{
if (value is Observations.GaussianObservationsStatistics gaussianObservationsStatistics)
{

var statesCount = gaussianObservationsStatistics.Means.GetLength(0);
var observationDimensions = gaussianObservationsStatistics.Means.GetLength(1);

Expand Down Expand Up @@ -227,12 +226,14 @@ public override void Show(object value)

var batchObservationsCount = gaussianObservationsStatistics.BatchObservations.GetLength(0);
var offset = BufferData && batchObservationsCount > BufferCount ? batchObservationsCount - BufferCount : 0;
var predictedStatesCount = gaussianObservationsStatistics.PredictedStates.Length;

for (int i = offset; i < batchObservationsCount; i++)
{
var dim1 = gaussianObservationsStatistics.BatchObservations[i, dimension1SelectedIndex];
var dim2 = gaussianObservationsStatistics.BatchObservations[i, dimension2SelectedIndex];
var state = gaussianObservationsStatistics.InferredMostProbableStates[i];
allScatterSeries[(int)state].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
var state = gaussianObservationsStatistics.PredictedStates[i];
allScatterSeries[Convert.ToInt32(state)].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
}

for (int i = 0; i < statesCount; i++)
Expand Down
2 changes: 1 addition & 1 deletion src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="py:Exec">
<py:Script>hmm.most_likely_states([59.7382107943162,3.99285183724331])</py:Script>
<py:Script>hmm.infer_state([59.7382107943162,3.99285183724331])</py:Script>
</Combinator>
</Expression>
<Expression xsi:type="WorkflowOutput" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ public class GaussianObservationsStatistics
public double[,] BatchObservations { get; set; }

/// <summary>
/// The sequence of inferred most probable states.
/// The predicted state for each observation in the batch of observations.
/// </summary>
[Description("The sequence of inferred most probable states.")]
[Description("The predicted state for each observation in the batch of observations.")]
[XmlIgnore]
public int[] InferredMostProbableStates { get; set; }
public long[] PredictedStates { get; set; }

/// <summary>
/// Transforms an observable sequence of <see cref="PyObject"/> into an observable sequence
Expand All @@ -64,15 +64,15 @@ public IObservable<GaussianObservationsStatistics> Process(IObservable<PyObject>
var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas");
var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj);
var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations");
var inferredMostProbableStatesPyObj = (int[])pyObject.GetArrayAttr("inferred_most_probable_states");
var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states");

return new GaussianObservationsStatistics
{
Means = meansPyObj,
StdDevs = stdDevsPyObj,
CovarianceMatrices = covarianceMatricesPyObj,
BatchObservations = batchObservationsPyObj,
InferredMostProbableStates = inferredMostProbableStatesPyObj
PredictedStates = predictedStatesPyObj
};
});
}
Expand Down
22 changes: 13 additions & 9 deletions src/Bonsai.ML.HiddenMarkovModels/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ def get_nonlinearity_type(func):
self.state_probabilities = None

self.batch = None
self.batch_observations = np.array([[]], dtype=float)
self.batch_observations = np.array([[]], dtype=float).reshape((0, dimensions))
self.is_running = False
self._fit_finished = False
self.loop = None
self.thread = None
self.curr_batch_size = 0
self.flush_data_between_batches = True
self.inferred_most_probable_states = np.array([], dtype=int)
self.predicted_states = np.array([], dtype=int)
self.buffer_count = 250

def update_params(self, initial_state_distribution, transitions_params, observations_params):
hmm_params = self.params
Expand Down Expand Up @@ -124,10 +125,17 @@ def update_params(self, initial_state_distribution, transitions_params, observat

def infer_state(self, observation: list[float]):

self.log_alpha = self.compute_log_alpha(
np.expand_dims(np.array(observation), 0), self.log_alpha)
observation = np.expand_dims(np.array(observation), 0)
self.log_alpha = self.compute_log_alpha(observation, self.log_alpha)
self.state_probabilities = np.exp(self.log_alpha).astype(np.double)
return self.state_probabilities.argmax()
prediction = self.state_probabilities.argmax()
self.predicted_states = np.append(self.predicted_states, prediction)
if self.predicted_states.shape[0] > self.buffer_count:
self.predicted_states = self.predicted_states[1:]
self.batch_observations = np.vstack([self.batch_observations, observation])
if self.batch_observations.shape[0] == self.buffer_count:
self.batch_observations = self.batch_observations[1:]
return prediction

def compute_log_alpha(self, obs, log_alpha=None):

Expand Down Expand Up @@ -171,8 +179,6 @@ def fit_async(self,
self.batch = np.vstack(
[self.batch[1:], np.expand_dims(np.array(observation), 0)])

self.batch_observations = self.batch

if not self.is_running and self.loop is None and self.thread is None:

if self.curr_batch_size >= batch_size:
Expand Down Expand Up @@ -221,8 +227,6 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None

self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)

self.is_running = True

if self.loop is None or self.loop.is_closed():
Expand Down

0 comments on commit 34e43d1

Please sign in to comment.