-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
f99cc63
f830e93
9938ef1
323a1ab
df57227
5343b71
4cc8ac0
12edcb4
248773f
53a5dc9
2727a24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,3 +1464,155 @@ def _pair(x): | |
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| _, CKK, oH, oW = patches.shape | ||
| return patches.reshape(N, CKK, oH * oW) | ||
|
|
||
|
|
||
| def _adaptive_pool_start_index(output_idx, output_size, input_size): | ||
| """Calculate start index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32) | ||
|
|
||
|
|
||
| def _adaptive_pool_end_index(output_idx, output_size, input_size): | ||
| """Calculate end index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype( | ||
| jnp.int32 | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_avg_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive average pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply average pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Average over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.mean(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Average over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.mean(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def adaptive_max_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive max pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply max pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Max over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.max(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Max over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.max(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from keras.src.layers.pooling.adaptive_average_pooling2d import ( | ||
| AdaptiveAveragePooling2D, | ||
| ) | ||
| from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| """Adaptive Average Pooling 2D layer.""" | ||
|
|
||
| from keras import config | ||
| from keras.src import ops | ||
| from keras.src.api_export import keras_export | ||
| from keras.src.layers.layer import Layer | ||
|
|
||
|
|
||
| @keras_export("keras.layers.AdaptiveAveragePooling2D") | ||
| class AdaptiveAveragePooling2D(Layer): | ||
| """Adaptive average pooling operation for 2D spatial data. | ||
|
|
||
| This layer applies an adaptive average pooling operation, which pools the | ||
| input such that the output has a target shape specified by `output_size`, | ||
| regardless of the input shape. The kernel size and stride are automatically | ||
| computed to achieve the target output size. | ||
|
|
||
| Args: | ||
| output_size: Integer or tuple of 2 integers, specifying the target | ||
| output size `(height, width)`. If a single integer is provided, | ||
| the same value is used for both dimensions. | ||
| data_format: string, either `"channels_last"` or `"channels_first"`. | ||
| The ordering of the dimensions in the inputs. `"channels_last"` | ||
| corresponds to inputs with shape `(batch, height, width, channels)` | ||
| while `"channels_first"` corresponds to inputs with shape | ||
| `(batch, channels, height, width)`. Defaults to the value found in | ||
| your Keras config file at `~/.keras/keras.json`. If never set, then | ||
| "channels_last" will be used. | ||
|
|
||
| Input shape: | ||
| - If `data_format="channels_last"`: | ||
| 4D tensor with shape `(batch_size, height, width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| 4D tensor with shape `(batch_size, channels, height, width)`. | ||
|
|
||
| Output shape: | ||
| - If `data_format="channels_last"`: | ||
| 4D tensor with shape | ||
| `(batch_size, output_height, output_width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| 4D tensor with shape | ||
| `(batch_size, channels, output_height, output_width)`. | ||
|
|
||
| Examples: | ||
|
|
||
| >>> input_img = np.random.rand(1, 64, 64, 3) | ||
| >>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=(32, 32)) | ||
| >>> output_img = layer(input_img) | ||
| >>> output_img.shape | ||
| (1, 32, 32, 3) | ||
|
|
||
| >>> # Single integer for square output | ||
| >>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=7) | ||
| >>> output_img = layer(input_img) | ||
| >>> output_img.shape | ||
| (1, 7, 7, 3) | ||
| """ | ||
|
|
||
| def __init__(self, output_size, data_format=None, **kwargs): | ||
| super().__init__(**kwargs) | ||
| if isinstance(output_size, int): | ||
| self.output_size = (output_size, output_size) | ||
| elif isinstance(output_size, (list, tuple)): | ||
| if len(output_size) != 2: | ||
| raise ValueError( | ||
| f"`output_size` must be an integer or tuple of 2 integers. " | ||
| f"Received: output_size={output_size}" | ||
| ) | ||
| self.output_size = tuple(output_size) | ||
| else: | ||
| raise TypeError( | ||
| f"`output_size` must be an integer or tuple of 2 integers. " | ||
| f"Received: output_size={output_size} of type " | ||
| f"{type(output_size)}" | ||
| ) | ||
|
|
||
| self.data_format = data_format or config.image_data_format() | ||
|
|
||
| if self.data_format not in {"channels_first", "channels_last"}: | ||
| raise ValueError( | ||
| f"Invalid data_format: {self.data_format}. " | ||
| "Must be either 'channels_first' or 'channels_last'." | ||
| ) | ||
|
|
||
| def call(self, inputs): | ||
| return ops.adaptive_avg_pool( | ||
| inputs, output_size=self.output_size, data_format=self.data_format | ||
| ) | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| if self.data_format == "channels_last": | ||
| return ( | ||
| input_shape[0], | ||
| self.output_size[0], | ||
| self.output_size[1], | ||
| input_shape[3], | ||
| ) | ||
| else: # channels_first | ||
| return ( | ||
| input_shape[0], | ||
| input_shape[1], | ||
| self.output_size[0], | ||
| self.output_size[1], | ||
| ) | ||
|
|
||
| def get_config(self): | ||
| config_dict = { | ||
| "output_size": self.output_size, | ||
| "data_format": self.data_format, | ||
| } | ||
| base_config = super().get_config() | ||
| return {**base_config, **config_dict} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helper functions
_adaptive_pool_start_indexand_adaptive_pool_end_indexare defined but not used. This dead code should be removed to improve code clarity.