11import test_ndarray_ext as t
2- import test_jax_ext as tj
3- import test_tensorflow_ext as tt
42import pytest
53import warnings
64import importlib
@@ -20,21 +18,6 @@ def needs_torch(x):
2018except :
2119 needs_torch = pytest .mark .skip (reason = "PyTorch is required" )
2220
23- try :
24- import tensorflow as tf
25- import tensorflow .config
26- def needs_tensorflow (x ):
27- return x
28- except :
29- needs_tensorflow = pytest .mark .skip (reason = "TensorFlow is required" )
30-
31- try :
32- import jax .numpy as jnp
33- def needs_jax (x ):
34- return x
35- except :
36- needs_jax = pytest .mark .skip (reason = "JAX is required" )
37-
3821try :
3922 import cupy as cp
4023 def needs_cupy (x ):
@@ -158,19 +141,6 @@ def test05_constrain_order():
158141 assert t .check_order (np .zeros ((3 , 5 , 4 , 6 ), order = 'F' )[:, 2 , :, :]) == '?'
159142
160143
161- @needs_jax
162- def test06_constrain_order_jax ():
163- with warnings .catch_warnings ():
164- warnings .simplefilter ("ignore" )
165- try :
166- c = jnp .zeros ((3 , 5 ))
167- except :
168- pytest .skip ('jax is missing' )
169-
170- z = jnp .zeros ((3 , 5 , 4 , 6 ))
171- assert t .check_order (z ) == 'C'
172-
173-
174144@needs_torch
175145@pytest .mark .filterwarnings
176146def test07_constrain_order_pytorch ():
@@ -190,18 +160,6 @@ def test07_constrain_order_pytorch():
190160 assert t .check_device (torch .zeros (3 , 5 , device = 'cuda' )) == 'cuda'
191161
192162
193- @needs_tensorflow
194- def test08_constrain_order_tensorflow ():
195- with warnings .catch_warnings ():
196- warnings .simplefilter ("ignore" )
197- try :
198- c = tf .zeros ((3 , 5 ))
199- except :
200- pytest .skip ('tensorflow is missing' )
201-
202- assert t .check_order (c ) == 'C'
203-
204-
205163@needs_numpy
206164def test09_write_from_cpp ():
207165 x = np .zeros (10 , dtype = np .float32 )
@@ -251,48 +209,6 @@ def test11_implicit_conversion_pytorch():
251209 t .noimplicit (torch .zeros (2 , 2 , 10 , dtype = torch .float32 )[:, :, 4 ])
252210
253211
254- @needs_tensorflow
255- def test12_implicit_conversion_tensorflow ():
256- with warnings .catch_warnings ():
257- warnings .simplefilter ("ignore" )
258- try :
259- c = tf .zeros ((3 , 5 ))
260- except :
261- pytest .skip ('tensorflow is missing' )
262-
263- t .implicit (tf .zeros ((2 , 2 ), dtype = tf .int32 ))
264- t .implicit (tf .zeros ((2 , 2 , 10 ), dtype = tf .float32 )[:, :, 4 ])
265- t .implicit (tf .zeros ((2 , 2 , 10 ), dtype = tf .int32 )[:, :, 4 ])
266- t .implicit (tf .zeros ((2 , 2 , 10 ), dtype = tf .bool )[:, :, 4 ])
267-
268- with pytest .raises (TypeError ) as excinfo :
269- t .noimplicit (tf .zeros ((2 , 2 ), dtype = tf .int32 ))
270-
271- with pytest .raises (TypeError ) as excinfo :
272- t .noimplicit (tf .zeros ((2 , 2 ), dtype = tf .bool ))
273-
274-
275- @needs_jax
276- def test13_implicit_conversion_jax ():
277- with warnings .catch_warnings ():
278- warnings .simplefilter ("ignore" )
279- try :
280- c = jnp .zeros ((3 , 5 ))
281- except :
282- pytest .skip ('jax is missing' )
283-
284- t .implicit (jnp .zeros ((2 , 2 ), dtype = jnp .int32 ))
285- t .implicit (jnp .zeros ((2 , 2 , 10 ), dtype = jnp .float32 )[:, :, 4 ])
286- t .implicit (jnp .zeros ((2 , 2 , 10 ), dtype = jnp .int32 )[:, :, 4 ])
287- t .implicit (jnp .zeros ((2 , 2 , 10 ), dtype = jnp .bool_ )[:, :, 4 ])
288-
289- with pytest .raises (TypeError ) as excinfo :
290- t .noimplicit (jnp .zeros ((2 , 2 ), dtype = jnp .int32 ))
291-
292- with pytest .raises (TypeError ) as excinfo :
293- t .noimplicit (jnp .zeros ((2 , 2 ), dtype = jnp .uint8 ))
294-
295-
296212def test14_destroy_capsule ():
297213 collect ()
298214 dc = t .destruct_count ()
@@ -376,31 +292,6 @@ def test18_return_pytorch():
376292 assert t .destruct_count () - dc == 1
377293
378294
379- @needs_jax
380- def test19_return_jax ():
381- collect ()
382- dc = tj .destruct_count ()
383- x = tj .ret_jax ()
384- assert x .shape == (2 , 4 )
385- assert jnp .all (x == jnp .array ([[1 ,2 ,3 ,4 ], [5 ,6 ,7 ,8 ]], dtype = jnp .float32 ))
386- del x
387- collect ()
388- assert tj .destruct_count () - dc == 1
389-
390-
391- @needs_tensorflow
392- def test20_return_tensorflow ():
393- collect ()
394- dc = tt .destruct_count ()
395- x = tt .ret_tensorflow ()
396- assert x .get_shape ().as_list () == [2 , 4 ]
397- assert tf .math .reduce_all (
398- x == tf .constant ([[1 ,2 ,3 ,4 ], [5 ,6 ,7 ,8 ]], dtype = tf .float32 ))
399- del x
400- collect ()
401- assert tt .destruct_count () - dc == 1
402-
403-
404295@needs_numpy
405296def test21_return_array_scalar ():
406297 collect ()
@@ -504,16 +395,6 @@ def test28_check_torch():
504395 assert t .check (torch .zeros ((1 )))
505396
506397
507- @needs_tensorflow
508- def test29_check_tensorflow ():
509- assert t .check (tf .zeros ((1 )))
510-
511-
512- @needs_jax
513- def test30_check_jax ():
514- assert t .check (jnp .zeros ((1 )))
515-
516-
517398@needs_numpy
518399def test31_rv_policy ():
519400 def p (a ):
@@ -879,8 +760,6 @@ def test45_implicit_conversion_cupy():
879760@needs_numpy
880761def test46_implicit_conversion_contiguous_complex ():
881762 # Test fix for issue #709
882- import numpy as np
883-
884763 c_f32 = np .random .rand (10 , 10 )
885764 c_c64 = c_f32 .astype (np .complex64 )
886765
@@ -907,7 +786,6 @@ def test_conv(x):
907786
908787@needs_numpy
909788def test_47_ret_infer ():
910- import numpy as np
911789 assert np .all (t .ret_infer_c () == [[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
912790 assert np .all (t .ret_infer_f () == [[1 , 3 , 5 , 7 ], [2 , 4 , 6 , 8 ]])
913791
@@ -956,13 +834,12 @@ def test50_test_matrix4f_copy():
956834
957835@needs_numpy
958836def test51_return_from_stack ():
959- import numpy as np
960837 assert np .all (t .ret_from_stack_1 () == [1 ,2 ,3 ])
961838 assert np .all (t .ret_from_stack_2 () == [1 ,2 ,3 ])
962839
840+
963841@needs_numpy
964842def test52_accept_np_both_true_contig ():
965- import numpy as np
966843 a = np .zeros ((2 , 1 ), dtype = np .float32 )
967844 assert a .flags ['C_CONTIGUOUS' ] and a .flags ['F_CONTIGUOUS' ]
968845 t .accept_np_both_true_contig_a (a )
@@ -972,6 +849,5 @@ def test52_accept_np_both_true_contig():
972849
973850@needs_numpy
974851def test53_issue_930 ():
975- import numpy as np
976852 wrapper = t .Wrapper (np .ones (3 , dtype = np .float32 ))
977853 assert wrapper .value [0 ] == 1
0 commit comments