Skip to content

Commit e66ab50

Browse files
divyashreepathihalliLakshmiKalaKadalishashakamartin-gornerhertschuh
authored
merge master (#20741)
* Specify window_length dtype requirement in tf.keras.ops.istft in math.py (#20728) The `window_length` parameter in `tf.keras.ops.istft` requires `tf.int32` dtype, but this isn't documented. This can cause unexpected `ValueError` when using `tf.int64` and `tf.int16` Here is the Example case: ``` import tensorflow as tf input_dict = { 'stfts': tf.constant([[-0.87817144+1.14583987j, -0.32066484+0.25565411j]], dtype=tf.complex128), 'frame_length': tf.constant(256, dtype=tf.int16), 'frame_step': tf.constant(5120,dtype=tf.int64) } result = tf.signal.inverse_stft(**input_dict) print(result) ``` The code throws the following error: ``` ValueError: window_length: Tensor conversion requested dtype int32 for Tensor with dtype int64 ``` * Add rand_augment processing layer (#20716) * Add rand_augment init * Update rand_augment init * Add rand_augment * Add NotImplementedError * Add some test cases * Fix failed test case * Update rand_augment * Update rand_augment test * Fix random_rotation bug * Add build method to supress warning. * Add implementation for transform_bboxes * Fixing batch_dim_name attribute (#20674) * fixing wrong trainer assumption that batch dim is always the first one in the mesh * need functools partial * lint * fix test failure when distribution=None * lint2 * fix for test failure * added data sharding for 3D+ meshes * lint3 * added @Property for batch_dim_name + refactoring * fix typo * Add support for `dtype` / `DTypePolicy` to `JaxLayer` and `FlaxLayer`. (#20732) The `dtype` / `DTypePolicy` is applied to all float variables. * Allow dynamic shape in `STFTSpectrogram` layer. (#20736) by simply using `ops.shape(x)` instead of `x.shape`. * Remove duplicate export tests in `model_test`. (#20735) The same tests exist at: - https://github.com/keras-team/keras/blob/master/keras/src/export/saved_model_test.py#L66 - https://github.com/keras-team/keras/blob/master/keras/src/export/onnx_test.py#L62 The goal is to isolate the use of `onnxruntime` to a single file, `onnx_test.py`. * Add OpenVINO into README.md (#20739) * Add OpenVINO into README.md Signed-off-by: Kazantsev, Roman <[email protected]> * Update README.md --------- Signed-off-by: Kazantsev, Roman <[email protected]> * Multiple Example Title has removed in metrics.MeanIoU method (#20738) Multiple Example Title has removed in metrics.MeanIoU method --------- Signed-off-by: Kazantsev, Roman <[email protected]> Co-authored-by: LakshmiKalaKadali <[email protected]> Co-authored-by: Ugeun Park <[email protected]> Co-authored-by: Martin Görner <[email protected]> Co-authored-by: hertschuh <[email protected]> Co-authored-by: Roman Kazantsev <[email protected]> Co-authored-by: LavanyaKV1234 <[email protected]>
1 parent 1adaaec commit e66ab50

File tree

22 files changed

+489
-121
lines changed

22 files changed

+489
-121
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Keras 3: Deep Learning for Humans
22

3-
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, and PyTorch.
3+
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).
44
Effortlessly build and train models for computer vision, natural language processing, audio processing,
55
timeseries forecasting, recommender systems, etc.
66

@@ -73,7 +73,7 @@ python pip_build.py --install
7373
## Configuring your backend
7474

7575
You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`
76-
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`. Example:
76+
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example:
7777

7878
```
7979
export KERAS_BACKEND="jax"
@@ -91,6 +91,10 @@ import keras
9191
**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
9292
the package has been imported.
9393

94+
**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model
95+
predictions using `model.predict()` method.
96+
To use `openvino` backend, install the required dependencies from the `requirements-openvino.txt` file.
97+
9498
## Backwards compatibility
9599

96100
Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
MaxNumBoundingBoxes,
153153
)
154154
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
155+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
156+
RandAugment,
157+
)
155158
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156159
RandomBrightness,
157160
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
MaxNumBoundingBoxes,
153153
)
154154
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
155+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
156+
RandAugment,
157+
)
155158
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156159
RandomBrightness,
157160
)

keras/src/backend/jax/distribution_lib.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
100100
return global_value
101101

102102

103-
def distribute_data_input(per_process_batch, layout):
103+
def distribute_data_input(per_process_batch, layout, batch_dim_name):
104104
"""Distribute the input data with the corresponding layout.
105105
106106
Note that the inputs here is a local worker batch. Within the local worker,
@@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
117117
if not isinstance(layout, jax.sharding.Sharding):
118118
layout = _to_jax_layout(layout)
119119

120-
mesh_shape = list(layout.mesh.shape.values())
121-
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
122-
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
120+
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
121+
122+
mesh_model_dim_size = 1
123+
for name, dim_size in layout.mesh.shape.items():
124+
if not name == batch_dim_name:
125+
mesh_model_dim_size *= dim_size
126+
123127
num_model_replicas_per_process = num_model_replicas_total / num_processes()
124128
per_process_batch_size = per_process_batch.shape[0]
125129

keras/src/backend/jax/distribution_lib_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def test_distribute_data_input(self):
337337
mesh, jax.sharding.PartitionSpec("batch", None)
338338
)
339339

340-
result = backend_dlib.distribute_data_input(per_process_batch, layout)
340+
result = backend_dlib.distribute_data_input(
341+
per_process_batch, layout, "batch"
342+
)
341343

342344
# Check the shape of the global batch array
343345
self.assertEqual(

keras/src/backend/jax/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs):
119119
self._tf_trackable.non_trainable_variables,
120120
non_trainable_variables,
121121
):
122-
var.assign(new_value)
122+
var.assign(tf.cast(new_value, var.dtype))
123123
return output
124124

125125
stateful_fn.__signature__ = inspect.Signature(

keras/src/backend/jax/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import itertools
3+
from functools import partial
34

45
import jax
56
import numpy as np
@@ -988,15 +989,18 @@ def _get_jax_state(
988989

989990
def _distribute_data(data, layouts=None):
990991
distribution = distribution_lib.distribution()
992+
991993
if distribution is not None:
992994
if layouts is None:
993995
layouts = tree.map_structure(
994996
lambda d: distribution.get_data_layout(d.shape),
995997
data,
996998
)
997-
return tree.map_structure(
998-
jax_distribution_lib.distribute_data_input, data, layouts
999+
jax_dist_data_input = partial(
1000+
jax_distribution_lib.distribute_data_input,
1001+
batch_dim_name=distribution.batch_dim_name,
9991002
)
1003+
return tree.map_structure(jax_dist_data_input, data, layouts)
10001004

10011005
return tree.map_structure(jax.device_put, data)
10021006

keras/src/distribution/distribution_lib.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ class Distribution:
287287
device_mesh: A `DeviceMesh` instance.
288288
"""
289289

290-
def __init__(self, device_mesh):
290+
def __init__(self, device_mesh, batch_dim_name=None):
291291
self._device_mesh = device_mesh
292+
self._batch_dim_name = batch_dim_name
292293

293294
def get_data_layout(self, data_shape):
294295
"""Retrieve the `TensorLayout` for the input data.
@@ -341,6 +342,10 @@ def scope(self):
341342
def device_mesh(self):
342343
return self._device_mesh
343344

345+
@property
346+
def batch_dim_name(self):
347+
return self._batch_dim_name
348+
344349
def distribute_dataset(self, dataset):
345350
"""Create a distributed dataset instance from the original user dataset.
346351
@@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
395400
else:
396401
self._initialize_mesh_from_list_devices()
397402

398-
self._batch_dim_name = self.device_mesh.axis_names[0]
399403
# Those following attributes might get convert to public methods.
400404
self._num_process = distribution_lib.num_processes()
401405
self._process_id = distribution_lib.process_id()
@@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
408412
"Expect `mesh` to be an instance of `DeviceMesh`. "
409413
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
410414
)
411-
super().__init__(device_mesh)
415+
super().__init__(device_mesh, device_mesh.axis_names[0])
412416
if self.device_mesh.devices.ndim != 1:
413417
warnings.warn(
414418
"Expect the input mesh to be 1D, but received "
@@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
424428
axis_names=[DEFAULT_BATCH_DIM_NAME],
425429
devices=devices,
426430
)
427-
super().__init__(device_mesh)
431+
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
428432

429433
def _initialize_mesh_from_list_devices(self):
430434
devices = np.array(list_devices())
@@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
433437
axis_names=[DEFAULT_BATCH_DIM_NAME],
434438
devices=devices,
435439
)
436-
super().__init__(device_mesh)
440+
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
437441

438442
def get_data_layout(self, data_shape):
439443
data_shard_spec = [None] * len(data_shape)
440-
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
444+
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
441445
return TensorLayout(data_shard_spec, self.device_mesh)
442446

443447
def get_variable_layout(self, variable):
@@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
590594

591595
def get_data_layout(self, data_shape):
592596
data_shard_spec = [None] * len(data_shape)
593-
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
597+
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
594598
return TensorLayout(data_shard_spec, self.device_mesh)
595599

596600
def get_variable_layout(self, variable):
@@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
631635
# Note that this might be smaller than one if model replicas are sharded
632636
# across multiple processes.
633637
mesh_batch_dim_index = self.device_mesh.axis_names.index(
634-
self._batch_dim_name
638+
self.batch_dim_name
635639
)
636640
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
637641
if num_model_replicas == 1:

keras/src/distribution/distribution_lib_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
186186
device_mesh = distribution.device_mesh
187187
self.assertEqual(len(device_mesh.devices), 8)
188188
self.assertEqual(device_mesh.axis_names, ["data"])
189-
self.assertEqual(distribution._batch_dim_name, "data")
189+
self.assertEqual(distribution.batch_dim_name, "data")
190190

191191
self.assertFalse(distribution._is_multi_process)
192192
self.assertEqual(distribution._process_id, 0)
@@ -197,7 +197,7 @@ def test_create_with_devices(self):
197197
device_mesh = distribution.device_mesh
198198
self.assertEqual(len(device_mesh.devices), 8)
199199
self.assertEqual(device_mesh.axis_names, ["batch"])
200-
self.assertEqual(distribution._batch_dim_name, "batch")
200+
self.assertEqual(distribution.batch_dim_name, "batch")
201201

202202
@mock.patch.object(
203203
distribution_lib,
@@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
211211
device_mesh = distribution.device_mesh
212212
self.assertEqual(len(device_mesh.devices), 8)
213213
self.assertEqual(device_mesh.axis_names, ["batch"])
214-
self.assertEqual(distribution._batch_dim_name, "batch")
214+
self.assertEqual(distribution.batch_dim_name, "batch")
215215

216216
def test_get_data_layout(self):
217217
distribution = distribution_lib.DataParallel(

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@
9696
MaxNumBoundingBoxes,
9797
)
9898
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
99+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
100+
RandAugment,
101+
)
99102
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
100103
RandomBrightness,
101104
)

0 commit comments

Comments
 (0)