Skip to content

Commit b7ac6f6

Browse files
Add weights per feature dimension (#83)
1 parent dd08ad7 commit b7ac6f6

5 files changed

+88
-17
lines changed

src/features/mm_feature.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ class MMFeature : public Resource {
3636

3737
virtual void display_data(const Ref<EditorNode3DGizmo>& p_gizmo, const Transform3D p_transform, const float* p_data) const {};
3838

39-
void normalize(float* p_data) const;
40-
void denormalize(float* p_data) const;
41-
float calculate_normalized_weight() const {
39+
virtual float calculate_normalized_weight(int64_t p_feature_dim) const {
4240
return weight / get_dimension_count();
4341
}
42+
43+
void normalize(float* p_data) const;
44+
void denormalize(float* p_data) const;
45+
4446
GETSET(float, weight, 1.0f);
4547
GETSET(NormalizationMode, normalization_mode, Standard);
4648
GETSET(PackedFloat32Array, means);

src/features/mm_trajectory_feature.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,24 @@ void MMTrajectoryFeature::display_data(const Ref<EditorNode3DGizmo>& p_gizmo, co
162162
delete[] dernomalized_data;
163163
}
164164

165+
float MMTrajectoryFeature::calculate_normalized_weight(int64_t p_feature_dim) const {
166+
167+
float weight = MMFeature::calculate_normalized_weight(p_feature_dim);
168+
169+
const uint32_t point_dim = _get_point_dimension_count();
170+
171+
const bool is_height = include_height && (p_feature_dim % point_dim) == 2;
172+
const bool is_facing = include_facing && (p_feature_dim % point_dim) == (include_height ? 3 : 2);
173+
174+
if (is_height) {
175+
weight *= height_weight;
176+
} else if (is_facing) {
177+
weight *= facing_weight;
178+
}
179+
180+
return weight;
181+
}
182+
165183
TypedArray<Dictionary> MMTrajectoryFeature::get_trajectory_points(const Transform3D& p_character_transform, const PackedFloat32Array& p_trajectory_data) const {
166184
ERR_FAIL_COND_V(p_trajectory_data.is_empty(), TypedArray<Dictionary>());
167185

@@ -189,14 +207,48 @@ TypedArray<Dictionary> MMTrajectoryFeature::get_trajectory_points(const Transfor
189207
return result;
190208
}
191209

210+
bool MMTrajectoryFeature::get_include_height() const {
211+
return include_height;
212+
}
213+
214+
void MMTrajectoryFeature::set_include_height(bool value) {
215+
include_height = value;
216+
notify_property_list_changed();
217+
}
218+
219+
bool MMTrajectoryFeature::get_include_facing() const {
220+
return include_facing;
221+
}
222+
223+
void MMTrajectoryFeature::set_include_facing(bool value) {
224+
include_facing = value;
225+
notify_property_list_changed();
226+
}
227+
228+
void MMTrajectoryFeature::_validate_property(PropertyInfo& p_property) const {
229+
if (p_property.name == StringName("facing_weight")) {
230+
if (!include_facing) {
231+
p_property.usage = PROPERTY_USAGE_NO_EDITOR;
232+
}
233+
}
234+
235+
if (p_property.name == StringName("height_weight")) {
236+
if (!include_height) {
237+
p_property.usage = PROPERTY_USAGE_NO_EDITOR;
238+
}
239+
}
240+
}
241+
192242
void MMTrajectoryFeature::_bind_methods() {
193243
ClassDB::bind_method(D_METHOD("get_trajectory_points", "character_transform", "trajectory_data"), &MMTrajectoryFeature::get_trajectory_points);
194244
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, past_delta_time);
195245
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::INT, past_frames);
196246
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, future_delta_time);
197247
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::INT, future_frames);
198248
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::BOOL, include_height);
249+
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, height_weight);
199250
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::BOOL, include_facing);
251+
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, facing_weight);
200252
}
201253

202254
uint32_t MMTrajectoryFeature::_get_point_dimension_count() const {

src/features/mm_trajectory_feature.h

+15-2
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,29 @@ class MMTrajectoryFeature : public MMFeature {
2626

2727
virtual void display_data(const Ref<EditorNode3DGizmo>& p_gizmo, const Transform3D p_transform, const float* p_data) const override;
2828

29+
virtual float calculate_normalized_weight(int64_t p_feature_dim) const override;
30+
2931
TypedArray<Dictionary> get_trajectory_points(const Transform3D& p_character_transform, const PackedFloat32Array& p_trajectory_data) const;
3032

3133
GETSET(double, past_delta_time, 0.1);
3234
GETSET(int64_t, past_frames, 1);
3335
GETSET(double, future_delta_time, 0.1);
3436
GETSET(int64_t, future_frames, 5);
35-
GETSET(bool, include_height, false);
36-
GETSET(bool, include_facing, true);
37+
38+
bool include_height{false};
39+
bool get_include_height() const;
40+
void set_include_height(bool value);
41+
42+
GETSET(float, height_weight, 1.0);
43+
44+
bool include_facing{true};
45+
bool get_include_facing() const;
46+
void set_include_facing(bool value);
47+
48+
GETSET(float, facing_weight, 1.0);
3749

3850
protected:
51+
void _validate_property(PropertyInfo& p_property) const;
3952
static void _bind_methods();
4053

4154
private:

src/mm_animation_library.cpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,13 @@ float MMAnimationLibrary::_compute_feature_costs(int p_pose_index, const PackedF
218218
continue;
219219
}
220220

221-
const float feature_cost =
222-
distance_squared((motion_data.ptr() + start_frame_index + start_feature_index),
223-
(p_query.ptr() + start_feature_index),
224-
feature->get_dimension_count()) *
225-
feature->calculate_normalized_weight();
221+
float feature_cost = 0.f;
222+
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
223+
feature_cost += distance_squared((motion_data.ptr() + start_frame_index + start_feature_index + feature_dim_index),
224+
(p_query.ptr() + start_feature_index + feature_dim_index),
225+
1) *
226+
feature->calculate_normalized_weight(feature_dim_index);
227+
}
226228

227229
if (p_feature_costs) {
228230
p_feature_costs->get_or_add(feature->get_class(), feature_cost);
@@ -250,11 +252,13 @@ MMQueryOutput MMAnimationLibrary::_search_naive(const PackedFloat32Array& p_quer
250252
continue;
251253
}
252254

253-
const float feature_cost =
254-
distance_squared((motion_data.ptr() + start_feature_index),
255-
(p_query.ptr() + start_feature_index - start_frame_index),
256-
feature->get_dimension_count()) *
257-
feature->calculate_normalized_weight();
255+
float feature_cost = 0.f;
256+
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
257+
feature_cost += distance_squared((motion_data.ptr() + start_frame_index + start_feature_index + feature_dim_index),
258+
(p_query.ptr() + start_feature_index + feature_dim_index),
259+
1) *
260+
feature->calculate_normalized_weight(feature_dim_index);
261+
}
258262

259263
feature_costs.get_or_add(feature->get_class(), feature_cost);
260264
pose_cost += feature_cost;
@@ -302,7 +306,7 @@ MMQueryOutput MMAnimationLibrary::_search_kd_tree(const PackedFloat32Array& p_qu
302306
continue;
303307
}
304308
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
305-
dimension_weights.push_back(feature->calculate_normalized_weight());
309+
dimension_weights.push_back(feature->calculate_normalized_weight(feature_dim_index));
306310
}
307311
}
308312

0 commit comments

Comments
 (0)