Skip to content

Commit ec48be3

Browse files
Add optimization after loading (#1127)
1 parent bb6559a commit ec48be3

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

NeoML/include/NeoML/Dnn/DnnOptimization.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,6 @@ struct NEOML_API CDnnOptimizationSettings final {
116116
CDnnOptimizationReport NEOML_API OptimizeDnn( CDnn& dnn,
117117
const CDnnOptimizationSettings& settings = CDnnOptimizationSettings() );
118118

119+
CDnnOptimizationReport NEOML_API OptimizeDnnOnLoad(CDnn& dnn, size_t dnnWeightsTotalSize);
120+
119121
} // namespace NeoML

NeoML/src/Dnn/Dnn.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ limitations under the License.
104104
#include <NeoML/Dnn/Layers/TransformerLayer.h>
105105
#include <NeoML/Dnn/Layers/TransformerSourceMaskLayer.h>
106106
#include <NeoML/Dnn/Layers/Upsampling2DLayer.h>
107+
#include <NeoML/Dnn/DnnOptimization.h>
108+
107109
#endif //!NEOML_COMPACT
108110

109111
namespace NeoML {
@@ -810,7 +812,7 @@ static constexpr int dnnVersion = 2000;
810812
void CDnn::Serialize( CArchive& archive )
811813
{
812814
NeoAssertMsg( !IsReferenceDnn(), "For ReferenceDnn serializing is restricted" );
813-
815+
size_t before = mathEngine.GetCurrentMemoryUsage();
814816
int version = dnnVersion;
815817
archive.Serialize( version );
816818

@@ -853,6 +855,9 @@ void CDnn::Serialize( CArchive& archive )
853855
// In order to avoid the CDnnSolver::Reset for the next solver
854856
rebuild();
855857
}
858+
size_t after = mathEngine.GetCurrentMemoryUsage();
859+
860+
OptimizeDnnOnLoad(*this, after - before);
856861
}
857862

858863
void CDnn::SerializeCheckpoint( CArchive& archive )

NeoML/src/Dnn/DnnOptimization.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,30 @@ CDnnOptimizationReport OptimizeDnn( CDnn& dnn, const CDnnOptimizationSettings& s
4343
optimization::CMobileNetV3Optimizer( graph ).Apply( report );
4444

4545
CArray<int> chains;
46-
OptimizeRowwiseChains( dnn, chains );
46+
OptimizeRowwiseChains(dnn, chains);
4747
report.RowwiseChainCount = chains.Size();
4848
}
4949
return report;
5050
}
5151

52+
CDnnOptimizationReport OptimizeDnnOnLoad(CDnn& dnn, size_t size)
53+
{
54+
CDnnOptimizationReport report;
55+
optimization::CGraph graph(dnn);
56+
57+
report.UnpackedCompositeLayers = optimization::UnpackComposites(graph);
58+
report.RemovedTrivialLayers = optimization::RemoveTrivialLayers(graph);
59+
optimization::CBatchNormFusionOptimizer(graph).Apply(report);
60+
61+
if (size < 1024 * 1024) {
62+
optimization::CMobileNetV2Optimizer(graph).Apply(report);
63+
64+
CArray<int> chains;
65+
OptimizeRowwiseChains(dnn, chains);
66+
report.RowwiseChainCount = chains.Size();
67+
}
68+
69+
return report;
70+
}
71+
5272
} // namespace NeoML

0 commit comments

Comments
 (0)