Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
582c449
Add Jax comparison
EmilyBourne Dec 16, 2025
622b0b4
Add jax dependency
EmilyBourne Dec 16, 2025
161486b
Benchmark of pyccel
EmilyBourne Dec 16, 2025
7d406cc
Update performance comparison
github-actions[bot] Dec 16, 2025
ca74d37
Update README and version
github-actions[bot] Dec 16, 2025
2e5437e
Pythran C++ min version increased (#45)
EmilyBourne Dec 16, 2025
4050bdb
Mention Jax in README
EmilyBourne Dec 16, 2025
7cb095e
Merge remote-tracking branch 'origin/main' into ebourne_add_jax_compa…
EmilyBourne Dec 16, 2025
e484893
Correct test choice. Correct raw string
EmilyBourne Dec 16, 2025
aaeadf6
Benchmark of pyccel
EmilyBourne Dec 16, 2025
3e33149
Update performance comparison
github-actions[bot] Dec 16, 2025
8d11e2b
Update README and version
github-actions[bot] Dec 16, 2025
a51a338
Revert files modified by CI
EmilyBourne Dec 16, 2025
056de8f
Benchmark of pyccel
EmilyBourne Jan 19, 2026
6a904f7
Update performance comparison
github-actions[bot] Jan 19, 2026
b296ed1
Update README and version
github-actions[bot] Jan 19, 2026
32efb01
Use matching Python executable. Copy correct file
EmilyBourne Jan 19, 2026
df0c64e
Benchmark of pyccel
EmilyBourne Jan 19, 2026
8a9757d
Handle all installation parameters from pyproject.toml
EmilyBourne Jan 19, 2026
c53bb25
Benchmark of pyccel
EmilyBourne Jan 19, 2026
ff33a00
Use jax.numpy instead of jit compilation
EmilyBourne Jan 19, 2026
7ee13db
Benchmark of pyccel
EmilyBourne Jan 19, 2026
39449c7
Update performance comparison
github-actions[bot] Jan 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 167 additions & 167 deletions README.md

Large diffs are not rendered by default.

33 changes: 21 additions & 12 deletions benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
help="Don't time the execution step")
parser.add_argument('--pypy', action='store_true', help='Run test cases with pypy')
parser.add_argument('--no_numba', action='store_true', help="Don't run numba tests")
parser.add_argument('--no_jax', action='store_true', help="Don't run jax tests")
parser.add_argument('--pythran-config-files', type=str, nargs='*', help='Provide configuration files for pythran', default = [])
parser.add_argument('--pyccel-config-files', type=str, nargs='*', help='Provide configuration files for pyccel', default = [])
parser.add_argument('--output', choices=('latex', 'markdown'), \
Expand Down Expand Up @@ -61,6 +62,9 @@
if not args.no_numba:
test_cases.append('numba')
test_case_names.append('numba')
if not args.no_jax:
test_cases.append('jax')
test_case_names.append('jax')
n_configs = 0
for i,f in enumerate(pyccel_configs):
name = os.path.splitext(os.path.basename(f))[0]
Expand Down Expand Up @@ -205,11 +209,16 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
numba_testname = 'numba_'+testname
numba_test_file = os.path.join(os.path.dirname(test_file), numba_basename)

jax_basename = 'jax_'+basename
jax_testname = 'jax_'+testname
jax_test_file = os.path.join(os.path.dirname(test_file), jax_basename)

new_folder = os.path.join('tmp',t.imports[0])

os.makedirs(new_folder, exist_ok=True)
shutil.copyfile(test_file, os.path.join(new_folder, basename))
shutil.copyfile(numba_test_file, os.path.join(new_folder, numba_basename))
shutil.copyfile(jax_test_file, os.path.join(new_folder, jax_basename))
os.chdir(new_folder)

import_funcs = ', '.join(t.imports)
Expand All @@ -221,9 +230,13 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
run_units = []

for case in test_cases:
setup_cmd = 'from {testname} import {funcs};'.format(
testname = numba_testname if case == 'numba' else testname,
funcs = import_funcs)
if case == 'numba':
chosen_testname = numba_testname
elif case == 'jax':
chosen_testname = jax_testname
else:
chosen_testname = testname
setup_cmd = f'from {chosen_testname} import {import_funcs};'
setup_cmd += t.setup.replace('\n','')
print("-------------------", file=log_file, flush=True)
print(" ",case, file=log_file, flush=True)
Expand Down Expand Up @@ -264,8 +277,8 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
print("Compilation CPU time : ", cpu_time, file=log_file)
comp_times.append('{:.2f}'.format(float(cpu_time)))

elif time_compilation and case == "numba":
cmd = ['pypy'] if case=='pypy' else ['python3']
elif time_compilation and case in ("numba", "jax"):
cmd = ['pypy'] if case=='pypy' else [sys.executable]
run_str = "{setup}import resource; t0 = resource.getrusage(resource.RUSAGE_SELF); {run}; t1 = resource.getrusage(resource.RUSAGE_SELF); {run}; t2 = resource.getrusage(resource.RUSAGE_SELF); print(2*t1.ru_utime-t0.ru_utime-t2.ru_utime + 2*t1.ru_stime-t0.ru_stime-t2.ru_stime)".format(
setup=setup_cmd,
run=exec_cmd)
Expand All @@ -291,7 +304,7 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
comp_times.append('-')

if time_execution:
cmd = ['pypy'] if case=='pypy' else ['python3']
cmd = ['pypy'] if case=='pypy' else [sys.executable]
cmd += ['-m'] + timeit_cmd + ['-s', setup_cmd, exec_cmd]

if verbose:
Expand All @@ -315,9 +328,7 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
stddev = float(r.group(3))
units = r.group(2)

bench_str = '{mean:.2f} $\pm$ {stddev:.2f}'.format(
mean=mean,
stddev=stddev)
bench_str = rf'{mean:.2f} $\pm$ {stddev:.2f}'
run_times.append((mean,stddev))
else:
regexp = re.compile(r'([0-9]+) loops?, best of ([0-9]+): ([0-9.]+) (\w*)')
Expand Down Expand Up @@ -354,9 +365,7 @@ def run_process(cmd: "List[str]", time_compilation: "bool"=False, env = None):
row.append('-')
else:
mean,stddev = time
row.append('{mean:.2f} $\pm$ {stddev:.2f}'.format(
mean=mean*f,
stddev=stddev*f))
row.append(rf'{mean*f:.2f} $\pm$ {stddev*f:.2f}')
else:
for time,f in zip(run_times,mult_fact):
if time is None:
Expand Down
20 changes: 20 additions & 0 deletions benchmarks/tests/jax_ackermann_mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# coding: utf-8
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
""" Module containing functions for testing the ackerman algorithm using numba
"""



def ackermann(m : int, n : int) -> int:
""" Total computable function that is not primitive recursive.
This function is useful for testing recursion
"""
if m == 0:
return n + 1
elif n == 0:
return ackermann(m - 1, 1)
else:
return ackermann(m - 1, ackermann(m, n - 1))
81 changes: 81 additions & 0 deletions benchmarks/tests/jax_bellman_ford_mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# coding: utf-8
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
""" Module containing functions for testing the Bellman-Ford algorithm using numba
"""


import jax.numpy as np



def bellman_ford ( v_num: int, e_num: int, source: int, e: 'int[:,:]', e_weight: 'float[:]',
v_weight: 'float[:]', predecessor: 'int[:]' ):
""" Calculate the shortest paths from a source vertex to all other
vertices in the weighted digraph
"""

r8_big = 1.0E+14

# Step 1: initialize the graph.
for i in range ( 0, v_num ):
v_weight[i] = r8_big
v_weight[source] = 0.0

predecessor[:v_num] = -1

# Step 2: Relax edges repeatedly.
for i in range ( 1, v_num ):
for j in range ( e_num ):
u = e[1, j]
v = e[0, j]
t = v_weight[u] + e_weight[j]
if ( t < v_weight[v] ):
v_weight[v] = t
predecessor[v] = u

# Step 3: check for negative-weight cycles
for j in range ( e_num ):
u = e[1, j]
v = e[0, j]
if ( v_weight[u] + e_weight[j] < v_weight[v] ):
print ( '' )
print ( 'BELLMAN_FORD - Fatal error!' )
print ( ' Graph contains a cycle with negative weight.' )
return 1

return 0




def bellman_ford_test():
""" Test bellman ford's algorithm
"""

e_num = 19900
v_num = 200

e = np.zeros((2, e_num), dtype = 'int')
e_weight = np.zeros(e_num, dtype = 'float')
idx = 0

for i in range(v_num):
for j in range(v_num):
if i > j:
e[0, idx] = i
e[1, idx] = j
idx += 1

for i in range(e_num):
e_weight[i] = np.cos(i) * i

source = 0
v_weight = np.zeros(v_num, dtype = 'float')
predecessor = np.zeros(v_num, dtype = 'int')

bellman_ford(v_num, e_num, source, e, e_weight, v_weight, predecessor)

return v_weight
109 changes: 109 additions & 0 deletions benchmarks/tests/jax_dijkstra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# coding: utf-8
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
""" Module containing functions for testing the Dijkstra algorithm using numba
"""

import jax.numpy as np

# ================================================================

def find_nearest ( nv: int, mind: 'int[:]', connected: 'bool[:]' ):
""" Find the nearest node
"""

i4_huge = 2147483647

d = i4_huge
v = -1
for i in range ( 0, nv ):
if ( not connected[i] and mind[i] <= d ):
d = mind[i]
v = i

return d, v

# ================================================================

def update_mind ( nv: int, mv: int, connected: 'bool[:]', ohd: 'int[:,:]', mind: 'int[:]' ):
""" Update the minimum distance
"""

i4_huge = 2147483647

for i in range ( 0, nv ):
if ( not connected[i] ):
if ( ohd[mv,i] < i4_huge ):
mind[i] = min ( mind[i], mind[mv] + ohd[mv,i] )

# ================================================================

def dijkstra_distance ( nv: int, ohd: 'int[:,:]', mind: 'int[:]' ):
""" Find the shortest paths between nodes in a graph
"""

# Start out with only node 1 connected to the tree.
connected = np.zeros (nv, dtype = 'bool' )

connected[0] = True
for i in range ( 1, nv ):
connected[i] = False

# Initialize the minimum distance to the one-step distance.
for i in range ( 1, nv ):
mind[i] = ohd[0,i]

# Attach one more node on each iteration.

for _ in range ( 1, nv ):
# Find the nearest unconnected node.
_, mv = find_nearest ( nv, mind, connected )

if ( mv == - 1 ):
print ( 'DIJKSTRA_DISTANCE - Fatal error!' )
print ( ' Search terminated early.' )
print ( ' Graph might not be connected.' )
# TODO exit
#exit ( 'DIJKSTRA_DISTANCE - Fatal error!' )

# Mark this node as connected.
connected[mv] = True

# Having determined the minimum distance to node MV, see if
# that reduces the minimum distance to other nodes.
update_mind ( nv, mv, connected, ohd, mind )

# ================================================================

def init ( nv: int, ohd: 'int[:,:]' ):
""" Create a graph
"""

i4_huge = 1 << 20

for i in range ( 0, nv ):
for j in range ( 0, nv ):
ohd[i,j] = i4_huge

ohd[i,i] = 0

ohd[0,333] = 33

# ================================================================

def dijkstra_distance_test ( ):
""" Test Dijkstra's algorithm
"""

# Initialize the problem data.
nv = 3000
ohd = np.zeros ( ( nv, nv ), dtype = 'int' )
init ( nv, ohd )

# Carry out the algorithm.
min_distance = np.zeros ( nv, dtype = 'int' )
dijkstra_distance ( nv, ohd, min_distance )

return min_distance
Loading