-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
I love the option of using keras 3 with different backends, making it possible to simply train my model in a tf environment but use it in a pytorch environment.
I am currently working on a relatively big project building upon this idea, however I notice that to accomodate with the preferences of many pytorch/channels_first users, it would be extremely beneficial to have a simple and comfortable function to convert a fully built or pretrained model from one representation to the other (channels_first <-> channels_last). I think this should be absolutely possible using the current keras workflow, but I did not encounter such functionality that would work on pretrained models.
I am fully aware that you can simply reshape the input for each call, or stack a reshape layer in front. However, I know that for software applications commodity is king, and I know that a lot of the users I am targeting would definitely prefer having a full model conversion function, which easily allows them to not only have the input and output in their preferred shape, but also all the intermediate activations of the network, without the trouble of constant reshaping.
Here is a MWE of how it could work:
# have an arbitrary built/pretrained keras model processing channels and spatial dimensions
channels_last_model = arbitrary_built_model
# introduce functions that convert each layer from channels first to channels last
# correctly transposing the weights/operations accordingly
channels_first_model = keras.something.channels_first(channels_last_model)
reconstructed_model = keras.something.channels_last(channels_first_model)
# create input for both variants
inputs_channels_last = keras.ops.ones([16, 224, 224, 3])
inputs_channels_first = keras.ops.ones([16, 3, 224, 224])
# call all the models
a = channels_last_model(inputs_channels_last)
b = channels_first_model(inputs_channels_first)
c = reconstructed_model(inputs_channels_last)
# make sure their outputs are equal
assert_all_equal(a, keras.ops.swapaxes(b, 1, -1))
assert_all_equal(a, c)Would it be viable for you to implement such functions? Or are there other simple and elegant solutions that I missed?
Thank you for considering this. I think this idea would place keras at a really advantageous point in the Deep Learning landscape, notably facilitating exchange between backends. 😊👍