diff --git a/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj b/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj index 03a2965..008b9d1 100644 --- a/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj +++ b/src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj @@ -6,11 +6,11 @@ Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking Visualizers net472 true - 0.2.0 + 0.3.0 - + diff --git a/src/Bonsai.Sleap/Bonsai.Sleap.csproj b/src/Bonsai.Sleap/Bonsai.Sleap.csproj index 364dd9d..10b9c8d 100644 --- a/src/Bonsai.Sleap/Bonsai.Sleap.csproj +++ b/src/Bonsai.Sleap/Bonsai.Sleap.csproj @@ -6,7 +6,7 @@ Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking true net472 - 0.2.2 + 0.3.0 diff --git a/src/Bonsai.Sleap/PoseIdentity.cs b/src/Bonsai.Sleap/PoseIdentity.cs index 802c369..56c7154 100644 --- a/src/Bonsai.Sleap/PoseIdentity.cs +++ b/src/Bonsai.Sleap/PoseIdentity.cs @@ -20,19 +20,25 @@ public PoseIdentity(IplImage image) } /// - /// Gets or sets the predicted pose identity. + /// Gets or sets the maximum likelihood predicted pose identity. /// public string Identity { get; set; } /// - /// Gets or sets the predicted pose identity index. + /// Gets or sets the maximum likelihood confidence score for the predicted identity. + /// + public float Confidence { get; set; } + + /// + /// Gets or sets the maximum likelihood predicted pose identity index. /// [XmlIgnore] public int IdentityIndex { get; set; } /// - /// Gets or sets the confidence score for the predicted identity. + /// Gets or sets the predicted identity confidence scores for this instance. /// - public float Confidence { get; set; } + [XmlIgnore] + public float[] IdentityScores { get; set; } } } diff --git a/src/Bonsai.Sleap/PredictPoseIdentities.cs b/src/Bonsai.Sleap/PredictPoseIdentities.cs index 25979a1..36392ad 100644 --- a/src/Bonsai.Sleap/PredictPoseIdentities.cs +++ b/src/Bonsai.Sleap/PredictPoseIdentities.cs @@ -184,7 +184,7 @@ private IObservable Process(IObservable sour { // Find the class with max score var pose = new PoseIdentity(input.Length == 1 ? input[0] : input[iid]); - var maxIndex = ArgMax(idArr, iid, Comparer.Default, out float maxScore); + pose.IdentityScores = GetRowValues(idArr, iid, Comparer.Default, out float maxScore, out int maxIndex); if (maxScore < idThreshold || maxIndex < 0) { @@ -257,22 +257,29 @@ public override IObservable Process(IObservable new IplImage[] { frame })); } - static int ArgMax(TElement[,] array, int instance, IComparer comparer, out TElement maxValue) + static TElement[] GetRowValues( + TElement[,] array, + int rowIndex, + IComparer comparer, + out TElement maxValue, + out int maxIndex) { if (array == null) throw new ArgumentNullException(nameof(array)); if (comparer == null) throw new ArgumentNullException(nameof(comparer)); - int maxIndex = -1; + maxIndex = -1; maxValue = default; - for (int i = 0; i < array.GetLength(1); i++) + var values = new TElement[array.GetLength(1)]; + for (int i = 0; i < values.Length; i++) { - if (i == 0 || comparer.Compare(array[instance, i], maxValue) > 0) + values[i] = array[rowIndex, i]; + if (i == 0 || comparer.Compare(values[i], maxValue) > 0) { maxIndex = i; - maxValue = array[instance, i]; + maxValue = array[rowIndex, i]; } } - return maxIndex; + return values; } } }