Skip to content

Commit 9c135d8

Browse files
cdtwiggmeta-codesync[bot]
authored andcommitted
Add round-trip test for parameter transform. (#788)
Summary: Pull Request resolved: #788 I want to do some refactoring of the parameter transform, so having some test coverage is helpful. Let's also add the ability to save out the parameter transform, which is useful in this test and could be useful in the future. Reviewed By: jeongseok-meta, cstollmeta Differential Revision: D86350961 fbshipit-source-id: ed04d96a644fc34fa85cb14e793b7a082f6d071b
1 parent d6f7e14 commit 9c135d8

File tree

4 files changed

+280
-0
lines changed

4 files changed

+280
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ if(MOMENTUM_BUILD_TESTING)
749749
character_test_helpers
750750
io_skeleton
751751
io_test_helper
752+
character_test_helpers_gtest
752753
)
753754

754755
mt_test(

momentum/io/skeleton/parameter_transform_io.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ std::tuple<ParameterTransform, ParameterLimits> loadModelDefinitionFromStream(
104104

105105
pt = parseParameterTransform(data, skeleton);
106106
pt.parameterSets = parseParameterSets(data, pt);
107+
pt.poseConstraints = parsePoseConstraints(data, pt);
107108
pl = parseParameterLimits(data, skeleton, pt);
108109

109110
return res;
@@ -435,4 +436,159 @@ std::tuple<ParameterTransform, ParameterLimits> loadModelDefinition(
435436
return loadModelDefinitionFromStream(inputStream, skeleton);
436437
}
437438

439+
std::string writeParameterTransform(
440+
const ParameterTransform& parameterTransform,
441+
const Skeleton& skeleton) {
442+
std::ostringstream oss;
443+
444+
// Write each joint parameter definition
445+
for (size_t iJoint = 0; iJoint < skeleton.joints.size(); ++iJoint) {
446+
const auto& joint = skeleton.joints[iJoint];
447+
448+
for (size_t iParam = 0; iParam < kParametersPerJoint; ++iParam) {
449+
const size_t jointParamIndex = iJoint * kParametersPerJoint + iParam;
450+
451+
// Skip inactive joint parameters
452+
if (!parameterTransform.activeJointParams[jointParamIndex]) {
453+
continue;
454+
}
455+
456+
// Write joint parameter name
457+
oss << joint.name << "." << kJointParameterNames[iParam] << " = ";
458+
459+
// Collect all model parameters that influence this joint parameter
460+
std::vector<std::pair<size_t, float>> influences;
461+
for (Eigen::Index iModelParam = 0; iModelParam < parameterTransform.transform.cols();
462+
++iModelParam) {
463+
const float weight = parameterTransform.transform.coeff(jointParamIndex, iModelParam);
464+
if (weight != 0.0f) {
465+
influences.emplace_back(iModelParam, weight);
466+
}
467+
}
468+
469+
// Write the offset if it exists
470+
const float offset = parameterTransform.offsets[jointParamIndex];
471+
bool needsPlus = false;
472+
473+
if (!influences.empty()) {
474+
for (size_t i = 0; i < influences.size(); ++i) {
475+
if (i > 0) {
476+
oss << " + ";
477+
}
478+
const auto [modelParamIndex, weight] = influences[i];
479+
oss << weight << "*" << parameterTransform.name[modelParamIndex];
480+
}
481+
needsPlus = true;
482+
}
483+
484+
if (offset != 0.0f) {
485+
if (needsPlus) {
486+
oss << " + ";
487+
}
488+
oss << offset;
489+
}
490+
491+
oss << "\n";
492+
}
493+
}
494+
495+
return oss.str();
496+
}
497+
498+
std::string writeParameterSets(const ParameterSets& parameterSets) {
499+
std::ostringstream oss;
500+
501+
for (const auto& [name, paramSet] : parameterSets) {
502+
oss << "parameterset " << name;
503+
504+
// Write all active parameters in the set
505+
for (size_t i = 0; i < paramSet.size(); ++i) {
506+
if (paramSet.test(i)) {
507+
// We need the parameter name, but we don't have access to parameterTransform here
508+
// This is a limitation - we'll need to pass it in
509+
oss << " param_" << i;
510+
}
511+
}
512+
oss << "\n";
513+
}
514+
515+
return oss.str();
516+
}
517+
518+
std::string writePoseConstraints(const PoseConstraints& poseConstraints) {
519+
std::ostringstream oss;
520+
521+
for (const auto& [name, constraint] : poseConstraints) {
522+
oss << "poseconstraints " << name;
523+
524+
// Write all parameter=value pairs
525+
for (const auto& [paramIndex, value] : constraint.parameterIdValue) {
526+
// We need the parameter name, but we don't have access to parameterTransform here
527+
oss << " param_" << paramIndex << "=" << value;
528+
}
529+
oss << "\n";
530+
}
531+
532+
return oss.str();
533+
}
534+
535+
std::string writeModelDefinition(
536+
const Skeleton& skeleton,
537+
const ParameterTransform& parameterTransform,
538+
const ParameterLimits& parameterLimits) {
539+
std::ostringstream oss;
540+
541+
// Write header
542+
oss << "Momentum Model Definition V1.0\n\n";
543+
544+
// Write ParameterTransform section
545+
if (!parameterTransform.name.empty()) {
546+
oss << "[ParameterTransform]\n";
547+
oss << writeParameterTransform(parameterTransform, skeleton);
548+
oss << "\n";
549+
}
550+
551+
// Write ParameterSets section
552+
if (!parameterTransform.parameterSets.empty()) {
553+
oss << "[ParameterSets]\n";
554+
for (const auto& [name, paramSet] : parameterTransform.parameterSets) {
555+
oss << "parameterset " << name;
556+
557+
// Write all active parameters in the set
558+
for (size_t i = 0; i < paramSet.size() && i < parameterTransform.name.size(); ++i) {
559+
if (paramSet.test(i)) {
560+
oss << " " << parameterTransform.name[i];
561+
}
562+
}
563+
oss << "\n";
564+
}
565+
oss << "\n";
566+
}
567+
568+
// Write PoseConstraints section
569+
if (!parameterTransform.poseConstraints.empty()) {
570+
oss << "[PoseConstraints]\n";
571+
for (const auto& [name, constraint] : parameterTransform.poseConstraints) {
572+
oss << "poseconstraints " << name;
573+
574+
// Write all parameter=value pairs
575+
for (const auto& [paramIndex, value] : constraint.parameterIdValue) {
576+
if (paramIndex < parameterTransform.name.size()) {
577+
oss << " " << parameterTransform.name[paramIndex] << "=" << value;
578+
}
579+
}
580+
oss << "\n";
581+
}
582+
oss << "\n";
583+
}
584+
585+
// Write ParameterLimits section using existing function
586+
if (!parameterLimits.empty()) {
587+
oss << "[ParameterLimits]\n";
588+
oss << writeParameterLimits(parameterLimits, skeleton, parameterTransform);
589+
}
590+
591+
return oss.str();
592+
}
593+
438594
} // namespace momentum

momentum/io/skeleton/parameter_transform_io.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,23 @@ std::tuple<ParameterTransform, ParameterLimits> loadModelDefinition(
3838
std::span<const std::byte> rawData,
3939
const Skeleton& skeleton);
4040

41+
// Write functions to serialize model definition components
42+
std::string writeParameterTransform(
43+
const ParameterTransform& parameterTransform,
44+
const Skeleton& skeleton);
45+
46+
std::string writeParameterSets(const ParameterSets& parameterSets);
47+
48+
std::string writePoseConstraints(const PoseConstraints& poseConstraints);
49+
50+
/// Write complete model definition file
51+
/// @param skeleton The character's skeletal structure
52+
/// @param parameterTransform Maps model parameters to joint parameters
53+
/// @param parameterLimits Constraints on model parameters (can be empty)
54+
/// @return String containing the complete model definition
55+
std::string writeModelDefinition(
56+
const Skeleton& skeleton,
57+
const ParameterTransform& parameterTransform,
58+
const ParameterLimits& parameterLimits);
59+
4160
} // namespace momentum

momentum/test/io/io_model_parser_test.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
*/
77

88
#include <gtest/gtest.h>
9+
#include <momentum/character/character.h>
10+
#include <momentum/character/skeleton.h>
911
#include <momentum/io/skeleton/parameter_transform_io.h>
12+
#include <momentum/test/character/character_helpers_gtest.h>
1013

1114
using namespace momentum;
1215

@@ -103,3 +106,104 @@ limit param2 minmax [-2.0, 2.0] 1.0
103106
EXPECT_TRUE(limitsContent.find("limit param1 minmax") != std::string::npos);
104107
EXPECT_TRUE(limitsContent.find("limit param2 minmax") != std::string::npos);
105108
}
109+
110+
TEST(IoModelParserTest, WriteModelDefinition_RoundTrip) {
111+
// Create original character
112+
Character original;
113+
original.name = "test_character";
114+
115+
// Create a simple skeleton
116+
original.skeleton.joints.resize(2);
117+
original.skeleton.joints[0].name = "root";
118+
original.skeleton.joints[0].parent = kInvalidIndex;
119+
original.skeleton.joints[1].name = "child";
120+
original.skeleton.joints[1].parent = 0;
121+
122+
// Create parameter transform
123+
original.parameterTransform.name = {"tx", "ty", "tz"};
124+
original.parameterTransform.activeJointParams.setConstant(
125+
original.skeleton.joints.size() * kParametersPerJoint, false);
126+
original.parameterTransform.offsets.setZero(
127+
original.skeleton.joints.size() * kParametersPerJoint);
128+
129+
// Set up transform matrix (2 joints * 7 params per joint = 14 rows, 3 parameters = 3 cols)
130+
original.parameterTransform.transform.resize(
131+
original.skeleton.joints.size() * kParametersPerJoint, 3);
132+
std::vector<Eigen::Triplet<float>> triplets;
133+
134+
// root.tx = 1.0 * tx
135+
original.parameterTransform.activeJointParams[0] = true;
136+
triplets.emplace_back(0, 0, 1.0f);
137+
138+
// root.ty = 1.0 * ty
139+
original.parameterTransform.activeJointParams[1] = true;
140+
triplets.emplace_back(1, 1, 1.0f);
141+
142+
// child.tz = 2.0 * tz + 0.5
143+
original.parameterTransform.activeJointParams[7 + 2] = true;
144+
triplets.emplace_back(7 + 2, 2, 2.0f);
145+
original.parameterTransform.offsets[7 + 2] = 0.5f;
146+
147+
original.parameterTransform.transform.setFromTriplets(triplets.begin(), triplets.end());
148+
149+
// Add parameter sets
150+
ParameterSet ps1;
151+
ps1.set(0, true); // tx
152+
ps1.set(1, true); // ty
153+
original.parameterTransform.parameterSets["translation_xy"] = ps1;
154+
155+
// Add pose constraints
156+
PoseConstraint pc;
157+
pc.parameterIdValue.emplace_back(0, 0.0f); // tx = 0
158+
pc.parameterIdValue.emplace_back(1, 0.0f); // ty = 0
159+
original.parameterTransform.poseConstraints["bind_pose"] = pc;
160+
161+
// Create parameter limits
162+
ParameterLimit limit1;
163+
limit1.type = LimitType::MinMax;
164+
limit1.weight = 1.0f;
165+
limit1.data.minMax.parameterIndex = 0;
166+
limit1.data.minMax.limits = Eigen::Vector2f(-1.0f, 1.0f);
167+
original.parameterLimits.push_back(limit1);
168+
169+
ParameterLimit limit2;
170+
limit2.type = LimitType::MinMax;
171+
limit2.weight = 1.0f;
172+
limit2.data.minMax.parameterIndex = 1;
173+
limit2.data.minMax.limits = Eigen::Vector2f(-2.0f, 2.0f);
174+
original.parameterLimits.push_back(limit2);
175+
176+
// Write model definition
177+
const std::string written = writeModelDefinition(
178+
original.skeleton, original.parameterTransform, original.parameterLimits);
179+
180+
// Verify the output contains expected sections
181+
EXPECT_TRUE(written.find("Momentum Model Definition V1.0") != std::string::npos);
182+
EXPECT_TRUE(written.find("[ParameterTransform]") != std::string::npos);
183+
EXPECT_TRUE(written.find("[ParameterSets]") != std::string::npos);
184+
EXPECT_TRUE(written.find("[PoseConstraints]") != std::string::npos);
185+
EXPECT_TRUE(written.find("[ParameterLimits]") != std::string::npos);
186+
187+
// Verify parameter transform content
188+
EXPECT_TRUE(written.find("root.tx = 1*tx") != std::string::npos);
189+
EXPECT_TRUE(written.find("root.ty = 1*ty") != std::string::npos);
190+
EXPECT_TRUE(written.find("child.tz = 2*tz + 0.5") != std::string::npos);
191+
192+
// Verify parameter sets
193+
EXPECT_TRUE(written.find("parameterset translation_xy tx ty") != std::string::npos);
194+
195+
// Verify pose constraints
196+
EXPECT_TRUE(written.find("poseconstraints bind_pose tx=0 ty=0") != std::string::npos);
197+
198+
// Round-trip test: parse what we wrote and build a new character
199+
Character roundtrip;
200+
roundtrip.name = "test_character";
201+
roundtrip.skeleton = original.skeleton;
202+
std::tie(roundtrip.parameterTransform, roundtrip.parameterLimits) = loadModelDefinition(
203+
std::span<const std::byte>(
204+
reinterpret_cast<const std::byte*>(written.data()), written.size()),
205+
original.skeleton);
206+
207+
// Use the comprehensive compareChars helper to verify everything matches
208+
compareChars(original, roundtrip, false);
209+
}

0 commit comments

Comments
 (0)