@@ -63,10 +63,11 @@ static const float* initLstmFreeTerm( const CFloatHandleVar* freeTermVar, const
6363
6464// -------------------------------------------------------------------------------------------------------------------------
6565
66- CMathEngineLstmDesc::CMathEngineLstmDesc (
66+ CMathEngineLstmDesc::CMathEngineLstmDesc ( bool isCompatibleMode,
6767 int hiddenSize, int objectSize, const CConstFloatHandle& inputWeights,
6868 const CConstFloatHandle& inputFreeTerm, const CConstFloatHandle& recurWeights,
6969 const CConstFloatHandle& recurFreeTerm ) :
70+ IsCompatibleMode ( IsCompatibleMode ),
7071 HiddenSize ( hiddenSize ),
7172 ObjectSize ( objectSize ),
7273 InputWeights ( GetRaw( inputWeights ) ),
@@ -78,11 +79,11 @@ CMathEngineLstmDesc::CMathEngineLstmDesc(
7879
7980CMathEngineLstmDesc::~CMathEngineLstmDesc () = default ;
8081
81- CLstmDesc* CCpuMathEngine::InitLstm ( int hiddenSize, int objectSize,
82+ CLstmDesc* CCpuMathEngine::InitLstm ( bool isCompatibleMode, int hiddenSize, int objectSize,
8283 const CConstFloatHandle& inputWeights, const CConstFloatHandle& inputFreeTerm,
8384 const CConstFloatHandle& recurrentWeights, const CConstFloatHandle& recurrentFreeTerm )
8485{
85- return new CMathEngineLstmDesc ( hiddenSize, objectSize, inputWeights, inputFreeTerm,
86+ return new CMathEngineLstmDesc ( isCompatibleMode, hiddenSize, objectSize, inputWeights, inputFreeTerm,
8687 recurrentWeights, recurrentFreeTerm );
8788}
8889
@@ -131,8 +132,11 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
131132 // Write state data directly to output or create temporary blob for recurent
132133 std::unique_ptr<CFloatHandleStackVar> stateBackLinkVar;
133134 if ( outputStateBackLink.IsNull () ) {
134- stateBackLinkVar.reset ( new CFloatHandleStackVar ( *this , sequenceCount * lstmDesc.HiddenSize ) );
135+ stateBackLinkVar.reset ( new CFloatHandleStackVar ( *this , sequenceLength * sequenceCount * lstmDesc.HiddenSize ) );
135136 }
137+
138+ CSequenceWrapper<float > output (outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize );
139+
136140 CSequenceWrapper<float > stateBackLink (
137141 outputStateBackLink.IsNull () ? stateBackLinkVar->GetHandle () : outputStateBackLink,
138142 outputStateBackLink.IsNull () ? 1 : sequenceLength,
@@ -146,7 +150,11 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
146150 fcLen, sequenceCount * 4 * lstmDesc.HiddenSize );
147151
148152 // Emulate working of LSTM recurrent implementation
149- CSequenceWrapper<float > mainBackLink ( outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize );
153+ std::unique_ptr<CFloatHandleStackVar> mainBackLinkVar;
154+ if (lstmDesc.IsCompatibleMode ) {
155+ mainBackLinkVar.reset (new CFloatHandleStackVar (*this , sequenceLength * sequenceCount * lstmDesc.HiddenSize ));
156+ }
157+ CSequenceWrapper<float > mainBackLink (lstmDesc.IsCompatibleMode ? mainBackLinkVar->GetHandle () : outputMainBackLink, sequenceLength, sequenceCount * lstmDesc.HiddenSize );
150158 initializeBacklink ( inputMainBackLink, mainBackLink );
151159
152160 CSequenceWrapper<const float > input ( inputHandle, sequenceLength, sequenceCount * lstmDesc.ObjectSize );
@@ -197,14 +205,8 @@ void CCpuMathEngine::Lstm( CLstmDesc& desc, bool reverse, int sequenceLength, in
197205 sequenceCount, resultWidth, resultWidth, resultWidth, lstmDesc.FreeTerm );
198206 }
199207
200- // if outputMainBackLink != output then we are in compatibility mode
201- if ( simdMathEngine != nullptr ) {
202- simdMathEngine->RunOnceRestOfLstm ( &lstmDesc, sequenceCount, fullyConnectedResult[outputPos],
203- stateBackLink[inputPos], stateBackLink[outputPos], mainBackLink[outputPos] );
204- } else {
205- lstmDesc.RunOnceRestOfLstm ( sequenceCount, fullyConnectedResult[outputPos], stateBackLink[inputPos],
206- stateBackLink[outputPos], mainBackLink[outputPos] );
207- }
208+ lstmDesc.RunOnceRestOfLstm (lstmDesc.IsCompatibleMode , sequenceCount, fullyConnectedResult[outputPos], stateBackLink[inputPos],
209+ stateBackLink[outputPos], mainBackLink[outputPos], output[outputPos]);
208210 --seqElemsInBuffer;
209211 }
210212}
0 commit comments