Skip to content

Commit 29daeed

Browse files
authored
Improve ndarray testing (wjakob#1168)
1 parent 6663f6a commit 29daeed

File tree

4 files changed

+137
-132
lines changed

4 files changed

+137
-132
lines changed

tests/test_jax.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import test_ndarray_ext as t
2+
import test_jax_ext as tj
3+
import pytest
4+
import warnings
5+
import importlib
6+
from common import collect
7+
8+
try:
9+
import jax.numpy as jnp
10+
def needs_jax(x):
11+
return x
12+
except:
13+
needs_jax = pytest.mark.skip(reason="JAX is required")
14+
15+
16+
@needs_jax
17+
def test01_constrain_order_jax():
18+
with warnings.catch_warnings():
19+
warnings.simplefilter("ignore")
20+
try:
21+
c = jnp.zeros((3, 5))
22+
except:
23+
pytest.skip('jax is missing')
24+
25+
z = jnp.zeros((3, 5, 4, 6))
26+
assert t.check_order(z) == 'C'
27+
28+
29+
@needs_jax
30+
def test02_implicit_conversion_jax():
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore")
33+
try:
34+
c = jnp.zeros((3, 5))
35+
except:
36+
pytest.skip('jax is missing')
37+
38+
t.implicit(jnp.zeros((2, 2), dtype=jnp.int32))
39+
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.float32)[:, :, 4])
40+
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.int32)[:, :, 4])
41+
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.bool_)[:, :, 4])
42+
43+
with pytest.raises(TypeError) as excinfo:
44+
t.noimplicit(jnp.zeros((2, 2), dtype=jnp.int32))
45+
46+
with pytest.raises(TypeError) as excinfo:
47+
t.noimplicit(jnp.zeros((2, 2), dtype=jnp.uint8))
48+
49+
50+
@needs_jax
51+
def test03_return_jax():
52+
collect()
53+
dc = tj.destruct_count()
54+
x = tj.ret_jax()
55+
assert x.shape == (2, 4)
56+
assert jnp.all(x == jnp.array([[1,2,3,4], [5,6,7,8]], dtype=jnp.float32))
57+
del x
58+
collect()
59+
assert tj.destruct_count() - dc == 1
60+
61+
62+
@needs_jax
63+
def test04_check_jax():
64+
assert t.check(jnp.zeros((1)))

tests/test_ndarray.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ int destruct_count = 0;
1212
static float f_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };
1313
static int i_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };
1414

15-
#if defined(__aarch64__)
15+
#if defined(__aarch64__) || defined(__AVX512FP16__)
1616
namespace nanobind::detail {
17-
template <> struct dtype_traits<__fp16> {
17+
template <> struct dtype_traits<_Float16> {
1818
static constexpr dlpack::dtype value {
1919
(uint8_t) dlpack::dtype_code::Float, // type code
2020
16, // size in bits
@@ -392,17 +392,17 @@ NB_MODULE(test_ndarray_ext, m) {
392392
v(i) = -v(i);
393393
}, "x"_a.noconvert());
394394

395-
#if defined(__aarch64__)
395+
#if defined(__aarch64__) || defined(__AVX512FP16__)
396396
m.def("ret_numpy_half", []() {
397-
__fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
397+
_Float16 *f = new _Float16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
398398
size_t shape[2] = { 2, 4 };
399399

400400
nb::capsule deleter(f, [](void *data) noexcept {
401401
destruct_count++;
402-
delete[] (__fp16*) data;
402+
delete[] (_Float16*) data;
403403
});
404-
return nb::ndarray<nb::numpy, __fp16, nb::shape<2, 4>>(f, 2, shape,
405-
deleter);
404+
return nb::ndarray<nb::numpy, _Float16, nb::shape<2, 4>>(f, 2, shape,
405+
deleter);
406406
});
407407
#endif
408408

tests/test_ndarray.py

Lines changed: 1 addition & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import test_ndarray_ext as t
2-
import test_jax_ext as tj
3-
import test_tensorflow_ext as tt
42
import pytest
53
import warnings
64
import importlib
@@ -20,21 +18,6 @@ def needs_torch(x):
2018
except:
2119
needs_torch = pytest.mark.skip(reason="PyTorch is required")
2220

23-
try:
24-
import tensorflow as tf
25-
import tensorflow.config
26-
def needs_tensorflow(x):
27-
return x
28-
except:
29-
needs_tensorflow = pytest.mark.skip(reason="TensorFlow is required")
30-
31-
try:
32-
import jax.numpy as jnp
33-
def needs_jax(x):
34-
return x
35-
except:
36-
needs_jax = pytest.mark.skip(reason="JAX is required")
37-
3821
try:
3922
import cupy as cp
4023
def needs_cupy(x):
@@ -158,19 +141,6 @@ def test05_constrain_order():
158141
assert t.check_order(np.zeros((3, 5, 4, 6), order='F')[:, 2, :, :]) == '?'
159142

160143

161-
@needs_jax
162-
def test06_constrain_order_jax():
163-
with warnings.catch_warnings():
164-
warnings.simplefilter("ignore")
165-
try:
166-
c = jnp.zeros((3, 5))
167-
except:
168-
pytest.skip('jax is missing')
169-
170-
z = jnp.zeros((3, 5, 4, 6))
171-
assert t.check_order(z) == 'C'
172-
173-
174144
@needs_torch
175145
@pytest.mark.filterwarnings
176146
def test07_constrain_order_pytorch():
@@ -190,18 +160,6 @@ def test07_constrain_order_pytorch():
190160
assert t.check_device(torch.zeros(3, 5, device='cuda')) == 'cuda'
191161

192162

193-
@needs_tensorflow
194-
def test08_constrain_order_tensorflow():
195-
with warnings.catch_warnings():
196-
warnings.simplefilter("ignore")
197-
try:
198-
c = tf.zeros((3, 5))
199-
except:
200-
pytest.skip('tensorflow is missing')
201-
202-
assert t.check_order(c) == 'C'
203-
204-
205163
@needs_numpy
206164
def test09_write_from_cpp():
207165
x = np.zeros(10, dtype=np.float32)
@@ -251,48 +209,6 @@ def test11_implicit_conversion_pytorch():
251209
t.noimplicit(torch.zeros(2, 2, 10, dtype=torch.float32)[:, :, 4])
252210

253211

254-
@needs_tensorflow
255-
def test12_implicit_conversion_tensorflow():
256-
with warnings.catch_warnings():
257-
warnings.simplefilter("ignore")
258-
try:
259-
c = tf.zeros((3, 5))
260-
except:
261-
pytest.skip('tensorflow is missing')
262-
263-
t.implicit(tf.zeros((2, 2), dtype=tf.int32))
264-
t.implicit(tf.zeros((2, 2, 10), dtype=tf.float32)[:, :, 4])
265-
t.implicit(tf.zeros((2, 2, 10), dtype=tf.int32)[:, :, 4])
266-
t.implicit(tf.zeros((2, 2, 10), dtype=tf.bool)[:, :, 4])
267-
268-
with pytest.raises(TypeError) as excinfo:
269-
t.noimplicit(tf.zeros((2, 2), dtype=tf.int32))
270-
271-
with pytest.raises(TypeError) as excinfo:
272-
t.noimplicit(tf.zeros((2, 2), dtype=tf.bool))
273-
274-
275-
@needs_jax
276-
def test13_implicit_conversion_jax():
277-
with warnings.catch_warnings():
278-
warnings.simplefilter("ignore")
279-
try:
280-
c = jnp.zeros((3, 5))
281-
except:
282-
pytest.skip('jax is missing')
283-
284-
t.implicit(jnp.zeros((2, 2), dtype=jnp.int32))
285-
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.float32)[:, :, 4])
286-
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.int32)[:, :, 4])
287-
t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.bool_)[:, :, 4])
288-
289-
with pytest.raises(TypeError) as excinfo:
290-
t.noimplicit(jnp.zeros((2, 2), dtype=jnp.int32))
291-
292-
with pytest.raises(TypeError) as excinfo:
293-
t.noimplicit(jnp.zeros((2, 2), dtype=jnp.uint8))
294-
295-
296212
def test14_destroy_capsule():
297213
collect()
298214
dc = t.destruct_count()
@@ -376,31 +292,6 @@ def test18_return_pytorch():
376292
assert t.destruct_count() - dc == 1
377293

378294

379-
@needs_jax
380-
def test19_return_jax():
381-
collect()
382-
dc = tj.destruct_count()
383-
x = tj.ret_jax()
384-
assert x.shape == (2, 4)
385-
assert jnp.all(x == jnp.array([[1,2,3,4], [5,6,7,8]], dtype=jnp.float32))
386-
del x
387-
collect()
388-
assert tj.destruct_count() - dc == 1
389-
390-
391-
@needs_tensorflow
392-
def test20_return_tensorflow():
393-
collect()
394-
dc = tt.destruct_count()
395-
x = tt.ret_tensorflow()
396-
assert x.get_shape().as_list() == [2, 4]
397-
assert tf.math.reduce_all(
398-
x == tf.constant([[1,2,3,4], [5,6,7,8]], dtype=tf.float32))
399-
del x
400-
collect()
401-
assert tt.destruct_count() - dc == 1
402-
403-
404295
@needs_numpy
405296
def test21_return_array_scalar():
406297
collect()
@@ -504,16 +395,6 @@ def test28_check_torch():
504395
assert t.check(torch.zeros((1)))
505396

506397

507-
@needs_tensorflow
508-
def test29_check_tensorflow():
509-
assert t.check(tf.zeros((1)))
510-
511-
512-
@needs_jax
513-
def test30_check_jax():
514-
assert t.check(jnp.zeros((1)))
515-
516-
517398
@needs_numpy
518399
def test31_rv_policy():
519400
def p(a):
@@ -879,8 +760,6 @@ def test45_implicit_conversion_cupy():
879760
@needs_numpy
880761
def test46_implicit_conversion_contiguous_complex():
881762
# Test fix for issue #709
882-
import numpy as np
883-
884763
c_f32 = np.random.rand(10, 10)
885764
c_c64 = c_f32.astype(np.complex64)
886765

@@ -907,7 +786,6 @@ def test_conv(x):
907786

908787
@needs_numpy
909788
def test_47_ret_infer():
910-
import numpy as np
911789
assert np.all(t.ret_infer_c() == [[1, 2, 3, 4], [5, 6, 7, 8]])
912790
assert np.all(t.ret_infer_f() == [[1, 3, 5, 7], [2, 4, 6, 8]])
913791

@@ -956,13 +834,12 @@ def test50_test_matrix4f_copy():
956834

957835
@needs_numpy
958836
def test51_return_from_stack():
959-
import numpy as np
960837
assert np.all(t.ret_from_stack_1() == [1,2,3])
961838
assert np.all(t.ret_from_stack_2() == [1,2,3])
962839

840+
963841
@needs_numpy
964842
def test52_accept_np_both_true_contig():
965-
import numpy as np
966843
a = np.zeros((2, 1), dtype=np.float32)
967844
assert a.flags['C_CONTIGUOUS'] and a.flags['F_CONTIGUOUS']
968845
t.accept_np_both_true_contig_a(a)
@@ -972,6 +849,5 @@ def test52_accept_np_both_true_contig():
972849

973850
@needs_numpy
974851
def test53_issue_930():
975-
import numpy as np
976852
wrapper = t.Wrapper(np.ones(3, dtype=np.float32))
977853
assert wrapper.value[0] == 1

tests/test_tensorflow.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import test_ndarray_ext as t
2+
import test_tensorflow_ext as ttf
3+
import pytest
4+
import warnings
5+
import importlib
6+
from common import collect
7+
8+
try:
9+
import tensorflow as tf
10+
import tensorflow.config
11+
def needs_tensorflow(x):
12+
return x
13+
except:
14+
needs_tensorflow = pytest.mark.skip(reason="TensorFlow is required")
15+
16+
17+
@needs_tensorflow
18+
def test01_constrain_order_tensorflow():
19+
with warnings.catch_warnings():
20+
warnings.simplefilter("ignore")
21+
try:
22+
c = tf.zeros((3, 5))
23+
except:
24+
pytest.skip('tensorflow is missing')
25+
26+
assert t.check_order(c) == 'C'
27+
28+
29+
@needs_tensorflow
30+
def test02_implicit_conversion_tensorflow():
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore")
33+
try:
34+
c = tf.zeros((3, 5))
35+
except:
36+
pytest.skip('tensorflow is missing')
37+
38+
t.implicit(tf.zeros((2, 2), dtype=tf.int32))
39+
t.implicit(tf.zeros((2, 2, 10), dtype=tf.float32)[:, :, 4])
40+
t.implicit(tf.zeros((2, 2, 10), dtype=tf.int32)[:, :, 4])
41+
t.implicit(tf.zeros((2, 2, 10), dtype=tf.bool)[:, :, 4])
42+
43+
with pytest.raises(TypeError) as excinfo:
44+
t.noimplicit(tf.zeros((2, 2), dtype=tf.int32))
45+
46+
with pytest.raises(TypeError) as excinfo:
47+
t.noimplicit(tf.zeros((2, 2), dtype=tf.bool))
48+
49+
50+
@needs_tensorflow
51+
def test03_return_tensorflow():
52+
collect()
53+
dc = ttf.destruct_count()
54+
x = ttf.ret_tensorflow()
55+
assert x.get_shape().as_list() == [2, 4]
56+
assert tf.math.reduce_all(
57+
x == tf.constant([[1,2,3,4], [5,6,7,8]], dtype=tf.float32))
58+
del x
59+
collect()
60+
assert ttf.destruct_count() - dc == 1
61+
62+
63+
@needs_tensorflow
64+
def test04_check_tensorflow():
65+
assert t.check(tf.zeros((1)))

0 commit comments

Comments
 (0)