@@ -241,8 +241,6 @@ void CBaseLayer::AllocateOutputBlobs()
241241 return ;
242242 }
243243
244- CMemoryModeSwitcher switcher ( MathEngine (), GetDnn ()->isReuseMemoryMode );
245-
246244 for ( int i = 0 ; i < outputDescs.Size (); ++i ) {
247245 if ( outputBlobs[i] == nullptr ) {
248246 outputBlobs[i] = CDnnBlob::CreateBlob ( MathEngine (), outputDescs[i].GetDataType (), outputDescs[i] );
@@ -405,7 +403,7 @@ void CBaseLayer::reshape()
405403 CArray<CBlobDesc> prevInputDescs;
406404 inputDescs.MoveTo ( prevInputDescs );
407405 inputDescs.SetSize (inputs.Size ());
408-
406+
409407 // Call the input layers reshape recursively, reset the input blobs
410408 for ( int i = 0 ; i < GetInputCount (); ++i ) {
411409 GetInputLayer (i)->reshape ();
@@ -420,7 +418,7 @@ void CBaseLayer::reshape()
420418
421419 if (!forcedReshape) {
422420 for (int i = 0 ; i < inputBlobs.Size (); i++) {
423- forcedReshape = forcedReshape
421+ forcedReshape = forcedReshape
424422 || !inputDescs[i].HasEqualDimensions (prevInputDescs[i]);
425423 }
426424 }
@@ -529,7 +527,11 @@ void CBaseLayer::runOnce()
529527 inputBlobs[i] = prevLayerOutput;
530528 }
531529
532- if ( mayFreeIoBlobs () ) {
530+ const bool mayFreeIoBlobs = GetDnn ()->isReuseMemoryMode
531+ && ( !GetDnn ()->isBackwardPerformed || !GetDnn ()->IsRecurrentMode () || GetDnn ()->IsLastSequencePos ()
532+ || ( ( blobsNeededForBackward & TInputBlobs ) == 0 && ( !isInPlace || ( blobsNeededForBackward & TOutputBlobs ) == 0 ) ) );
533+
534+ if ( mayFreeIoBlobs ) {
533535 for ( int i = 0 ; i < inputBlobs.Size (); ++i ) {
534536 CBaseLayer* inputLayer = GetInputLayer ( i );
535537 const int outputNumber = inputs[i].OutputNumber ;
@@ -560,39 +562,6 @@ void CBaseLayer::runOnce()
560562 }
561563}
562564
563- // Checks if output blobs of input layers can be discarded.
564- bool CBaseLayer::mayFreeIoBlobs () const
565- {
566- assert ( dnn != nullptr );
567-
568- if ( !dnn->isReuseMemoryMode ) {
569- // Memory reuse turned off.
570- return false ;
571- }
572-
573- if ( dnn->IsRecurrentMode () && !dnn->IsLastSequencePos () ) {
574- // Recurrent layer processing is incomplete.
575- return false ;
576- }
577-
578- if ( !dnn->isBackwardPerformed ) {
579- // Inference mode, intermediate data is not required.
580- return true ;
581- }
582-
583- if ( (blobsNeededForBackward & TInputBlobs) != 0 ) {
584- // Input blobs are required for back propagation.
585- return false ;
586- }
587-
588- if ( isInPlace && (blobsNeededForBackward & TOutputBlobs) != 0 ) {
589- // Output blobs are required for back propagation and they are the same as input.
590- return false ;
591- }
592-
593- return true ;
594- }
595-
596565// Recalculates the isBackwardNeeded flag; recursively checks the inputs
597566void CBaseLayer::recheckBackwardNeeded ()
598567{
@@ -654,7 +623,7 @@ void CBaseLayer::backwardRunAndLearnOnce()
654623 }
655624 }
656625
657- // Perform one step of error backward propagation:
626+ // Perform one step of error backward propagation:
658627 // calculate the input error from the output one
659628 BackwardOnce ();
660629 }
@@ -676,7 +645,7 @@ void CBaseLayer::backwardRunAndLearnOnce()
676645 paramDiffBlobs.DeleteAll ();
677646 }
678647 }
679-
648+
680649 outputDiffBlobs.DeleteAll ();
681650
682651 if ( IsBackwardPerformed () ) {
@@ -710,7 +679,7 @@ void CBaseLayer::backwardRunAndLearnOnce()
710679}
711680
712681// Handles the notification that output diff is ready for a given output
713- // If that is the last output diff necessary for learning,
682+ // If that is the last output diff necessary for learning,
714683// backpropagation and learning are started for this layer
715684void CBaseLayer::transferDiffBlob ( CDnnBlob* diffBlob, int outputNum )
716685{
0 commit comments