Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return all identity scores for each instance #22

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Bonsai.Sleap.Design/Bonsai.Sleap.Design.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
<PackageTags>Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking Visualizers</PackageTags>
<TargetFramework>net472</TargetFramework>
<UseWindowsForms>true</UseWindowsForms>
<VersionPrefix>0.2.0</VersionPrefix>
<VersionPrefix>0.3.0</VersionPrefix>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Bonsai.Vision.Design" Version="2.7.0" />
<PackageReference Include="Bonsai.Vision.Design" Version="2.8.0" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion src/Bonsai.Sleap/Bonsai.Sleap.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<PackageTags>Bonsai Rx SLEAP LEAP Markerless Multi Pose Tracking</PackageTags>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<TargetFramework>net472</TargetFramework>
<VersionPrefix>0.2.2</VersionPrefix>
<VersionPrefix>0.3.0</VersionPrefix>
</PropertyGroup>

<ItemGroup>
Expand Down
14 changes: 10 additions & 4 deletions src/Bonsai.Sleap/PoseIdentity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,25 @@ public PoseIdentity(IplImage image)
}

/// <summary>
/// Gets or sets the predicted pose identity.
/// Gets or sets the maximum likelihood predicted pose identity.
/// </summary>
public string Identity { get; set; }

/// <summary>
/// Gets or sets the predicted pose identity index.
/// Gets or sets the maximum likelihood confidence score for the predicted identity.
/// </summary>
public float Confidence { get; set; }

/// <summary>
/// Gets or sets the maximum likelihood predicted pose identity index.
/// </summary>
[XmlIgnore]
public int IdentityIndex { get; set; }

/// <summary>
/// Gets or sets the confidence score for the predicted identity.
/// Gets or sets the predicted identity confidence scores for this instance.
/// </summary>
public float Confidence { get; set; }
[XmlIgnore]
public float[] IdentityScores { get; set; }
}
}
21 changes: 14 additions & 7 deletions src/Bonsai.Sleap/PredictPoseIdentities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private IObservable<PoseIdentityCollection> Process(IObservable<IplImage[]> sour
{
// Find the class with max score
var pose = new PoseIdentity(input.Length == 1 ? input[0] : input[iid]);
var maxIndex = ArgMax(idArr, iid, Comparer<float>.Default, out float maxScore);
pose.IdentityScores = GetRowValues(idArr, iid, Comparer<float>.Default, out float maxScore, out int maxIndex);

if (maxScore < idThreshold || maxIndex < 0)
{
Expand Down Expand Up @@ -257,22 +257,29 @@ public override IObservable<PoseIdentityCollection> Process(IObservable<IplImage
return Process(source.Select(frame => new IplImage[] { frame }));
}

static int ArgMax<TElement>(TElement[,] array, int instance, IComparer<TElement> comparer, out TElement maxValue)
static TElement[] GetRowValues<TElement>(
TElement[,] array,
int rowIndex,
IComparer<TElement> comparer,
out TElement maxValue,
out int maxIndex)
{
if (array == null) throw new ArgumentNullException(nameof(array));
if (comparer == null) throw new ArgumentNullException(nameof(comparer));

int maxIndex = -1;
maxIndex = -1;
maxValue = default;
for (int i = 0; i < array.GetLength(1); i++)
var values = new TElement[array.GetLength(1)];
for (int i = 0; i < values.Length; i++)
{
if (i == 0 || comparer.Compare(array[instance, i], maxValue) > 0)
values[i] = array[rowIndex, i];
if (i == 0 || comparer.Compare(values[i], maxValue) > 0)
{
maxIndex = i;
maxValue = array[instance, i];
maxValue = array[rowIndex, i];
}
}
return maxIndex;
return values;
}
}
}