Skip to content

Commit 5a093e2

Browse files
cstollmetameta-codesync[bot]
authored andcommitted
add wrapper for joint-to-joint error function (#806)
Summary: Pull Request resolved: #806 See title Reviewed By: jeongseok-meta Differential Revision: D86779635 fbshipit-source-id: dd93fc2064a229064db9937f9d02d82547511def
1 parent 80749d3 commit 5a093e2

File tree

2 files changed

+337
-0
lines changed

2 files changed

+337
-0
lines changed

pymomentum/solver2/solver2_error_functions.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <momentum/character_solver/collision_error_function.h>
1313
#include <momentum/character_solver/distance_error_function.h>
1414
#include <momentum/character_solver/fixed_axis_error_function.h>
15+
#include <momentum/character_solver/joint_to_joint_distance_error_function.h>
1516
#include <momentum/character_solver/limit_error_function.h>
1617
#include <momentum/character_solver/model_parameters_error_function.h>
1718
#include <momentum/character_solver/normal_error_function.h>
@@ -1267,6 +1268,180 @@ or maintaining the width of body parts.)")
12671268
"Returns the number of constraints.");
12681269
}
12691270

1271+
void defJointToJointDistanceErrorFunction(py::module_& m) {
1272+
py::class_<mm::JointToJointDistanceConstraintT<float>>(m, "JointToJointDistanceConstraint")
1273+
.def(
1274+
"__repr__",
1275+
[](const mm::JointToJointDistanceConstraintT<float>& self) {
1276+
return fmt::format(
1277+
"JointToJointDistanceConstraint(joint1={}, joint2={}, offset1=[{:.3f}, {:.3f}, {:.3f}], offset2=[{:.3f}, {:.3f}, {:.3f}], weight={}, target_distance={})",
1278+
self.joint1,
1279+
self.joint2,
1280+
self.offset1.x(),
1281+
self.offset1.y(),
1282+
self.offset1.z(),
1283+
self.offset2.x(),
1284+
self.offset2.y(),
1285+
self.offset2.z(),
1286+
self.weight,
1287+
self.targetDistance);
1288+
})
1289+
.def_readonly(
1290+
"joint1",
1291+
&mm::JointToJointDistanceConstraintT<float>::joint1,
1292+
"The index of the first joint")
1293+
.def_readonly(
1294+
"joint2",
1295+
&mm::JointToJointDistanceConstraintT<float>::joint2,
1296+
"The index of the second joint")
1297+
.def_readonly(
1298+
"offset1",
1299+
&mm::JointToJointDistanceConstraintT<float>::offset1,
1300+
"The offset from joint1 in the local coordinate system of joint1")
1301+
.def_readonly(
1302+
"offset2",
1303+
&mm::JointToJointDistanceConstraintT<float>::offset2,
1304+
"The offset from joint2 in the local coordinate system of joint2")
1305+
.def_readonly(
1306+
"weight",
1307+
&mm::JointToJointDistanceConstraintT<float>::weight,
1308+
"The weight of the constraint")
1309+
.def_readonly(
1310+
"target_distance",
1311+
&mm::JointToJointDistanceConstraintT<float>::targetDistance,
1312+
"The target distance between the two points");
1313+
1314+
py::class_<
1315+
mm::JointToJointDistanceErrorFunctionT<float>,
1316+
mm::SkeletonErrorFunction,
1317+
std::shared_ptr<mm::JointToJointDistanceErrorFunctionT<float>>>(
1318+
m,
1319+
"JointToJointDistanceErrorFunction",
1320+
R"(Error function that penalizes deviation from a target distance between two points attached to different joints.
1321+
1322+
This is useful for enforcing distance constraints between different parts of a character,
1323+
such as maintaining a fixed distance between hands or ensuring two joints stay a certain distance apart.)")
1324+
.def(
1325+
"__repr__",
1326+
[](const mm::JointToJointDistanceErrorFunctionT<float>& self) {
1327+
return fmt::format(
1328+
"JointToJointDistanceErrorFunction(weight={}, num_constraints={})",
1329+
self.getWeight(),
1330+
self.getConstraints().size());
1331+
})
1332+
.def(
1333+
py::init<>(
1334+
[](const mm::Character& character,
1335+
float weight) -> std::shared_ptr<mm::JointToJointDistanceErrorFunctionT<float>> {
1336+
validateWeight(weight, "weight");
1337+
auto result =
1338+
std::make_shared<mm::JointToJointDistanceErrorFunctionT<float>>(character);
1339+
result->setWeight(weight);
1340+
return result;
1341+
}),
1342+
R"(Initialize a JointToJointDistanceErrorFunction.
1343+
1344+
:param character: The character to use.
1345+
:param weight: The weight applied to the error function.)",
1346+
py::keep_alive<1, 2>(),
1347+
py::arg("character"),
1348+
py::kw_only(),
1349+
py::arg("weight") = 1.0f)
1350+
.def(
1351+
"add_constraint",
1352+
[](mm::JointToJointDistanceErrorFunctionT<float>& self,
1353+
size_t joint1,
1354+
const Eigen::Vector3f& offset1,
1355+
size_t joint2,
1356+
const Eigen::Vector3f& offset2,
1357+
float targetDistance,
1358+
float weight) {
1359+
validateJointIndex(joint1, "joint1", self.getSkeleton());
1360+
validateJointIndex(joint2, "joint2", self.getSkeleton());
1361+
validateWeight(weight, "weight");
1362+
self.addConstraint(joint1, offset1, joint2, offset2, targetDistance, weight);
1363+
},
1364+
R"(Adds a joint-to-joint distance constraint to the error function.
1365+
1366+
:param joint1: The index of the first joint.
1367+
:param offset1: The offset from joint1 in the local coordinate system of joint1.
1368+
:param joint2: The index of the second joint.
1369+
:param offset2: The offset from joint2 in the local coordinate system of joint2.
1370+
:param target_distance: The desired distance between the two points in world space.
1371+
:param weight: The weight of the constraint.)",
1372+
py::arg("joint1"),
1373+
py::arg("offset1"),
1374+
py::arg("joint2"),
1375+
py::arg("offset2"),
1376+
py::arg("target_distance"),
1377+
py::arg("weight") = 1.0f)
1378+
.def(
1379+
"add_constraints",
1380+
[](mm::JointToJointDistanceErrorFunctionT<float>& self,
1381+
const py::array_t<int>& joint1,
1382+
const py::array_t<float>& offset1,
1383+
const py::array_t<int>& joint2,
1384+
const py::array_t<float>& offset2,
1385+
const py::array_t<float>& targetDistance,
1386+
const std::optional<py::array_t<float>>& weight) {
1387+
ArrayShapeValidator validator;
1388+
const int nConsIdx = -1;
1389+
validator.validate(joint1, "joint1", {nConsIdx}, {"n_cons"});
1390+
validateJointIndex(joint1, "joint1", self.getSkeleton());
1391+
validator.validate(offset1, "offset1", {nConsIdx, 3}, {"n_cons", "xyz"});
1392+
validator.validate(joint2, "joint2", {nConsIdx}, {"n_cons"});
1393+
validateJointIndex(joint2, "joint2", self.getSkeleton());
1394+
validator.validate(offset2, "offset2", {nConsIdx, 3}, {"n_cons", "xyz"});
1395+
validator.validate(targetDistance, "target_distance", {nConsIdx}, {"n_cons"});
1396+
validator.validate(weight, "weight", {nConsIdx}, {"n_cons"});
1397+
validateWeights(weight, "weight");
1398+
1399+
auto joint1Acc = joint1.unchecked<1>();
1400+
auto offset1Acc = offset1.unchecked<2>();
1401+
auto joint2Acc = joint2.unchecked<1>();
1402+
auto offset2Acc = offset2.unchecked<2>();
1403+
auto targetDistanceAcc = targetDistance.unchecked<1>();
1404+
auto weightAcc =
1405+
weight.has_value() ? std::make_optional(weight->unchecked<1>()) : std::nullopt;
1406+
1407+
py::gil_scoped_release release;
1408+
1409+
for (py::ssize_t i = 0; i < joint1.shape(0); ++i) {
1410+
self.addConstraint(
1411+
joint1Acc(i),
1412+
Eigen::Vector3f(offset1Acc(i, 0), offset1Acc(i, 1), offset1Acc(i, 2)),
1413+
joint2Acc(i),
1414+
Eigen::Vector3f(offset2Acc(i, 0), offset2Acc(i, 1), offset2Acc(i, 2)),
1415+
targetDistanceAcc(i),
1416+
weightAcc.has_value() ? (*weightAcc)(i) : 1.0f);
1417+
}
1418+
},
1419+
R"(Adds multiple joint-to-joint distance constraints to the error function.
1420+
1421+
:param joint1: A numpy array of indices for the first joints.
1422+
:param offset1: A numpy array of shape (n, 3) for offsets from joint1 in local coordinates.
1423+
:param joint2: A numpy array of indices for the second joints.
1424+
:param offset2: A numpy array of shape (n, 3) for offsets from joint2 in local coordinates.
1425+
:param target_distance: A numpy array of desired distances between point pairs.
1426+
:param weight: A numpy array of weights for the constraints.)",
1427+
py::arg("joint1"),
1428+
py::arg("offset1"),
1429+
py::arg("joint2"),
1430+
py::arg("offset2"),
1431+
py::arg("target_distance"),
1432+
py::arg("weight") = std::nullopt)
1433+
.def(
1434+
"clear_constraints",
1435+
&mm::JointToJointDistanceErrorFunctionT<float>::clearConstraints,
1436+
"Clears all joint-to-joint distance constraints from the error function.")
1437+
.def_property_readonly(
1438+
"constraints",
1439+
[](const mm::JointToJointDistanceErrorFunctionT<float>& self) {
1440+
return self.getConstraints();
1441+
},
1442+
"Returns the list of joint-to-joint distance constraints.");
1443+
}
1444+
12701445
} // namespace
12711446

