diff --git a/tmva/sofie_parsers/src/RModelParser_Keras.cxx b/tmva/sofie_parsers/src/RModelParser_Keras.cxx index 598dbdc18f787..e7333bbd0606a 100644 --- a/tmva/sofie_parsers/src/RModelParser_Keras.cxx +++ b/tmva/sofie_parsers/src/RModelParser_Keras.cxx @@ -873,8 +873,16 @@ RModel Parse(std::string filename, int batch_size){ // For each layer: type,name,activation,dtype,input tensor's name, // output tensor's name, kernel's name, bias's name // None object is returned for if property doesn't belong to layer - PyRunString("import tensorflow",fGlobalNS,fLocalNS); - PyRunString("import tensorflow.keras as keras",fGlobalNS,fLocalNS); + PyRunString("import tensorflow\n", fGlobalNS, fLocalNS); + PyRunString("import tensorflow.keras as keras\n" + "version = keras.__version__\n" + "major = int(version.split('.')[0])\n" + "if major >= 3:\n" + " raise RuntimeError(\n" + " 'TMVA SOFIE Keras parser supports Keras 2 only.\\n'\n" + " 'Keras 3 detected. Please export the model to ONNX.\\n'\n" + " )\n", + fGlobalNS, fLocalNS); PyRunString("from tensorflow.keras.models import load_model",fGlobalNS,fLocalNS); PyRunString("print('TF/Keras Version: '+ tensorflow.__version__)",fGlobalNS,fLocalNS); PyRunString(TString::Format("model=load_model('%s')",filename.c_str()),fGlobalNS,fLocalNS);