Skip to content

Compiled torch model has much worse performance than non-compiled version. #21851

@blechn

Description

@blechn

I wanted to compare JAX and PyTorch as backends for keras. For this I wrote a simple Autoencoder and just switched os.environ["KERAS_BACKEND"] from "jax" to "pytorch".

I found that when using PyTorch as the backend, using jit_compile=True actually makes performance much worse than without compiling.
Also, JAX outperforms PyTorch in every scenario I was able to produce.

This is a problem to meaningfully compare the two backends, as I believe that it's not realistic that the right usage of torch.compile would actually lead to degrade the performance...

Here are the steps to reproduce:

I used micromamba to create the environments:

micromamba env create -n keras-python3_XX python=3.XX -c conda-forge

and replace "XX" with "12" or "13" respectively. I believe python 3.14 is (unless for now) irrelevant because torch.compile is not yet supported.

Then, after activating the environment with

micromamba activate keras-python3_XX

i consecutively installed the following packages:

pip install "jax[cuda13]"
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu130
pip install matplotlib scikit-learn keras

The resulting installed packages are in the files "python-3_XX-pip_list-output.txt" for python 3.12 and python 3.13 respectively.
I did this all in a jupyter notebook, which I also added (it was originally a task for my lectures, in case you wonder why the heading is "Exercise 02").

By changing the flags OVERRIDE_TORCH you can force the program to use "torch" as a backend, and by changing WANT_TORCH_COMPILE you can determine if you want to use jit_compile=True in the model.compile function.

python-3_12-pip_list-output.txt
python-3_13-pip_list-output.txt

autoencoder.ipynb

For completeness, I copy-pasted the code from all cells below. I measured only the time it took for the last code cell.

# imports
import os

import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

# try to import jax and use it as keras backend if it is available, else use pytorch
try:
    import jax
    os.environ["KERAS_BACKEND"] = "jax"
except:
    import torch
    os.environ["KERAS_BACKEND"] = "torch"

# OVERRIDE BACKEND WITH TORCH:
OVERRIDE_TORCH = True
if OVERRIDE_TORCH:
    os.environ["KERAS_BACKEND"] = "torch"
    import torch

backend = os.environ["KERAS_BACKEND"]
print(f"Using keras backend: {backend}")

# further imports
import keras
from typing import List

# model creation with keras for multi-backend support
class Autoencoder(keras.Model):
    def __init__(self, input_size: int = 2, bottleneck_size: int = 1, hidden_size: int = 10, layers: int | List[int] = 5):
        super().__init__()

        self.latent_dim = bottleneck_size
        self.hidden_size = hidden_size

        # check if layers is a list
        # if it is, create the encoder and decoder according to the list
        # if it is not, create the encoder and decoder according to the layers and hidden_size parameters
        
        inputs = keras.Input( (input_size,) )

        # using keras functional api
        if type(layers) == list:
            x = inputs
            for i, layer_size in enumerate(layers):
                x = keras.layers.Dense(layer_size, activation='relu')(x)
            latent_outputs = keras.layers.Dense(self.latent_dim, activation=None)(x)

            self.encoder = keras.Model(inputs=inputs, outputs=latent_outputs, name="encoder")

            
            layers.reverse() # reverses the list in-place

            x = latent_outputs
            for i, layer_size in enumerate(layers):
                x = keras.layers.Dense(layer_size, activation='relu')(x)
            outputs = keras.layers.Dense(input_size, activation=None)

            self.decoder = keras.Model(inputs=latent_outputs, outputs=outputs, name="decoder")

            self.autoencoder = keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
        
        elif type(layers) == int:
            x = inputs
            for i in range(layers):
                x = keras.layers.Dense(hidden_size, activation='relu')(x)
            latent_outputs = keras.layers.Dense(self.latent_dim, activation=None)(x)

            self.encoder = keras.Model(inputs=inputs, outputs=latent_outputs, name="encoder")

            x = latent_outputs
            for i in range(layers):
                x = keras.layers.Dense(hidden_size, activation='relu')(x)
            outputs = keras.layers.Dense(input_size, activation=None)(x)

            self.decoder = keras.Model(inputs=latent_outputs, outputs=outputs, name="decoder")

            self.autoencoder = keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")

        else:
            print("Model creation failed. Make sure to use the right arguments.")

    def _encode(self, x):
        return self.encoder(x)
    
    def _decode(self, x):
        return self.decoder(x)

    def call(self, x):
        return self._decode(self._encode(x))
    
    def custom_summary(self):
        # print(self.encoder.summary())
        # print(self.decoder.summary())
        print(self.autoencoder.summary())

test_model = Autoencoder()
opt = keras.optimizers.Adam(1e-3)
loss = keras.losses.MeanSquaredError()

# check whether we can compile the torch model (only possible in python <= 3.13)
import sys
WANT_TORCH_COMPILE: bool = True # change this variable to compile the torch model - it actually worsens performance, it just for demonstration purposes
if '3.14' not in sys.version:
    if backend == "torch" and WANT_TORCH_COMPILE:
        test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
    elif backend == "torch" and not WANT_TORCH_COMPILE:
        test_model.compile(optimizer=opt, loss=loss)
    elif backend != "torch":
        test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
    else:
        print("huh?")
else:
    if backend != "torch":
        test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
    else:
        test_model.compile(optimizer=opt, loss=loss)

test_dataset, _ = make_moons(n_samples=10_000, noise=0.1)
# recommended by the warnings generated by running this cell without it:
if backend == "torch":
    torch.set_float32_matmul_precision('high') # by doing this, there are no warnings anymore in python 3.12 at least, python 3.13 still shows warnings
train_hist = test_model.fit(x=test_dataset, y=test_dataset, batch_size=32, epochs=10)

I hope someone can reproduce this. I'm not exactly sure if I'm doing something fundamentally wrong or not, but if I understood the keras API correctly it should be possible to run this code like this without experiencing this issue.

Thank you in advance for looking into this.

Metadata

Metadata

Labels

type:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions