Skip to content

Commit 881d8da

Browse files
Patch tf2onnx to ensure compatibility with numpy>=2.0.0 (#20725)
* Patch tf2onnx to support numpy 2 * Fix warnings * Update export_onnx
1 parent 94977dd commit 881d8da

File tree

4 files changed

+211
-45
lines changed

4 files changed

+211
-45
lines changed

keras/src/export/onnx.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import pathlib
2-
import tempfile
3-
41
from keras.src import backend
52
from keras.src import tree
63
from keras.src.export.export_utils import convert_spec_to_tensor
74
from keras.src.export.export_utils import get_input_signature
8-
from keras.src.export.saved_model import export_saved_model
9-
from keras.src.utils.module_utils import tensorflow as tf
5+
from keras.src.export.export_utils import make_tf_tensor_spec
6+
from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME
7+
from keras.src.export.saved_model import ExportArchive
8+
from keras.src.export.tf2onnx_lib import patch_tf2onnx
109

1110

1211
def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs):
@@ -65,18 +64,18 @@ def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs):
6564
)
6665

6766
if backend.backend() in ("tensorflow", "jax"):
68-
working_dir = pathlib.Path(filepath).parent
69-
with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir:
70-
if backend.backend() == "jax":
71-
kwargs = _check_jax_kwargs(kwargs)
72-
export_saved_model(
73-
model,
74-
temp_dir,
75-
verbose,
76-
input_signature,
77-
**kwargs,
78-
)
79-
saved_model_to_onnx(temp_dir, filepath, model.name)
67+
from keras.src.utils.module_utils import tf2onnx
68+
69+
input_signature = tree.map_structure(
70+
make_tf_tensor_spec, input_signature
71+
)
72+
decorated_fn = get_concrete_fn(model, input_signature, **kwargs)
73+
74+
# Use `tf2onnx` to convert the `decorated_fn` to the ONNX format.
75+
patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2.
76+
tf2onnx.convert.from_function(
77+
decorated_fn, input_signature, output_path=filepath
78+
)
8079

8180
elif backend.backend() == "torch":
8281
import torch
@@ -133,30 +132,14 @@ def _check_jax_kwargs(kwargs):
133132
return kwargs
134133

135134

136-
def saved_model_to_onnx(saved_model_dir, filepath, name):
137-
from keras.src.utils.module_utils import tf2onnx
138-
139-
# Convert to ONNX using `tf2onnx` library.
140-
(graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = (
141-
tf2onnx.tf_loader.from_saved_model(
142-
saved_model_dir,
143-
None,
144-
None,
145-
return_initialized_tables=True,
146-
return_tensors_to_rename=True,
147-
)
135+
def get_concrete_fn(model, input_signature, **kwargs):
136+
"""Get the `tf.function` associated with the model."""
137+
if backend.backend() == "jax":
138+
kwargs = _check_jax_kwargs(kwargs)
139+
export_archive = ExportArchive()
140+
export_archive.track_and_add_endpoint(
141+
DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs
148142
)
149-
150-
with tf.device("/cpu:0"):
151-
_ = tf2onnx.convert._convert_common(
152-
graph_def,
153-
name=name,
154-
target=[],
155-
custom_op_handlers={},
156-
extra_opset=[],
157-
input_names=inputs,
158-
output_names=outputs,
159-
tensors_to_rename=tensors_to_rename,
160-
initialized_tables=initialized_tables,
161-
output_path=filepath,
162-
)
143+
if backend.backend() == "tensorflow":
144+
export_archive._filter_and_track_resources()
145+
return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME)

keras/src/export/saved_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
)
3636

3737

