-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
I wanted to compare JAX and PyTorch as backends for keras. For this I wrote a simple Autoencoder and just switched os.environ["KERAS_BACKEND"] from "jax" to "pytorch".
I found that when using PyTorch as the backend, using jit_compile=True actually makes performance much worse than without compiling.
Also, JAX outperforms PyTorch in every scenario I was able to produce.
This is a problem to meaningfully compare the two backends, as I believe that it's not realistic that the right usage of torch.compile would actually lead to degrade the performance...
Here are the steps to reproduce:
I used micromamba to create the environments:
micromamba env create -n keras-python3_XX python=3.XX -c conda-forge
and replace "XX" with "12" or "13" respectively. I believe python 3.14 is (unless for now) irrelevant because torch.compile is not yet supported.
Then, after activating the environment with
micromamba activate keras-python3_XX
i consecutively installed the following packages:
pip install "jax[cuda13]"
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu130
pip install matplotlib scikit-learn keras
The resulting installed packages are in the files "python-3_XX-pip_list-output.txt" for python 3.12 and python 3.13 respectively.
I did this all in a jupyter notebook, which I also added (it was originally a task for my lectures, in case you wonder why the heading is "Exercise 02").
By changing the flags OVERRIDE_TORCH you can force the program to use "torch" as a backend, and by changing WANT_TORCH_COMPILE you can determine if you want to use jit_compile=True in the model.compile function.
python-3_12-pip_list-output.txt
python-3_13-pip_list-output.txt
For completeness, I copy-pasted the code from all cells below. I measured only the time it took for the last code cell.
# imports
import os
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
# try to import jax and use it as keras backend if it is available, else use pytorch
try:
import jax
os.environ["KERAS_BACKEND"] = "jax"
except:
import torch
os.environ["KERAS_BACKEND"] = "torch"
# OVERRIDE BACKEND WITH TORCH:
OVERRIDE_TORCH = True
if OVERRIDE_TORCH:
os.environ["KERAS_BACKEND"] = "torch"
import torch
backend = os.environ["KERAS_BACKEND"]
print(f"Using keras backend: {backend}")
# further imports
import keras
from typing import List
# model creation with keras for multi-backend support
class Autoencoder(keras.Model):
def __init__(self, input_size: int = 2, bottleneck_size: int = 1, hidden_size: int = 10, layers: int | List[int] = 5):
super().__init__()
self.latent_dim = bottleneck_size
self.hidden_size = hidden_size
# check if layers is a list
# if it is, create the encoder and decoder according to the list
# if it is not, create the encoder and decoder according to the layers and hidden_size parameters
inputs = keras.Input( (input_size,) )
# using keras functional api
if type(layers) == list:
x = inputs
for i, layer_size in enumerate(layers):
x = keras.layers.Dense(layer_size, activation='relu')(x)
latent_outputs = keras.layers.Dense(self.latent_dim, activation=None)(x)
self.encoder = keras.Model(inputs=inputs, outputs=latent_outputs, name="encoder")
layers.reverse() # reverses the list in-place
x = latent_outputs
for i, layer_size in enumerate(layers):
x = keras.layers.Dense(layer_size, activation='relu')(x)
outputs = keras.layers.Dense(input_size, activation=None)
self.decoder = keras.Model(inputs=latent_outputs, outputs=outputs, name="decoder")
self.autoencoder = keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
elif type(layers) == int:
x = inputs
for i in range(layers):
x = keras.layers.Dense(hidden_size, activation='relu')(x)
latent_outputs = keras.layers.Dense(self.latent_dim, activation=None)(x)
self.encoder = keras.Model(inputs=inputs, outputs=latent_outputs, name="encoder")
x = latent_outputs
for i in range(layers):
x = keras.layers.Dense(hidden_size, activation='relu')(x)
outputs = keras.layers.Dense(input_size, activation=None)(x)
self.decoder = keras.Model(inputs=latent_outputs, outputs=outputs, name="decoder")
self.autoencoder = keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
else:
print("Model creation failed. Make sure to use the right arguments.")
def _encode(self, x):
return self.encoder(x)
def _decode(self, x):
return self.decoder(x)
def call(self, x):
return self._decode(self._encode(x))
def custom_summary(self):
# print(self.encoder.summary())
# print(self.decoder.summary())
print(self.autoencoder.summary())
test_model = Autoencoder()
opt = keras.optimizers.Adam(1e-3)
loss = keras.losses.MeanSquaredError()
# check whether we can compile the torch model (only possible in python <= 3.13)
import sys
WANT_TORCH_COMPILE: bool = True # change this variable to compile the torch model - it actually worsens performance, it just for demonstration purposes
if '3.14' not in sys.version:
if backend == "torch" and WANT_TORCH_COMPILE:
test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
elif backend == "torch" and not WANT_TORCH_COMPILE:
test_model.compile(optimizer=opt, loss=loss)
elif backend != "torch":
test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
else:
print("huh?")
else:
if backend != "torch":
test_model.compile(optimizer=opt, loss=loss, jit_compile=True)
else:
test_model.compile(optimizer=opt, loss=loss)
test_dataset, _ = make_moons(n_samples=10_000, noise=0.1)
# recommended by the warnings generated by running this cell without it:
if backend == "torch":
torch.set_float32_matmul_precision('high') # by doing this, there are no warnings anymore in python 3.12 at least, python 3.13 still shows warnings
train_hist = test_model.fit(x=test_dataset, y=test_dataset, batch_size=32, epochs=10)
I hope someone can reproduce this. I'm not exactly sure if I'm doing something fundamentally wrong or not, but if I understood the keras API correctly it should be possible to run this code like this without experiencing this issue.
Thank you in advance for looking into this.