12721447
void addErrorFunctions(py::module_& m) {
@@ -2241,6 +2416,9 @@ rotation matrix to a target rotation.)")
22412416

22422417
// Vertex-to-vertex distance error function
22432418
defVertexVertexDistanceErrorFunction(m);
2419+
2420+
// Joint-to-joint distance error function
2421+
defJointToJointDistanceErrorFunction(m);
22442422
}
22452423

22462424
} // namespace pymomentum

pymomentum/test/test_solver2.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,6 +1626,165 @@ def test_vertex_vertex_distance_constraint(self) -> None:
16261626
self.assertIn("vertex_index1=0", constraint_repr)
16271627
self.assertIn("vertex_index2=2", constraint_repr)
16281628

1629+
def test_joint_to_joint_distance_constraint(self) -> None:
1630+
"""Test JointToJointDistanceErrorFunction to ensure joints are pulled to target distance."""
1631+
1632+
# Create a test character
1633+
character = pym_geometry.create_test_character(num_joints=4)
1634+
1635+
n_params = character.parameter_transform.size
1636+
1637+
# Ensure repeatability in the rng:
1638+
torch.manual_seed(0)
1639+
model_params_init = torch.zeros(n_params, dtype=torch.float32)
1640+
1641+
# Choose two joints to constrain - use joints that are initially far apart
1642+
joint_index1 = 0
1643+
joint_index2 = character.skeleton.size - 1 # Last joint
1644+
offset1 = np.array([0.5, 0.0, 0.0], dtype=np.float32)
1645+
offset2 = np.array([0.0, 0.5, 0.0], dtype=np.float32)
1646+
target_distance = 1.5 # Target distance between the two points
1647+
weight = 1.0
1648+
1649+
# Get initial positions of the points
1650+
skel_state_init = pym_geometry.model_parameters_to_skeleton_state(
1651+
character, model_params_init
1652+
)
1653+
initial_point1 = pym_skel_state.transform_points(
1654+
skel_state_init[joint_index1],
1655+
torch.from_numpy(offset1),
1656+
)
1657+
initial_point2 = pym_skel_state.transform_points(
1658+
skel_state_init[joint_index2],
1659+
torch.from_numpy(offset2),
1660+
)
1661+
initial_distance = torch.norm(initial_point2 - initial_point1).item()
1662+
1663+
# Create JointToJointDistanceErrorFunction
1664+
joint_distance_error = pym_solver2.JointToJointDistanceErrorFunction(character)
1665+
1666+
# Test basic properties
1667+
self.assertEqual(len(joint_distance_error.constraints), 0)
1668+
1669+
# Add a single constraint
1670+
joint_distance_error.add_constraint(
1671+
joint1=joint_index1,
1672+
offset1=offset1,
1673+
joint2=joint_index2,
1674+
offset2=offset2,
1675+
target_distance=target_distance,
1676+
weight=weight,
1677+
)
1678+
1679+
# Verify constraint was added
1680+
self.assertEqual(len(joint_distance_error.constraints), 1)
1681+
1682+
constraint = joint_distance_error.constraints[0]
1683+
self.assertEqual(constraint.joint1, joint_index1)
1684+
self.assertEqual(constraint.joint2, joint_index2)
1685+
self.assertTrue(np.allclose(constraint.offset1, offset1))
1686+
self.assertTrue(np.allclose(constraint.offset2, offset2))
1687+
self.assertAlmostEqual(constraint.weight, weight)
1688+
self.assertAlmostEqual(constraint.target_distance, target_distance)
1689+
1690+
# Create solver function with the joint distance error
1691+
solver_function = pym_solver2.SkeletonSolverFunction(
1692+
character, [joint_distance_error]
1693+
)
1694+
1695+
# Set solver options
1696+
solver_options = pym_solver2.GaussNewtonSolverOptions()
1697+
solver_options.max_iterations = 100
1698+
solver_options.regularization = 1e-5
1699+
1700+
# Create and run the solver
1701+
solver = pym_solver2.GaussNewtonSolver(solver_function, solver_options)
1702+
model_params_final = solver.solve(model_params_init.numpy())
1703+
1704+
# Convert final model parameters to skeleton state
1705+
skel_state_final = pym_geometry.model_parameters_to_skeleton_state(
1706+
character, torch.from_numpy(model_params_final)
1707+
)
1708+
1709+
# Compute final positions of the points
1710+
final_point1 = pym_skel_state.transform_points(
1711+
skel_state_final[joint_index1],
1712+
torch.from_numpy(offset1),
1713+
)
1714+
final_point2 = pym_skel_state.transform_points(
1715+
skel_state_final[joint_index2],
1716+
torch.from_numpy(offset2),
1717+
)
1718+
final_distance = torch.norm(final_point2 - final_point1).item()
1719+
1720+
# Assert that the final distance is close to the target distance
1721+
self.assertAlmostEqual(
1722+
final_distance,
1723+
target_distance,
1724+
delta=1e-3,
1725+
msg=f"Final distance {final_distance} does not match target {target_distance}",
1726+
)
1727+
1728+
# Verify that the distance actually changed from the initial distance
1729+
self.assertNotAlmostEqual(
1730+
initial_distance,
1731+
final_distance,
1732+
delta=1e-1,
1733+
msg=f"Distance did not change significantly from initial {initial_distance} to final {final_distance}",
1734+
)
1735+
1736+
# Test multiple constraints using add_constraints
1737+
joint_distance_error.clear_constraints()
1738+
self.assertEqual(len(joint_distance_error.constraints), 0)
1739+
1740+
# Add multiple constraints
1741+
joint_indices1 = np.array([0, 1], dtype=np.int32)
1742+
offsets1 = np.array([[0.5, 0.0, 0.0], [0.0, 0.5, 0.0]], dtype=np.float32)
1743+
joint_indices2 = np.array([2, 3], dtype=np.int32)
1744+
offsets2 = np.array([[0.0, 0.0, 0.5], [0.5, 0.5, 0.0]], dtype=np.float32)
1745+
target_distances = np.array([0.8, 1.2], dtype=np.float32)
1746+
weights = np.array([1.0, 2.0], dtype=np.float32)
1747+
1748+
joint_distance_error.add_constraints(
1749+
joint1=joint_indices1,
1750+
offset1=offsets1,
1751+
joint2=joint_indices2,
1752+
offset2=offsets2,
1753+
target_distance=target_distances,
1754+
weight=weights,
1755+
)
1756+
1757+
# Verify multiple constraints were added
1758+
self.assertEqual(len(joint_distance_error.constraints), 2)
1759+
constraints = joint_distance_error.constraints
1760+
1761+
# Check first constraint
1762+
self.assertEqual(constraints[0].joint1, 0)
1763+
self.assertEqual(constraints[0].joint2, 2)
1764+
self.assertTrue(np.allclose(constraints[0].offset1, [0.5, 0.0, 0.0]))
1765+
self.assertTrue(np.allclose(constraints[0].offset2, [0.0, 0.0, 0.5]))
1766+
self.assertAlmostEqual(constraints[0].weight, 1.0)
1767+
self.assertAlmostEqual(constraints[0].target_distance, 0.8)
1768+
1769+
# Check second constraint
1770+
self.assertEqual(constraints[1].joint1, 1)
1771+
self.assertEqual(constraints[1].joint2, 3)
1772+
self.assertTrue(np.allclose(constraints[1].offset1, [0.0, 0.5, 0.0]))
1773+
self.assertTrue(np.allclose(constraints[1].offset2, [0.5, 0.5, 0.0]))
1774+
self.assertAlmostEqual(constraints[1].weight, 2.0)
1775+
self.assertAlmostEqual(constraints[1].target_distance, 1.2)
1776+
1777+
# Test string representation
1778+
repr_str = repr(joint_distance_error)
1779+
self.assertIn("JointToJointDistanceErrorFunction", repr_str)
1780+
self.assertIn("num_constraints=2", repr_str)
1781+
1782+
# Test constraint string representation
1783+
constraint_repr = repr(constraints[0])
1784+
self.assertIn("JointToJointDistanceConstraint", constraint_repr)
1785+
self.assertIn("joint1=0", constraint_repr)
1786+
self.assertIn("joint2=2", constraint_repr)
1787+
16291788
def test_weight_validation(self) -> None:
16301789
"""Test that error functions throw ValueError when negative weights are passed."""
16311790

0 commit comments

Comments
 (0)