diff --git a/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj b/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj
index 03a2965..008b9d1 100644
--- a/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj
+++ b/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj
@@ -6,11 +6,11 @@
Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking Visualizers
net472
true
- 0.2.0
+ 0.3.0
-
+
diff --git a/src/Bonsai.Sleap/Bonsai.Sleap.csproj b/src/Bonsai.Sleap/Bonsai.Sleap.csproj
index 364dd9d..10b9c8d 100644
--- a/src/Bonsai.Sleap/Bonsai.Sleap.csproj
+++ b/src/Bonsai.Sleap/Bonsai.Sleap.csproj
@@ -6,7 +6,7 @@
Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking
true
net472
- 0.2.2
+ 0.3.0
diff --git a/src/Bonsai.Sleap/PoseIdentity.cs b/src/Bonsai.Sleap/PoseIdentity.cs
index 802c369..56c7154 100644
--- a/src/Bonsai.Sleap/PoseIdentity.cs
+++ b/src/Bonsai.Sleap/PoseIdentity.cs
@@ -20,19 +20,25 @@ public PoseIdentity(IplImage image)
}
///
- /// Gets or sets the predicted pose identity.
+ /// Gets or sets the maximum likelihood predicted pose identity.
///
public string Identity { get; set; }
///
- /// Gets or sets the predicted pose identity index.
+ /// Gets or sets the maximum likelihood confidence score for the predicted identity.
+ ///
+ public float Confidence { get; set; }
+
+ ///
+ /// Gets or sets the maximum likelihood predicted pose identity index.
///
[XmlIgnore]
public int IdentityIndex { get; set; }
///
- /// Gets or sets the confidence score for the predicted identity.
+ /// Gets or sets the predicted identity confidence scores for this instance.
///
- public float Confidence { get; set; }
+ [XmlIgnore]
+ public float[] IdentityScores { get; set; }
}
}
diff --git a/src/Bonsai.Sleap/PredictPoseIdentities.cs b/src/Bonsai.Sleap/PredictPoseIdentities.cs
index 25979a1..36392ad 100644
--- a/src/Bonsai.Sleap/PredictPoseIdentities.cs
+++ b/src/Bonsai.Sleap/PredictPoseIdentities.cs
@@ -184,7 +184,7 @@ private IObservable Process(IObservable sour
{
// Find the class with max score
var pose = new PoseIdentity(input.Length == 1 ? input[0] : input[iid]);
- var maxIndex = ArgMax(idArr, iid, Comparer.Default, out float maxScore);
+ pose.IdentityScores = GetRowValues(idArr, iid, Comparer.Default, out float maxScore, out int maxIndex);
if (maxScore < idThreshold || maxIndex < 0)
{
@@ -257,22 +257,29 @@ public override IObservable Process(IObservable new IplImage[] { frame }));
}
- static int ArgMax(TElement[,] array, int instance, IComparer comparer, out TElement maxValue)
+ static TElement[] GetRowValues(
+ TElement[,] array,
+ int rowIndex,
+ IComparer comparer,
+ out TElement maxValue,
+ out int maxIndex)
{
if (array == null) throw new ArgumentNullException(nameof(array));
if (comparer == null) throw new ArgumentNullException(nameof(comparer));
- int maxIndex = -1;
+ maxIndex = -1;
maxValue = default;
- for (int i = 0; i < array.GetLength(1); i++)
+ var values = new TElement[array.GetLength(1)];
+ for (int i = 0; i < values.Length; i++)
{
- if (i == 0 || comparer.Compare(array[instance, i], maxValue) > 0)
+ values[i] = array[rowIndex, i];
+ if (i == 0 || comparer.Compare(values[i], maxValue) > 0)
{
maxIndex = i;
- maxValue = array[instance, i];
+ maxValue = array[rowIndex, i];
}
}
- return maxIndex;
+ return values;
}
}
}