Skip to content

Commit 712d944

Browse files
Optimize LSTM (#1130)
* Optimize LSTM * Turn off cnn and cache optimizations
1 parent 84259c5 commit 712d944

File tree

13 files changed

+36
-34
lines changed

13 files changed

+36
-34
lines changed

NeoML/src/Dnn/BaseLayer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ void CBaseLayer::AllocateOutputBlobs()
241241
return;
242242
}
243243

244+
CMemoryModeSwitcher switcher( MathEngine(), GetDnn()->isReuseMemoryMode );
245+
244246
for( int i = 0; i < outputDescs.Size(); ++i ) {
245247
if( outputBlobs[i] == nullptr ) {
246248
outputBlobs[i] = CDnnBlob::CreateBlob( MathEngine(), outputDescs[i].GetDataType(), outputDescs[i] );

NeoML/src/Dnn/Dnn.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ static constexpr int dnnVersion = 2000;
812812
void CDnn::Serialize( CArchive& archive )
813813
{
814814
NeoAssertMsg( !IsReferenceDnn(), "For ReferenceDnn serializing is restricted" );
815-
size_t before = mathEngine.GetCurrentMemoryUsage();
815+
816816
int version = dnnVersion;
817817
archive.Serialize( version );
818818

@@ -855,9 +855,6 @@ void CDnn::Serialize( CArchive& archive )
855855
// In order to avoid the CDnnSolver::Reset for the next solver
856856
rebuild();
857857
}
858-
size_t after = mathEngine.GetCurrentMemoryUsage();
859-
860-
OptimizeDnnOnLoad(*this, after - before);
861858
}
862859

863860
void CDnn::SerializeCheckpoint( CArchive& archive )

NeoML/src/Dnn/Layers/LstmLayer.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ void CLstmLayer::Serialize( CArchive& archive )
350350
void CLstmLayer::RunOnce()
351351
{
352352
if( MathEngine().GetType() == MET_Cpu &&
353-
!isInCompatibilityMode &&
354353
!IsBackwardPerformed() &&
355354
!IsLearningPerformed() &&
356355
recurrentActivation == AF_Sigmoid )
@@ -471,7 +470,7 @@ void CLstmLayer::initDesc()
471470
: inputHiddenLayer->FreeTerms()->GetData();
472471
CConstFloatHandle recurrentFreeTerm = recurHiddenLayer->FreeTerms() == nullptr ? CConstFloatHandle()
473472
: recurHiddenLayer->FreeTerms()->GetData();
474-
lstmDesc = MathEngine().InitLstm( GetHiddenSize(), inputBlobs[0]->GetObjectSize(),
473+
lstmDesc = MathEngine().InitLstm( isInCompatibilityMode, GetHiddenSize(), inputBlobs[0]->GetObjectSize(),
475474
inputHiddenLayer->Weights()->GetData(), inputFreeTerm,
476475
recurHiddenLayer->Weights()->GetData(), recurrentFreeTerm );
477476
}

NeoMathEngine/include/NeoMathEngine/NeoMathEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ class NEOMATHENGINE_API IDnnEngine : public IBlasEngine {
10021002
const CFloatHandle& inputDiff ) = 0;
10031003

10041004
// Creates descriptor of LSTM with given weights be created.
1005-
virtual CLstmDesc* InitLstm( int hiddenSize, int objectSize,
1005+
virtual CLstmDesc* InitLstm( bool isCompatibleMode, int hiddenSize, int objectSize,
10061006
const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
10071007
const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm ) = 0;
10081008
virtual void Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, int sequenceCount,

NeoMathEngine/src/CPU/CpuMathEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ class CCpuMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
548548
void BertConvBackward( const CConstFloatHandle& dataHandle, const CConstFloatHandle& kernelHandle,
549549
const CConstFloatHandle& outDiffHandle, int seqLen, int batchSize, int numHeads, int headSize, int kernelSize,
550550
const CFloatHandle& dataDiffHandle, const CFloatHandle& kernelDiffHandle ) override;
551-
CLstmDesc* InitLstm( int hiddenSize, int objectSize,
551+
CLstmDesc* InitLstm( bool isCompatibleMode, int hiddenSize, int objectSize,
552552
const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
553553
const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm ) override;
554554
void Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, int sequenceCount,

NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ static const float* initLstmFreeTerm( const CFloatHandleVar* freeTermVar, const
6363

6464
//-------------------------------------------------------------------------------------------------------------------------
6565

66-
CMathEngineLstmDesc::CMathEngineLstmDesc(
66+
CMathEngineLstmDesc::CMathEngineLstmDesc( bool isCompatibleMode,
6767
int hiddenSize, int objectSize, const CConstFloatHandle& inputWeights,
6868
const CConstFloatHandle& inputFreeTerm, const CConstFloatHandle& recurWeights,
6969
const CConstFloatHandle& recurFreeTerm ) :
70+
IsCompatibleMode( IsCompatibleMode ),
7071
HiddenSize( hiddenSize ),
7172
ObjectSize( objectSize ),
7273
InputWeights( GetRaw( inputWeights ) ),
@@ -78,11 +79,11 @@ CMathEngineLstmDesc::CMathEngineLstmDesc(
7879

7980
CMathEngineLstmDesc::~CMathEngineLstmDesc() = default;
8081

81-
CLstmDesc* CCpuMathEngine::InitLstm( int hiddenSize, int objectSize,
82+
CLstmDesc* CCpuMathEngine::InitLstm( bool isCompatibleMode, int hiddenSize, int objectSize,
8283
const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
8384
const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm )
8485
{
85-
return new CMathEngineLstmDesc( hiddenSize, objectSize, inputWeights, inputFreeTerm,
86+
return new CMathEngineLstmDesc( isCompatibleMode, hiddenSize, objectSize, inputWeights, inputFreeTerm,
8687
recurrentWeights, recurrentFreeTerm );
8788
}
8889

@@ -131,8 +132,11 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
131132
// Write state data directly to output or create temporary blob for recurent
132133
std::unique_ptr<CFloatHandleStackVar> stateBackLinkVar;
133134
if( outputStateBackLink.IsNull() ) {
134-
stateBackLinkVar.reset( new CFloatHandleStackVar( *this, sequenceCount * lstmDesc.HiddenSize ) );
135+
stateBackLinkVar.reset( new CFloatHandleStackVar( *this, sequenceLength * sequenceCount * lstmDesc.HiddenSize ) );
135136
}
137+
138+
CSequenceWrapper<float> output(outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize);
139+
136140
CSequenceWrapper<float> stateBackLink(
137141
outputStateBackLink.IsNull() ? stateBackLinkVar->GetHandle() : outputStateBackLink,
138142
outputStateBackLink.IsNull() ? 1 : sequenceLength,
@@ -146,7 +150,11 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
146150
fcLen, sequenceCount * 4 * lstmDesc.HiddenSize );
147151

148152
// Emulate working of LSTM recurrent implementation
149-
CSequenceWrapper<float> mainBackLink( outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize );
153+
std::unique_ptr<CFloatHandleStackVar> mainBackLinkVar;
154+
if (lstmDesc.IsCompatibleMode) {
155+
mainBackLinkVar.reset(new CFloatHandleStackVar(*this, sequenceLength * sequenceCount * lstmDesc.HiddenSize));
156+
}
157+
CSequenceWrapper<float> mainBackLink(lstmDesc.IsCompatibleMode ? mainBackLinkVar->GetHandle() : outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize );
150158
initializeBacklink( inputMainBackLink, mainBackLink );
151159

152160
CSequenceWrapper<const float> input( inputHandle, sequenceLength, sequenceCount * lstmDesc.ObjectSize );
@@ -197,14 +205,8 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
197205
sequenceCount, resultWidth, resultWidth, resultWidth, lstmDesc.FreeTerm );
198206
}
199207

200-
// if outputMainBackLink != output then we are in compatibility mode
201-
if( simdMathEngine != nullptr ) {
202-
simdMathEngine->RunOnceRestOfLstm( &lstmDesc, sequenceCount, fullyConnectedResult[outputPos],
203-
stateBackLink[inputPos], stateBackLink[outputPos], mainBackLink[outputPos] );
204-
} else {
205-
lstmDesc.RunOnceRestOfLstm( sequenceCount, fullyConnectedResult[outputPos], stateBackLink[inputPos],
206-
stateBackLink[outputPos], mainBackLink[outputPos] );
207-
}
208+
lstmDesc.RunOnceRestOfLstm(lstmDesc.IsCompatibleMode, sequenceCount, fullyConnectedResult[outputPos], stateBackLink[inputPos],
209+
stateBackLink[outputPos], mainBackLink[outputPos], output[outputPos]);
208210
--seqElemsInBuffer;
209211
}
210212
}

NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ limitations under the License.
2626
namespace NeoML {
2727

2828
struct CMathEngineLstmDesc : public CLstmDesc {
29-
CMathEngineLstmDesc(
29+
CMathEngineLstmDesc( bool isCompatibleMode,
3030
int hiddenSize, int objectSize, const CConstFloatHandle& inputWeights,
3131
const CConstFloatHandle& inputFreeTerm, const CConstFloatHandle& recurWeights,
3232
const CConstFloatHandle& recurFreeTerm );
3333
~CMathEngineLstmDesc() override;
3434

3535
static int constexpr GatesNum = 4;
3636

37+
const bool IsCompatibleMode;
3738
const int HiddenSize;
3839
const int ObjectSize;
3940
const float* const InputWeights;
@@ -42,12 +43,12 @@ struct CMathEngineLstmDesc : public CLstmDesc {
4243
const std::unique_ptr<CFloatHandleVar> FreeTermVar;
4344
const float* const FreeTerm;
4445

45-
virtual void RunOnceRestOfLstm( int objectCount, float* fullyConnectedResult, const float* inputStateBackLink,
46-
float* outputStateBackLink, float* outputMainBackLink );
46+
virtual void RunOnceRestOfLstm( bool isCompatibleMode, int objectCount, float* fullyConnectedResult, const float* inputStateBackLink,
47+
float* outputStateBackLink, float* outputMainBackLink, float* output );
4748
};
4849

49-
inline void CMathEngineLstmDesc::RunOnceRestOfLstm( int objectCount, float* fullyConnectedResult,
50-
const float* inputStateBackLink, float* outputStateBackLink, float* outputMainBackLink )
50+
inline void CMathEngineLstmDesc::RunOnceRestOfLstm( bool isCompatibleMode, int objectCount, float* fullyConnectedResult,
51+
const float* inputStateBackLink, float* outputStateBackLink, float* outputMainBackLink, float* compatibleOutput )
5152
{
5253
// Elementwise summ of fully connected layers' results (inplace)
5354
const int resultMatrixWidth = CMathEngineLstmDesc::GatesNum * HiddenSize;
@@ -76,10 +77,10 @@ inline void CMathEngineLstmDesc::RunOnceRestOfLstm( int objectCount, float* full
7677
NeoML::vectorAdd( forgetData, inputData, outputStateBackLink, HiddenSize );
7778

7879
// Apply tanh to state baclink
79-
NeoML::vectorTanh( outputStateBackLink, inputData, HiddenSize );
80+
NeoML::vectorTanh( outputStateBackLink, isCompatibleMode ? compatibleOutput : inputData, HiddenSize );
8081

8182
// Multiply output gate with result of previous operation
82-
NeoML::vectorEltwiseMultiply( outputData, inputData, outputMainBackLink, HiddenSize );
83+
NeoML::vectorEltwiseMultiply( outputData, isCompatibleMode ? compatibleOutput : inputData, outputMainBackLink, HiddenSize );
8384

8485
inputTanhData += resultMatrixWidth;
8586
forgetData += resultMatrixWidth;
@@ -88,6 +89,7 @@ inline void CMathEngineLstmDesc::RunOnceRestOfLstm( int objectCount, float* full
8889
inputStateBackLink += HiddenSize;
8990
outputStateBackLink += HiddenSize;
9091
outputMainBackLink += HiddenSize;
92+
compatibleOutput += HiddenSize;
9193
}
9294
}
9395

NeoMathEngine/src/GPU/CUDA/CudaMathEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class CCudaMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
551551
void BertConvBackward( const CConstFloatHandle& dataHandle, const CConstFloatHandle& kernelHandle,
552552
const CConstFloatHandle& outDiffHandle, int seqLen, int batchSize, int numHeads, int headSize, int kernelSize,
553553
const CFloatHandle& dataDiffHandle, const CFloatHandle& kernelDiffHandle ) override;
554-
CLstmDesc* InitLstm( int hiddenSize, int objectSize,
554+
CLstmDesc* InitLstm( bool isCompatibleMode, int hiddenSize, int objectSize,
555555
const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
556556
const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm ) override;
557557
void Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, int sequenceCount,

NeoMathEngine/src/GPU/CUDA/CudaMathEngineDnnLstm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222

2323
namespace NeoML {
2424

25-
CLstmDesc* CCudaMathEngine::InitLstm( int, int, const CConstFloatHandle&, const CConstFloatHandle&,
25+
CLstmDesc* CCudaMathEngine::InitLstm( bool, int, int, const CConstFloatHandle&, const CConstFloatHandle&,
2626
const CConstFloatHandle&, const CConstFloatHandle& )
2727
{
2828
ASSERT_EXPR( false );

NeoMathEngine/src/GPU/Metal/MetalMathEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
560560
void BertConvBackward( const CConstFloatHandle& dataHandle, const CConstFloatHandle& kernelHandle,
561561
const CConstFloatHandle& outDiffHandle, int seqLen, int batchSize, int numHeads, int headSize, int kernelSize,
562562
const CFloatHandle& dataDiffHandle, const CFloatHandle& kernelDiffHandle ) override;
563-
CLstmDesc* InitLstm( int hiddenSize, int objectSize,
563+
CLstmDesc* InitLstm( bool isCompatibleMode, int hiddenSize, int objectSize,
564564
const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
565565
const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm ) override;
566566
void Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, int sequenceCount,

0 commit comments

Comments
 (0)