38+
DEFAULT_ENDPOINT_NAME = "serve"
39+
40+
3841
@keras_export("keras.export.ExportArchive")
3942
class ExportArchive(BackendExportArchive):
4043
"""ExportArchive is used to write SavedModel artifacts (e.g. for inference).
@@ -623,7 +626,7 @@ def export_saved_model(
623626
input_signature = get_input_signature(model)
624627

625628
export_archive.track_and_add_endpoint(
626-
"serve", model, input_signature, **kwargs
629+
DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs
627630
)
628631
export_archive.write_out(filepath, verbose=verbose)
629632

keras/src/export/tf2onnx_lib.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import copy
2+
import functools
3+
import logging
4+
import traceback
5+
6+
import numpy as np
7+
8+
9+
@functools.lru_cache()
10+
def patch_tf2onnx():
11+
"""Patches `tf2onnx` to ensure compatibility with numpy>=2.0.0."""
12+
13+
from onnx import AttributeProto
14+
from onnx import TensorProto
15+
16+
from keras.src.utils.module_utils import tf2onnx
17+
18+
logger = logging.getLogger(tf2onnx.__name__)
19+
20+
def patched_rewrite_constant_fold(g, ops):
21+
"""
22+
We call tensorflow transform with constant folding but in some cases
23+
tensorflow does fold all constants. Since there are a bunch of ops in
24+
onnx that use attributes where tensorflow has dynamic inputs, we badly
25+
want constant folding to work. For cases where tensorflow missed
26+
something, make another pass over the graph and fix want we care about.
27+
"""
28+
func_map = {
29+
"Add": np.add,
30+
"GreaterEqual": np.greater_equal,
31+
"Cast": np.asarray,
32+
"ConcatV2": np.concatenate,
33+
"Less": np.less,
34+
"ListDiff": np.setdiff1d,
35+
"Mul": np.multiply,
36+
"Pack": np.stack,
37+
"Range": np.arange,
38+
"Sqrt": np.sqrt,
39+
"Sub": np.subtract,
40+
}
41+
ops = list(ops)
42+
43+
keep_looking = True
44+
while keep_looking:
45+
keep_looking = False
46+
for idx, op in enumerate(ops):
47+
func = func_map.get(op.type)
48+
if func is None:
49+
continue
50+
if set(op.output) & set(g.outputs):
51+
continue
52+
try:
53+
inputs = []
54+
for node in op.inputs:
55+
if not node.is_const():
56+
break
57+
inputs.append(node.get_tensor_value(as_list=False))
58+
59+
logger.debug(
60+
"op name %s, %s, %s",
61+
op.name,
62+
len(op.input),
63+
len(inputs),
64+
)
65+
if inputs and len(op.input) == len(inputs):
66+
logger.info(
67+
"folding node type=%s, name=%s" % (op.type, op.name)
68+
)
69+
if op.type == "Cast":
70+
dst = op.get_attr_int("to")
71+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
72+
val = np.asarray(*inputs, dtype=np_type)
73+
elif op.type == "ConcatV2":
74+
axis = inputs[-1]
75+
values = inputs[:-1]
76+
val = func(tuple(values), axis)
77+
elif op.type == "ListDiff":
78+
out_type = op.get_attr_int("out_idx")
79+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(
80+
out_type
81+
)
82+
val = func(*inputs)
83+
val = val.astype(np_type)
84+
elif op.type in ["Pack"]:
85+
# handle ops that need input array and axis
86+
axis = op.get_attr_int("axis")
87+
val = func(inputs, axis=axis)
88+
elif op.type == "Range":
89+
dtype = op.get_attr_int("Tidx")
90+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(
91+
dtype
92+
)
93+
val = func(*inputs, dtype=np_type)
94+
else:
95+
val = func(*inputs)
96+
97+
new_node_name = tf2onnx.utils.make_name(op.name)
98+
new_output_name = new_node_name
99+
old_output_name = op.output[0]
100+
old_node_name = op.name
101+
logger.debug(
102+
"create const node [%s] replacing [%s]",
103+
new_node_name,
104+
old_node_name,
105+
)
106+
ops[idx] = g.make_const(new_node_name, val)
107+
108+
logger.debug(
109+
"replace old output [%s] with new output [%s]",
110+
old_output_name,
111+
new_output_name,
112+
)
113+
# need to re-write the consumers input name to use the
114+
# const name
115+
consumers = g.find_output_consumers(old_output_name)
116+
if consumers:
117+
for consumer in consumers:
118+
g.replace_input(
119+
consumer, old_output_name, new_output_name
120+
)
121+
122+
# keep looking until there is nothing we can fold.
123+
# We keep the graph in topological order so if we
124+
# folded, the result might help a following op.
125+
keep_looking = True
126+
except Exception as ex:
127+
tb = traceback.format_exc()
128+
logger.info("exception: %s, details: %s", ex, tb)
129+
# ignore errors
130+
131+
return ops
132+
133+
def patched_get_value_attr(self, external_tensor_storage=None):
134+
"""
135+
Return onnx attr for value property of node.
136+
Attr is modified to point to external tensor data stored in
137+
external_tensor_storage, if included.
138+
"""
139+
a = self._attr["value"]
140+
if (
141+
external_tensor_storage is not None
142+
and self in external_tensor_storage.node_to_modified_value_attr
143+
):
144+
return external_tensor_storage.node_to_modified_value_attr[self]
145+
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
146+
return a
147+
148+
def prod(x):
149+
if hasattr(np, "product"):
150+
return np.product(x)
151+
else:
152+
return np.prod(x)
153+
154+
if (
155+
prod(a.t.dims)
156+
> external_tensor_storage.external_tensor_size_threshold
157+
):
158+
a = copy.deepcopy(a)
159+
tensor_name = (
160+
self.name.strip()
161+
+ "_"
162+
+ str(external_tensor_storage.name_counter)
163+
)
164+
for c in '~"#%&*:<>?/\\{|}':
165+
tensor_name = tensor_name.replace(c, "_")
166+
external_tensor_storage.name_counter += 1
167+
external_tensor_storage.name_to_tensor_data[tensor_name] = (
168+
a.t.raw_data
169+
)
170+
external_tensor_storage.node_to_modified_value_attr[self] = a
171+
a.t.raw_data = b""
172+
a.t.ClearField("raw_data")
173+
location = a.t.external_data.add()
174+
location.key = "location"
175+
location.value = tensor_name
176+
a.t.data_location = TensorProto.EXTERNAL
177+
return a
178+
179+
tf2onnx.tfonnx.rewrite_constant_fold = patched_rewrite_constant_fold
180+
tf2onnx.graph.Node.get_value_attr = patched_get_value_attr

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
namex>=0.0.8
22
ruff
33
pytest
4-
numpy<2.0.0 # TODO: Remove the restriction when tf2onnx supports numpy>2.0.0
4+
numpy
55
scipy
66
scikit-learn
77
pandas

0 commit comments

Comments
 (0)