Skip to content

Commit d6f7e14

Browse files
cdtwiggmeta-codesync[bot]
authored andcommitted
Reduce memory usage of sequence_solver (#787)
Summary: Pull Request resolved: #787 In the sequence_solver_function, currently we store skeleton and mesh states for every frame. This is intended to be analogous to the standard skeleton_solver_function which reduces the amount of memory allocation that needs to happen inside the solver loop. The problem with doing this, though, is that we end up storing hundreds or thousands of MeshStates, which can be a huge amount of memory. When solving a sequence, it turns out to be more efficient to just construct the meshstates and skeletonstates on the fly, similar to what we do for Jacobian matrices. We can leverage the State functionality in dispenso to make this thread-local so we don't have to allocate every single time but can reuse in a given worker. Reviewed By: cstollmeta Differential Revision: D86348621 fbshipit-source-id: 780d81bd373e2c251b896cdf97fa4ef498680a6f
1 parent 67ec8e3 commit d6f7e14

File tree

4 files changed

+303
-144
lines changed

4 files changed

+303
-144
lines changed

momentum/character_sequence_solver/sequence_solver.cpp

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,18 @@ size_t computeJacobianSize(std::vector<std::shared_ptr<ErrorFunctionType>>& func
9696
// Compute the full per-frame Jacobian and update the skeleton state:
9797
template <typename T>
9898
std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t>
99-
SequenceSolverT<T>::computePerFrameJacobian(SequenceSolverFunctionT<T>* fn, size_t iFrame) {
99+
SequenceSolverT<T>::computePerFrameJacobian(
100+
SequenceSolverFunctionT<T>* fn,
101+
size_t iFrame,
102+
SkeletonStateT<T>& skelState,
103+
MeshStateT<T>& meshState) {
100104
const auto& frameParameters = fn->frameParameters_[iFrame];
101-
auto& skelState = fn->states_[iFrame];
102-
auto& meshState = fn->meshStates_[iFrame];
103105

104106
const auto& character = fn->getCharacter();
105107

106108
skelState.set(fn->parameterTransform_.apply(frameParameters), character.skeleton);
107109

108-
if (fn->needsMesh()) {
110+
if (fn->needsMeshPerFrame()) {
109111
meshState.update(frameParameters, skelState, character);
110112
}
111113

@@ -143,9 +145,16 @@ SequenceSolverT<T>::computePerFrameJacobian(SequenceSolverFunctionT<T>* fn, size
143145
}
144146

145147
template <typename T>
146-
std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t> SequenceSolverT<
147-
T>::computeSequenceJacobian(SequenceSolverFunctionT<T>* fn, size_t iFrame, size_t bandwidth) {
148+
std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t>
149+
SequenceSolverT<T>::computeSequenceJacobian(
150+
SequenceSolverFunctionT<T>* fn,
151+
size_t iFrame,
152+
size_t bandwidth,
153+
std::span<const SkeletonStateT<T>> skelStates,
154+
std::span<const MeshStateT<T>> meshStates) {
148155
const size_t bandwidth_cur = std::min(fn->getNumFrames() - iFrame, bandwidth);
156+
MT_CHECK(skelStates.size() >= bandwidth_cur);
157+
MT_CHECK(meshStates.size() >= bandwidth_cur);
149158

150159
const auto nFullParameters = fn->parameterTransform_.numAllModelParameters();
151160
const size_t jacobianSize = padForSSE(computeJacobianSize(fn->sequenceErrorFunctions_[iFrame]));
@@ -165,9 +174,9 @@ std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t> SequenceSolverT
165174
const size_t n = errf->getJacobianSize();
166175
int rows = 0;
167176
errorCur += errf->getJacobian(
168-
gsl::make_span(fn->frameParameters_).subspan(iFrame, nFrames),
169-
gsl::make_span(fn->states_).subspan(iFrame, nFrames),
170-
gsl::make_span(fn->meshStates_).subspan(iFrame, nFrames),
177+
std::span(fn->frameParameters_).subspan(iFrame, nFrames),
178+
skelStates,
179+
meshStates,
171180
jacobian.block(offset, 0, n, nFrames * nFullParameters),
172181
residual.middleRows(offset, n),
173182
rows);
@@ -215,9 +224,12 @@ double SequenceSolverT<T>::processPerFrameErrors_serial(
215224
SequenceSolverFunctionT<T>* fn,
216225
OnlineBandedHouseholderQR<T>& qrSolver,
217226
ProgressBar* progress) {
227+
SkeletonStateT<T> skelState;
228+
MeshStateT<T> meshState;
218229
double errorSum = 0;
219230
for (size_t iFrame = 0; iFrame < fn->getNumFrames(); ++iFrame) {
220-
auto [jacobian, residual, errorCur, nFunctions] = computePerFrameJacobian(fn, iFrame);
231+
auto [jacobian, residual, errorCur, nFunctions] =
232+
computePerFrameJacobian(fn, iFrame, skelState, meshState);
221233
errorSum += errorCur;
222234

223235
if (jacobian.rows() != 0) {
@@ -269,7 +281,9 @@ double SequenceSolverT<T>::processErrorFunctions_parallel(
269281
const std::function<UniversalJacobianResid(
270282
size_t,
271283
SequenceSolverFunctionT<T>*,
272-
OnlineBandedHouseholderQR<T>& qrSolver)>& processJac,
284+
OnlineBandedHouseholderQR<T>&,
285+
SkeletonStateT<T>,
286+
MeshStateT<T>&)>& processJac,
273287
SequenceSolverFunctionT<T>* fn,
274288
OnlineBandedHouseholderQR<T>& qrSolver,
275289
ProgressBar* progress) {
@@ -292,15 +306,20 @@ double SequenceSolverT<T>::processErrorFunctions_parallel(
292306
// 1. We compute the full Jacobian/residual for each from each frame.
293307
// 2. We zero out the non-shared parts of the Jacobian in parallel
294308
// 3. We zero out the shared parts of the Jacobian serially in order in the last stage.
309+
std::vector<PerFrameStateT<T>> states;
295310
dispenso::parallel_for(
296-
dispenso::makeChunkedRange(0, end, 1), [&](size_t rangeStart, size_t rangeEnd) {
311+
states,
312+
[]() -> PerFrameStateT<T> { return PerFrameStateT<T>(); },
313+
dispenso::makeChunkedRange(0, end, 1),
314+
[&](PerFrameStateT<T>& state, size_t rangeStart, size_t rangeEnd) {
297315
for (size_t iFrame = rangeStart; iFrame < rangeEnd; ++iFrame) {
298316
if (iFrame >= end) {
299317
// nothing to do:
300318
return;
301319
}
302320

303-
UniversalJacobianResid universalJacRes = processJac(iFrame, fn, qrSolver);
321+
UniversalJacobianResid universalJacRes =
322+
processJac(iFrame, fn, qrSolver, state.skeletonState, state.meshState);
304323

305324
// Need to push into the queue:
306325
std::unique_lock<std::mutex> reorderBufferLock(reorderBufferMutex);
@@ -385,11 +404,14 @@ double SequenceSolverT<T>::processPerFrameErrors_parallel(
385404
return processErrorFunctions_parallel(
386405
[](size_t iFrame,
387406
SequenceSolverFunctionT<T>* fn,
388-
OnlineBandedHouseholderQR<T>& qrSolver) -> UniversalJacobianResid {
407+
OnlineBandedHouseholderQR<T>& qrSolver,
408+
SkeletonStateT<T> skelState,
409+
MeshStateT<T>& meshState) -> UniversalJacobianResid {
389410
// Construct the Jacobian/residual for a single frame and zero out the parts
390411
// of the Jacobian that only affect that frame; this is safe to do because
391412
// the non-shared parameters for each frame don't overlap.
392-
auto [jacobian, residual, errorCur, numFunctions] = computePerFrameJacobian(fn, iFrame);
413+
auto [jacobian, residual, errorCur, numFunctions] =
414+
computePerFrameJacobian(fn, iFrame, skelState, meshState);
393415

394416
qrSolver.zeroBandedPart(
395417
iFrame * fn->perFrameParameterIndices_.size(),
@@ -434,20 +456,47 @@ double SequenceSolverT<T>::processSequenceErrors_serial(
434456
// copies of the perFrameParameterIndices_:
435457
const auto sequenceColumnIndices = buildSequenceColumnIndices(fn, bandwidth);
436458

459+
std::vector<SkeletonStateT<T>> skelStates(bandwidth);
460+
std::vector<MeshStateT<T>> meshStates(bandwidth);
461+
437462
double errorSum = 0;
438463
for (size_t iFrame = 0; iFrame < fn->getNumFrames(); ++iFrame) {
439464
const size_t bandwidth_cur = std::min<size_t>(fn->getNumFrames() - iFrame, bandwidth);
440465

466+
// Determine how many frames are already valid from the previous iteration
467+
// First frame: 0 (compute all), subsequent frames: bandwidth_cur - 1 (reuse all but last)
468+
const size_t numValidFrames = (iFrame == 0) ? 0 : (bandwidth_cur - 1);
469+
470+
// Shift valid frames left by 1 to reuse already-computed states
471+
for (size_t kSubFrame = 0; kSubFrame < numValidFrames; ++kSubFrame) {
472+
skelStates[kSubFrame] = std::move(skelStates[kSubFrame + 1]);
473+
if (fn->needsMeshSequence()) {
474+
meshStates[kSubFrame] = std::move(meshStates[kSubFrame + 1]);
475+
}
476+
}
477+
478+
// Compute new frames (all frames if iFrame == 0, or just the last frame if iFrame > 0)
479+
for (size_t kSubFrame = numValidFrames; kSubFrame < bandwidth_cur; ++kSubFrame) {
480+
skelStates[kSubFrame].set(
481+
fn->parameterTransform_.apply(fn->frameParameters_[iFrame + kSubFrame]),
482+
fn->getCharacter().skeleton);
483+
484+
if (fn->needsMeshSequence()) {
485+
meshStates[kSubFrame].update(
486+
fn->frameParameters_[iFrame + kSubFrame], skelStates[kSubFrame], fn->getCharacter());
487+
}
488+
}
489+
441490
auto [jacobian, residual, errorCur, nFunctions] =
442-
computeSequenceJacobian(fn, iFrame, bandwidth);
491+
computeSequenceJacobian(fn, iFrame, bandwidth, skelStates, meshStates);
443492
errorSum += errorCur;
444493

445494
if (jacobian.rows() != 0) {
446495
qrSolver.addMutating(
447496
iFrame * fn->perFrameParameterIndices_.size(),
448497
ColumnIndexedMatrix<Eigen::MatrixX<T>>(
449498
jacobian,
450-
gsl::make_span(sequenceColumnIndices)
499+
std::span(sequenceColumnIndices)
451500
.subspan(0, bandwidth_cur * fn->perFrameParameterIndices_.size())),
452501
ColumnIndexedMatrix<Eigen::MatrixX<T>>(jacobian, fn->universalParameterIndices_),
453502
residual);

momentum/character_sequence_solver/sequence_solver.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
#pragma once
99

10+
#include <momentum/character/fwd.h>
1011
#include <momentum/character_sequence_solver/fwd.h>
1112
#include <momentum/common/fwd.h>
1213
#include <momentum/math/fwd.h>
1314
#include <momentum/math/online_householder_qr.h>
1415
#include <momentum/solver/solver.h>
1516

1617
#include <functional>
18+
#include <span>
1719

1820
namespace momentum {
1921

@@ -113,20 +115,28 @@ class SequenceSolverT : public SolverT<T> {
113115
const std::function<UniversalJacobianResid(
114116
size_t,
115117
SequenceSolverFunctionT<T>*,
116-
OnlineBandedHouseholderQR<T>&)>& processJac,
118+
OnlineBandedHouseholderQR<T>&,
119+
SkeletonStateT<T> skelState,
120+
MeshStateT<T>& meshState)>& processJac,
117121
SequenceSolverFunctionT<T>* fn,
118122
OnlineBandedHouseholderQR<T>& qrSolver,
119123
ProgressBar* progress);
120124

121125
// Returns the [Jacobian, residual, error] for all the error functions applying to a single frame:
122126
static std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t> computePerFrameJacobian(
123127
SequenceSolverFunctionT<T>* fn,
124-
size_t iFrame);
128+
size_t iFrame,
129+
SkeletonStateT<T>& skelState,
130+
MeshStateT<T>& meshState);
125131

126132
// Returns the [Jacobian, residual, error] for all the sequence error functions starting from a
127133
// single frame:
128-
static std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t>
129-
computeSequenceJacobian(SequenceSolverFunctionT<T>* fn, size_t iFrame, size_t bandwidth);
134+
static std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, double, size_t> computeSequenceJacobian(
135+
SequenceSolverFunctionT<T>* fn,
136+
size_t iFrame,
137+
size_t bandwidth,
138+
std::span<const SkeletonStateT<T>> skelStates,
139+
std::span<const MeshStateT<T>> meshStates);
130140

131141
static std::vector<Eigen::Index> buildSequenceColumnIndices(
132142
const SequenceSolverFunctionT<T>* fn,

0 commit comments

Comments
 (0)