Skip to content

Feature Request: convert built models from channels_first to channels_last and vice versa. #21889

@DiGyt

Description

@DiGyt

I love the option of using keras 3 with different backends, making it possible to simply train my model in a tf environment but use it in a pytorch environment.

I am currently working on a relatively big project building upon this idea, however I notice that to accomodate with the preferences of many pytorch/channels_first users, it would be extremely beneficial to have a simple and comfortable function to convert a fully built or pretrained model from one representation to the other (channels_first <-> channels_last). I think this should be absolutely possible using the current keras workflow, but I did not encounter such functionality that would work on pretrained models.

I am fully aware that you can simply reshape the input for each call, or stack a reshape layer in front. However, I know that for software applications commodity is king, and I know that a lot of the users I am targeting would definitely prefer having a full model conversion function, which easily allows them to not only have the input and output in their preferred shape, but also all the intermediate activations of the network, without the trouble of constant reshaping.

Here is a MWE of how it could work:

# have an arbitrary built/pretrained keras model processing channels and spatial dimensions
channels_last_model = arbitrary_built_model

# introduce functions that convert each layer from channels first to channels last
# correctly transposing the weights/operations accordingly
channels_first_model = keras.something.channels_first(channels_last_model)
reconstructed_model = keras.something.channels_last(channels_first_model)

# create input for both variants
inputs_channels_last = keras.ops.ones([16, 224, 224, 3])
inputs_channels_first = keras.ops.ones([16, 3, 224, 224])

# call all the models
a = channels_last_model(inputs_channels_last)
b = channels_first_model(inputs_channels_first)
c = reconstructed_model(inputs_channels_last)

# make sure their outputs are equal
assert_all_equal(a, keras.ops.swapaxes(b, 1, -1))
assert_all_equal(a, c)

Would it be viable for you to implement such functions? Or are there other simple and elegant solutions that I missed?

Thank you for considering this. I think this idea would place keras at a really advantageous point in the Deep Learning landscape, notably facilitating exchange between backends. 😊👍

Metadata

Metadata

Labels

keras-team-review-pendingPending review by a Keras team member.type:featureThe user is asking for a new feature.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions