From e3d40fd9e2052078c53d653822e70296fa08ee7f Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Wed, 5 Nov 2025 13:59:39 +0000 Subject: [PATCH 01/17] New SMT-based ArrayIndexAnalysis --- src/psyclone/psyir/tools/__init__.py | 4 +- .../psyir/tools/array_index_analysis.py | 703 ++++++++++++++++++ .../transformations/parallel_loop_trans.py | 67 +- .../tests/psyir/nodes/omp_directives_test.py | 41 + .../psyir/tools/array_index_analysis_test.py | 315 ++++++++ .../transformations/transformations_test.py | 1 + 6 files changed, 1119 insertions(+), 12 deletions(-) create mode 100644 src/psyclone/psyir/tools/array_index_analysis.py create mode 100644 src/psyclone/tests/psyir/tools/array_index_analysis_test.py diff --git a/src/psyclone/psyir/tools/__init__.py b/src/psyclone/psyir/tools/__init__.py index 2e2f38dca7..3df31cb2ce 100644 --- a/src/psyclone/psyir/tools/__init__.py +++ b/src/psyclone/psyir/tools/__init__.py @@ -43,6 +43,7 @@ from psyclone.psyir.tools.read_write_info import ReadWriteInfo from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain from psyclone.psyir.tools.reduction_inference import ReductionInferenceTool +from psyclone.psyir.tools.array_index_analysis import ArrayIndexAnalysis # For AutoAPI documentation generation. __all__ = ['CallTreeUtils', @@ -50,4 +51,5 @@ 'DependencyTools', 'DefinitionUseChain', 'ReadWriteInfo', - 'ReductionInferenceTool'] + 'ReductionInferenceTool', + 'ArrayIndexAnalysis'] diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py new file mode 100644 index 0000000000..e83e354462 --- /dev/null +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -0,0 +1,703 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2025, University of Cambridge, UK. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: M. Naylor, University of Cambridge, UK +# ----------------------------------------------------------------------------- + +'''This module provides a class to determine whether or not distinct iterations +of a given loop can generate conflicting array accesses (if not, the loop can +potentially be parallelised). It formulates the problem as a set of SMT +constraints over array indices which are then are passed to a third-party +solver via pySMT. We currently mandate use of the Z3 solver as it has a useful +timeout option, missing in other solvers.''' + +# PySMT imports +import pysmt.shortcuts as smt +import pysmt.fnode as smt_fnode +import pysmt.typing as smt_types +import pysmt.logics as smt_logics +from pysmt.exceptions import SolverReturnedUnknownResultError + +# PSyclone imports +from psyclone.psyir.nodes import Loop, DataNode, Literal, Assignment, \ + Reference, UnaryOperation, BinaryOperation, IntrinsicCall, \ + Routine, Node, IfBlock, Schedule, ArrayReference, Range, WhileLoop +from psyclone.psyir.symbols import DataType, ScalarType, ArrayType, \ + INTEGER_TYPE + +# Outline +# ======= +# +# The analysis class provides a method 'is_loop_conflict_free()' to decide +# whether or not the array accesses in a given loop are conflicting between +# iterations. Two array accesses are conflicting if they access the same +# element of the same array, and at least one of the accesses is a write. The +# analysis algorithm operates, broadly, as follows. +# +# Given a loop, we find its enclosing routine, and start analysing the routine +# statement-by-statement in a recursive-descent fashion. +# +# As we proceed, we maintain a set of SMT constraints and a substitution that +# maps Fortran variable names to SMT variable names. For each Fortran +# variable, the substitution points to an SMT variable that is constrained (in +# the set of constraints) such that it captures the value of the Fortran +# variable at the current point in the code. When a Fortran variable is +# mutated, the substitution is be modified to point to a fresh SMT variable, +# with new constraints, without destroying the old constraints. +# +# More concretely, when we encounter an assignment of a scalar integer/logical +# variable, of the form 'x = rhs', we translate 'rhs' to the SMT formula +# 'smt_rhs' with the current substitution applied. We then add a constraint +# 'var = smt_rhs' where 'var' is a fresh SMT variable, and update the +# substition so that 'x' maps to 'var'. +# +# The Fortran-expression-to-SMT translator knows about several Fortran +# operators and intrinsics, but not all of them; when it sees something it +# doesn't know about, it simply translates it to a fresh unconstrained SMT +# variable. +# +# Sometimes we reach a statement that modifies a Fortran variable in an unknown +# way (e.g. calling a subroutine). This can be handled by updating the +# substitution to point to a fresh unconstrained SMT variable; we refer to this +# process as "killing" the variable. +# +# In addition to the current substitution, we maintain a stack of previous +# substitutions. This allows substitutions to be saved and restored before and +# after analysing a block of code that may or may not be executed at run time. +# +# We also maintain a "current condition". This can be viewed as a constraint +# that has not been comitted to the constraint set because we want to be able +# to grow, contract, and retract it as we enter and exit conditional blocks of +# code. This current condition is passed in recursive calls, so there is an +# implicit stack of them. +# +# More concretely, when we encouter an 'if' statement, we copy the current +# substitution onto the stack, then recurse into the 'then' body, passing in +# the 'if' condition as an argument, and then restore the old substitution. We +# do the same for the 'else' body if there is one (in this case the negated +# condition is passed to the recursive call). Finally, we kill all variables +# written by the 'then' and 'else' bodies, because we don't know which will be +# executed at run time. +# +# As the analysis proceeds, we also maintain a list of array accesses. For each +# access, we record various information including the name of the array, +# whether it is a read or a write, the current condition at the point the +# access is made, and its list of indices (translated to SMT). When we are +# analysing code that is inside the loop of interest, we add all array accesses +# encountered to the list. +# +# When we encounter the loop of interest, we perform a couple of steps before +# recursing into the loop body. First, we kill all variables written by the +# loop body, because we don't know whether we are entering the loop (at +# run time) for the first time or not. Second, we create two SMT variables to +# represent two arbitary but distinct iterations of the loop. Each variable is +# constrained to the start, stop, and step of the loop, and the two variables +# are constrained to be not equal. After that, we analyse the loop body twice, +# each time mapping the loop variable in the substitution to a different SMT +# variable. After analysing the loop body for the first time, we save the +# array access list and start afresh with a new one. Therefore, once the +# analysis is complete, we have two array access lists, one for each iteration. +# +# When we encounter a loop that is not the loop of interest, we follow a +# similar approach but only consider a single arbitrary iteration of the loop. +# +# When the recursive descent is complete, we are left with two array access +# lists. We are interested in whether any pair of accesses to the same array +# (in which one of the accesses is a write) represents a conflict. An access +# pair is conflict-free if an equality constraint between each access's +# indices, when combined with the current condition of each access and the +# global constraint set, is unsatisfiable. In this way, we can check every +# access pair and determine whether or not the loop is conflict free. + +class ArrayIndexAnalysis: + # Fortran integer width in bits + int_width = 32 + + def __init__(self, smt_timeout_ms=5000): + # Set SMT solver timeout in milliseconds + self.smt_timeout = smt_timeout_ms + + # Class representing an array access + class ArrayAccess: + def __init__(self, + cond: smt_fnode.FNode, + is_write: bool, + indices: list[smt_fnode.FNode], + psyir_node: Node): + # The condition at the location of the access + self.cond = cond + # Whether the access is a read or a write + self.is_write = is_write + # SMT expressions representing the indices of the access + self.indices = indices + # PSyIR Node for the access (useful for error reporting) + self.psyir_node = psyir_node + + # Initialise analysis + def init_analysis(self): + # The substitution maps integer and logical Fortran variables + # to SMT symbols + self.subst = {} + # We have a stack of these to support save/restore + self.subst_stack = [] + # The constraint set is represented as a list of boolean SMT formulae + self.constraints = [] + # The access dict maps each array name to a list of array accesses + self.access_dict = {} + # We record two access dicts, representing two arbitrary but distinct + # iterations of the loop to parallelise + self.saved_access_dicts = [] + # Are we currently analysing the loop to parallelise? + self.in_loop_to_parallelise = False + # Has the analaysis finished? + self.finished = False + + # Push copy of current substitution to the stack + def save_subst(self): + self.subst_stack.append(self.subst.copy()) + + # Pop substitution from stack into current substitution + def restore_subst(self): + self.subst = self.subst_stack.pop() + + # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol + def kill_integer_var(self, var: str): + fresh_sym = smt.FreshSymbol(typename=smt_types.BVType(self.int_width)) + smt_var = smt.Symbol(var, typename=smt_types.BVType(self.int_width)) + self.subst[smt_var] = fresh_sym + + # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol + def kill_logical_var(self, var: str): + fresh_sym = smt.FreshSymbol(typename=smt_types.BOOL) + smt_var = smt.Symbol(var, typename=smt_types.BOOL) + self.subst[smt_var] = fresh_sym + + # Kill all scalar integer/logical variables written inside 'node' + def kill_all_written_vars(self, node: Node): + var_accesses = node.reference_accesses() + for sig, access_seq in var_accesses.items(): + for access_info in access_seq.all_write_accesses: + sym = self.routine.symbol_table.lookup(sig.var_name) + if sym.is_unresolved: + continue # pragma: no cover + elif _is_scalar_integer(sym.datatype): + self.kill_integer_var(sig.var_name) + elif _is_scalar_logical(sym.datatype): + self.kill_logical_var(sig.var_name) + + # Add the SMT constraint to the constraint set + def add_constraint(self, smt_expr: smt_fnode.FNode): + self.constraints.append(smt_expr) + + # Add an integer assignment constraint to the constraint set + def add_integer_assignment(self, var: str, smt_expr: smt_fnode.FNode): + # Create a fresh symbol + fresh_sym = smt.FreshSymbol(typename=smt_types.BVType(self.int_width)) + # Assert equality between this symbol and the given SMT expression + self.add_constraint(smt.Equals(fresh_sym, smt_expr)) + # Update the substitution + smt_var = smt.Symbol(var, typename=smt_types.BVType(self.int_width)) + self.subst[smt_var] = fresh_sym + + # Add a logical assignment constraint to the constraint set + def add_logical_assignment(self, var: str, smt_expr: smt_fnode.FNode): + # Create a fresh symbol + fresh_sym = smt.FreshSymbol(typename=smt_types.BOOL) + # Assert equality between this symbol and the given SMT expression + self.add_constraint(smt.Iff(fresh_sym, smt_expr)) + # Update the substitution + smt_var = smt.Symbol(var, typename=smt_types.BOOL) + self.subst[smt_var] = fresh_sym + + # Translate integer expresison to SMT, and apply current substitution + def translate_integer_expr_with_subst(self, expr: smt_fnode.FNode): + smt_expr = translate_integer_expr(expr, self.int_width) + return smt_expr.substitute(self.subst) + + # Translate logical expresison to SMT, and apply current substitution + def translate_logical_expr_with_subst(self, expr: smt_fnode.FNode): + smt_expr = translate_logical_expr(expr, self.int_width) + return smt_expr.substitute(self.subst) + + # Constrain a loop variable to given start/stop/step + def constrain_loop_var(self, + var: smt_fnode.FNode, + start: DataNode, + stop: DataNode, + step: DataNode): + zero = smt.SBV(0, self.int_width) + var_begin = self.translate_integer_expr_with_subst(start) + var_end = self.translate_integer_expr_with_subst(stop) + if step is None: + step = Literal("1", INTEGER_TYPE) # pragma: no cover + var_step = self.translate_integer_expr_with_subst(step) + self.add_constraint(smt.And( + # (var - var_begin) % var_step == 0 + smt.Equals(smt.BVSRem(smt.BVSub(var, var_begin), var_step), zero), + # var_step > 0 ==> var >= var_begin + smt.Implies(smt.BVSGT(var_step, zero), + smt.BVSGE(var, var_begin)), + # var_step < 0 ==> var <= var_begin + smt.Implies(smt.BVSLT(var_step, zero), + smt.BVSLE(var, var_begin)), + # var_step > 0 ==> var <= var_end + smt.Implies(smt.BVSGT(var_step, zero), + smt.BVSLE(var, var_end)), + # var_step < 0 ==> var >= var_end + smt.Implies(smt.BVSLT(var_step, zero), + smt.BVSGE(var, var_end)))) + + # Add an array access to the current access dict + def add_array_access(self, + array_name: str, + is_write: bool, + cond: smt_fnode.FNode, + indices: list[smt_fnode.FNode], + psyir_node: Node): + access = ArrayIndexAnalysis.ArrayAccess( + cond, is_write, indices, psyir_node) + if array_name not in self.access_dict: + self.access_dict[array_name] = [] + self.access_dict[array_name].append(access) + + # Add all array accesses in the given node to the current access dict + def add_all_array_accesses(self, node: Node, cond: smt_fnode.FNode): + var_accesses = node.reference_accesses() + for sig, access_seq in var_accesses.items(): + for access_info in access_seq: + if access_info.is_data_access: + + # ArrayReference + if isinstance(access_info.node, ArrayReference): + indices = [] + for index in access_info.node.indices: + if isinstance(index, Range): + var = smt.FreshSymbol( + typename=smt_types.BVType( + self.int_width)) + self.constrain_loop_var( + var, index.start, index.stop, index.step) + indices.append(var) + else: + indices.append( + self.translate_integer_expr_with_subst( + index)) + self.add_array_access( + sig.var_name, + access_info.is_any_write(), + cond, indices, access_info.node) + + # Reference with datatype ArrayType + elif (isinstance(access_info.node, Reference) and + isinstance(access_info.node.datatype, ArrayType)): + indices = [] + for index in access_info.node.datatype.shape: + var = smt.FreshSymbol( + typename=smt_types.BVType(self.int_width)) + indices.append(var) + self.add_array_access( + sig.var_name, access_info.is_any_write(), + cond, indices, access_info.node) + + # Move the current access dict to the stack, and proceed with an empty one + def save_access_dict(self): + self.saved_access_dicts.append(self.access_dict) + self.access_dict = {} + + # Check if the given loop has a conflict + def is_loop_conflict_free(self, loop: Loop) -> tuple[Node, Node]: + # Type checking + if not isinstance(loop, Loop): + raise TypeError("ArrayIndexAnalysis: Loop argument expected") + self.loop = loop + + # Find the enclosing routine + routine = loop.ancestor(Routine) + if not routine: + raise ValueError( + "ArrayIndexAnalysis: loop has no enclosing routine") + self.routine = routine + + # Start with an empty constraint set and substitution + self.init_analysis() + self.loop_to_parallelise = loop + smt.reset_env() + + # Step through body of the enclosing routine, statement by statement + for stmt in routine.children: + self.step(stmt, smt.TRUE()) + + # Check that we have found and analysed the loop to parallelise + if not (self.finished and len(self.saved_access_dicts) == 2): + return None + + # Forumlate constraints for solving, considering the two iterations + iter_i = self.saved_access_dicts[0] + iter_j = self.saved_access_dicts[1] + conflicts = [] + for (arr_name, i_accesses) in iter_i.items(): + j_accesses = iter_j[arr_name] + # For each write access in the i iteration + for i_access in i_accesses: + if i_access.is_write: + # Check for conflicts against every access in the + # j iteration + for j_access in j_accesses: + assert len(i_access.indices) == len(j_access.indices) + indices_equal = [] + for (i_idx, j_idx) in zip(i_access.indices, + j_access.indices): + indices_equal.append(smt.Equals(i_idx, j_idx)) + conflicts.append(smt.And( + *indices_equal, + i_access.cond, + j_access.cond)) + + # Invoke Z3 solver with a timeout + solver = smt.Solver(name='z3', + logic=smt_logics.QF_BV, + generate_models=False, + incremental=False, + solver_options={'timeout': self.smt_timeout}) + try: + return not solver.is_sat(smt.And(*self.constraints, + smt.Or(conflicts))) + except SolverReturnedUnknownResultError: # pragma: no cover + return None # pragma: no cover + + # Analyse a single statement + def step(self, stmt: Node, cond: smt_fnode.FNode): + # Has analysis finished? + if self.finished: + return + + # Assignment + if isinstance(stmt, Assignment): + if (isinstance(stmt.lhs, Reference) + and not isinstance(stmt.lhs, ArrayReference)): + if _is_scalar_integer(stmt.lhs.datatype): + rhs_smt = self.translate_integer_expr_with_subst(stmt.rhs) + self.add_integer_assignment(stmt.lhs.name, rhs_smt) + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.rhs, cond) + return + elif _is_scalar_logical(stmt.lhs.datatype): + rhs_smt = self.translate_logical_expr_with_subst(stmt.rhs) + self.add_logical_assignment(stmt.lhs.name, rhs_smt) + return + + # Schedule + if isinstance(stmt, Schedule): + for child in stmt.children: + self.step(child, cond) + return + + # IfBlock + if isinstance(stmt, IfBlock): + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.condition, cond) + # Translate condition to SMT + smt_condition = self.translate_logical_expr_with_subst( + stmt.condition) + # Recursively step into 'then' + if stmt.if_body: + self.save_subst() + self.step(stmt.if_body, smt.And(cond, smt_condition)) + self.restore_subst() + # Recursively step into 'else' + if stmt.else_body: + self.save_subst() + self.step(stmt.else_body, + smt.And(cond, smt.Not(smt_condition))) + self.restore_subst() + # Kill vars written by each branch + if stmt.if_body: + self.kill_all_written_vars(stmt.if_body) + if stmt.else_body: + self.kill_all_written_vars(stmt.else_body) + return + + # Loop + if isinstance(stmt, Loop): + # Kill variables written by loop body + self.kill_all_written_vars(stmt.loop_body) + # Kill loop variable + self.kill_integer_var(stmt.variable.name) + # Have we reached the loop we'd like to parallelise? + if stmt is self.loop_to_parallelise: + self.in_loop_to_parallelise = True + # Consider two arbitary but distinct iterations + i_var = smt.FreshSymbol( + typename=smt_types.BVType(self.int_width)) + j_var = smt.FreshSymbol( + typename=smt_types.BVType(self.int_width)) + self.add_constraint(smt.NotEquals(i_var, j_var)) + iteration_vars = [i_var, j_var] + else: + # Consider a single, arbitrary iteration + i_var = smt.FreshSymbol( + typename=smt_types.BVType(self.int_width)) + iteration_vars = [i_var] + # Analyse loop body for each iteration variable separately + for var in iteration_vars: + self.save_subst() + smt_loop_var = smt.Symbol( + stmt.variable.name, + typename=smt_types.BVType(self.int_width)) + self.subst[smt_loop_var] = var + # Introduce constraints on loop variable + self.constrain_loop_var( + var, stmt.start_expr, stmt.stop_expr, stmt.step_expr) + # Analyse loop body + self.step(stmt.loop_body, cond) + if stmt is self.loop_to_parallelise: + self.save_access_dict() + self.restore_subst() + # Record whether the analysis has finished + if stmt is self.loop_to_parallelise: + self.finished = True + return + + # WhileLoop + if isinstance(stmt, WhileLoop): + # Kill variables written by loop body + self.kill_all_written_vars(stmt.loop_body) + # Add array accesses in condition + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.condition, cond) + # Translate condition to SMT + smt_condition = self.translate_logical_expr_with_subst( + stmt.condition) + # Recursively step into loop body + self.save_subst() + self.step(stmt.loop_body, smt.And(cond, smt_condition)) + self.restore_subst() + return + + # Fall through + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt, cond) + self.kill_all_written_vars(stmt) + +# Translating Fortran expressions to SMT formulae +# =============================================== + + +# Translate a scalar integer Fortran expression to SMT +def translate_integer_expr(expr_root: Node, int_width: int): + # SMT type to represent Fortran integers + bv_int_t = smt_types.BVType(int_width) + + def trans(expr: Node): + # Check that type is a scalar integer of unspecified precision + type_ok = _is_scalar_integer(expr.datatype) + + # Literal + if isinstance(expr, Literal) and type_ok: + return smt.SBV(int(expr.value), int_width) + + # Reference + if (isinstance(expr, Reference) + and not isinstance(expr, ArrayReference) + and type_ok): + return smt.Symbol(expr.name, typename=bv_int_t) + + # UnaryOperation + if isinstance(expr, UnaryOperation): + arg_smt = trans(expr.operand) + if expr.operator == UnaryOperation.Operator.MINUS: + return smt.BVNeg(arg_smt) + if expr.operator == UnaryOperation.Operator.PLUS: + return arg_smt + + # BinaryOperation + if isinstance(expr, BinaryOperation): + (left, right) = expr.operands + left_smt = trans(left) + right_smt = trans(right) + + if expr.operator == BinaryOperation.Operator.ADD: + return smt.BVAdd(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.SUB: + return smt.BVSub(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.MUL: + return smt.BVMul(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.DIV: + return smt.BVSDiv(left_smt, right_smt) + + # IntrinsicCall + if isinstance(expr, IntrinsicCall): + # Unary operators + if expr.intrinsic == IntrinsicCall.Intrinsic.ABS: + zero = smt.BVZero(int_width) + smt_arg = trans(expr.children[1]) + return smt.Ite(smt.BVSLT(smt_arg, zero), + smt.BVNeg(smt_arg), + smt_arg) + + # Binary operators + if expr.intrinsic in [IntrinsicCall.Intrinsic.SHIFTL, + IntrinsicCall.Intrinsic.SHIFTR, + IntrinsicCall.Intrinsic.SHIFTA, + IntrinsicCall.Intrinsic.IAND, + IntrinsicCall.Intrinsic.IOR, + IntrinsicCall.Intrinsic.IEOR, + IntrinsicCall.Intrinsic.MOD]: + left_smt = trans(expr.children[1]) + right_smt = trans(expr.children[2]) + + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: + return smt.BVLShl(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: + return smt.BVLShr(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: + return smt.BVAShr(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: + return smt.BVAnd(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: + return smt.BVOr(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: + return smt.BVXor(left_smt, right_smt) + # TODO: does BVSRem match the semantics of Fortran MOD? + if expr.intrinsic == IntrinsicCall.Intrinsic.MOD: + return smt.BVSRem(left_smt, right_smt) + + # N-ary operators + if expr.intrinsic in [IntrinsicCall.Intrinsic.MIN, + IntrinsicCall.Intrinsic.MAX]: + smt_args = [trans(arg) for arg in expr.children[1:]] + reduced = smt_args[0] + for arg in smt_args[1:]: + if expr.intrinsic == IntrinsicCall.Intrinsic.MIN: + reduced = smt.Ite(smt.BVSLT(reduced, arg), + reduced, arg) + elif expr.intrinsic == IntrinsicCall.Intrinsic.MAX: + reduced = smt.Ite(smt.BVSLT(reduced, arg), + arg, reduced) + return reduced + + # Fall through: return a fresh, unconstrained symbol + return smt.FreshSymbol(typename=bv_int_t) + + return trans(expr_root) + + +# Translate a scalar logical Fortran expression to SMT +def translate_logical_expr(expr_root: Node, int_width: int): + def trans(expr: Node): + # Check that type is a scalar logical + type_ok = _is_scalar_logical(expr.datatype) + + # Literal + if isinstance(expr, Literal) and type_ok: + if expr.value == "true": + return smt.TRUE() + if expr.value == "false": + return smt.FALSE() + + # Reference + if (isinstance(expr, Reference) + and not isinstance(expr, ArrayReference) + and type_ok): + return smt.Symbol(expr.name, typename=smt_types.BOOL) + + # UnaryOperation + if isinstance(expr, UnaryOperation): + arg_smt = trans(expr.operand) + if expr.operator == UnaryOperation.Operator.NOT: + return smt.Not(arg_smt) + + # BinaryOperation + if isinstance(expr, BinaryOperation): + # Operands are logicals + if expr.operator in [BinaryOperation.Operator.AND, + BinaryOperation.Operator.OR, + BinaryOperation.Operator.EQV, + BinaryOperation.Operator.NEQV]: + (left, right) = expr.operands + left_smt = trans(left) + right_smt = trans(right) + + if expr.operator == BinaryOperation.Operator.AND: + return smt.And(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.OR: + return smt.Or(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.EQV: + return smt.Iff(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.NEQV: + return smt.Not(smt.Iff(left_smt, right_smt)) + + # Operands are numbers + if expr.operator in [BinaryOperation.Operator.EQ, + BinaryOperation.Operator.NE, + BinaryOperation.Operator.GT, + BinaryOperation.Operator.LT, + BinaryOperation.Operator.GE, + BinaryOperation.Operator.LE]: + (left, right) = expr.operands + left_smt = translate_integer_expr(left, int_width) + right_smt = translate_integer_expr(right, int_width) + + if expr.operator == BinaryOperation.Operator.EQ: + return smt.Equals(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.NE: + return smt.NotEquals(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.GT: + return smt.BVSGT(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.LT: + return smt.BVSLT(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.GE: + return smt.BVSGE(left_smt, right_smt) + if expr.operator == BinaryOperation.Operator.LE: + return smt.BVSLE(left_smt, right_smt) + + # Fall through: return a fresh, unconstrained symbol + return smt.FreshSymbol(typename=smt_types.BOOL) + + return trans(expr_root) + +# Helper functions +# ================ + + +# Check that type is a scalar integer of unspecified precision +def _is_scalar_integer(dt: DataType) -> bool: + return (isinstance(dt, ScalarType) and + dt.intrinsic == ScalarType.Intrinsic.INTEGER and + dt.precision == ScalarType.Precision.UNDEFINED) + + +# Check that type is a scalar logical +def _is_scalar_logical(dt: DataType) -> bool: + return (isinstance(dt, ScalarType) and + dt.intrinsic == ScalarType.Intrinsic.BOOLEAN) diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index 5b1db26c8b..6c3ab67156 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -53,7 +53,7 @@ BinaryOperation, IntrinsicCall ) from psyclone.psyir.tools import ( - DependencyTools, DTCode, ReductionInferenceTool + DependencyTools, DTCode, ReductionInferenceTool, ArrayIndexAnalysis ) from psyclone.psyir.transformations.loop_trans import LoopTrans from psyclone.psyir.transformations.async_trans_mixin import \ @@ -175,6 +175,9 @@ def validate(self, node, options=None, **kwargs): reduction_ops = self.get_option("reduction_ops", **kwargs) if reduction_ops is None: reduction_ops = [] + use_smt_array_anal = self.get_option( + "use_smt_array_anal", **kwargs) + smt_timeout_ms = self.get_option("smt_timeout_ms", **kwargs) else: verbose = options.get("verbose", False) collapse = options.get("collapse", False) @@ -185,6 +188,8 @@ def validate(self, node, options=None, **kwargs): sequential = options.get("sequential", False) privatise_arrays = options.get("privatise_arrays", False) reduction_ops = options.get("reduction_ops", []) + use_smt_array_anal = options.get("use_smt_array_anal", False) + smt_timeout_ms = options.get("smt_timeout_ms", 5000) # Check type of reduction_ops (not handled by validate_options) if not isinstance(reduction_ops, list): @@ -271,6 +276,7 @@ def validate(self, node, options=None, **kwargs): # The DependencyTools also returns False for things that are # not an issue, so we ignore specific messages. errors = [] + num_depedency_errors = 0 for message in dep_tools.get_all_messages(): if message.code == DTCode.WARN_SCALAR_WRITTEN_ONCE: continue @@ -296,8 +302,22 @@ def validate(self, node, options=None, **kwargs): if clause: self.inferred_reduction_clauses.append(clause) continue + + if (message.code == DTCode.ERROR_DEPENDENCY): + num_depedency_errors = num_depedency_errors + 1 errors.append(str(message)) + # Use ArrayIndexAnalysis + if use_smt_array_anal: + # Are all the errors array dependency errors? + if len(errors) > 0 and len(errors) == num_depedency_errors: + # Try using the ArrayIndexAnalysis to prove that the + # dependency errors are false + arr_anal = ArrayIndexAnalysis( + smt_timeout_ms=smt_timeout_ms) + if arr_anal.is_loop_conflict_free(node): + errors = [] + if errors: error_lines = "\n".join(errors) messages = (f"Loop cannot be parallelised because:\n" @@ -326,6 +346,8 @@ def apply(self, node, options=None, verbose: bool = False, nowait: bool = False, reduction_ops: List[Union[BinaryOperation.Operator, IntrinsicCall.Intrinsic]] = None, + use_smt_array_anal: bool = False, + smt_timeout_ms: int = 5000, **kwargs): ''' Apply the Loop transformation to the specified node in a @@ -370,6 +392,9 @@ def apply(self, node, options=None, verbose: bool = False, :param reduction_ops: if non-empty, attempt parallelisation of loops by inferring reduction clauses involving any of the reduction operators in the list. + :param bool use_smt_array_anal: whether to use the SMT-based + ArrayIndexAnalysis to discharge false dependency errors. + :param bool smt_timeout_ms: SMT solver timeout in milliseconds. ''' if not options: @@ -378,7 +403,9 @@ def apply(self, node, options=None, verbose: bool = False, ignore_dependencies_for=ignore_dependencies_for, privatise_arrays=privatise_arrays, sequential=sequential, nowait=nowait, - reduction_ops=reduction_ops, **kwargs + reduction_ops=reduction_ops, + use_smt_array_anal=use_smt_array_anal, + smt_timeout_ms=smt_timeout_ms, **kwargs ) # Rename the input options that are renamed in this apply method. # TODO 2668, rename options to be consistent. @@ -399,13 +426,18 @@ def apply(self, node, options=None, verbose: bool = False, privatise_arrays = options.get("privatise_arrays", False) nowait = options.get("nowait", False) reduction_ops = options.get("reduction_ops", []) + use_smt_array_anal = options.get("use_smt_array_anal", False) + smt_timeout_ms = options.get("smt_timeout_ms", False) self.validate(node, options=options, verbose=verbose, collapse=collapse, force=force, ignore_dependencies_for=ignore_dependencies_for, privatise_arrays=privatise_arrays, sequential=sequential, nowait=nowait, - reduction_ops=reduction_ops, **kwargs) + reduction_ops=reduction_ops, + use_smt_array_anal=use_smt_array_anal, + smt_timeout_ms=smt_timeout_ms, + **kwargs) list_of_signatures = [Signature(name) for name in list_of_names] dtools = DependencyTools() @@ -483,14 +515,27 @@ def apply(self, node, options=None, verbose: bool = False, if not next_loop.independent_iterations( dep_tools=dtools, signatures_to_ignore=list_of_signatures): - if verbose: - msgs = dtools.get_all_messages() - next_loop.preceding_comment = ( - "\n".join([str(m) for m in msgs]) + - " Consider using the \"ignore_dependencies_" - "for\" transformation option if this is a " - "false dependency.") - break + msgs = dtools.get_all_messages() + discharge_errors = False + + if use_smt_array_anal: + all_dep_errors = all( + [msg.code == DTCode.ERROR_DEPENDENCY + for msg in msgs]) + arr_anal = ArrayIndexAnalysis( + smt_timeout_ms=smt_timeout_ms) + discharge_errors = ( + all_dep_errors and + arr_anal.is_loop_conflict_free(next_loop)) + + if not discharge_errors: + if verbose: + next_loop.preceding_comment = ( + "\n".join([str(m) for m in msgs]) + + " Consider using the \"ignore_dependenc" + "ies_for\" transformation option if this " + "is a false dependency.") + break else: num_collapsable_loops = None diff --git a/src/psyclone/tests/psyir/nodes/omp_directives_test.py b/src/psyclone/tests/psyir/nodes/omp_directives_test.py index 1da2c8c0c0..1dd5b524bf 100644 --- a/src/psyclone/tests/psyir/nodes/omp_directives_test.py +++ b/src/psyclone/tests/psyir/nodes/omp_directives_test.py @@ -74,6 +74,7 @@ LFRicOMPLoopTrans, OMPParallelTrans, OMPParallelLoopTrans, LFRicOMPParallelLoopTrans, OMPSingleTrans, OMPMasterTrans, OMPLoopTrans, TransformationError) +from pysmt.exceptions import NoSolverAvailableError BASE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(__file__)))), "test_files", "lfric") @@ -5349,3 +5350,43 @@ def test_firstprivate_with_uninitialised(fortran_reader, fortran_writer): output = fortran_writer(psyir) assert "firstprivate(a)" in output assert "firstprivate(b)" in output + + +def test_array_analysis_option(fortran_reader, fortran_writer): + '''Test that a tiled loop can be parallelised when using the SMT-based + array index analysis. + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine my_matmul(a, b, c) + integer, dimension(:,:), intent(in) :: a + integer, dimension(:,:), intent(in) :: b + integer, dimension(:,:), intent(out) :: c + integer :: x, y, k, k_out_var, x_out_var, y_out_var, a1_n, a2_n, b1_n + + a2_n = SIZE(a, 2) + b1_n = SIZE(b, 1) + a1_n = SIZE(a, 1) + + c(:,:) = 0 + do y_out_var = 1, a2_n, 8 + do x_out_var = 1, b1_n, 8 + do k_out_var = 1, a1_n, 8 + do y = y_out_var, MIN(y_out_var + (8 - 1), a2_n), 1 + do x = x_out_var, MIN(x_out_var + (8 - 1), b1_n), 1 + do k = k_out_var, MIN(k_out_var + (8 - 1), a1_n), 1 + c(x,y) = c(x,y) + a(k,y) * b(x,k) + enddo + enddo + enddo + enddo + enddo + enddo + end subroutine my_matmul''') + omplooptrans = OMPLoopTrans(omp_directive="paralleldo") + loop = psyir.walk(Loop)[0] + try: + omplooptrans.apply(loop, collapse=True, use_smt_array_anal=True) + output = fortran_writer(psyir) + assert "collapse(2)" in output + except NoSolverAvailableError: + pass diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py new file mode 100644 index 0000000000..09cda0b778 --- /dev/null +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -0,0 +1,315 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2025, University of Cambridge, UK +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author M. Naylor, University of Cambridge, UK + +''' Module containing tests for the SMT-based array index analysis.''' + +import pytest +from psyclone.psyir.nodes import (Loop, Assignment, Reference) +from psyclone.psyir.symbols import Symbol +from psyclone.psyir.tools import ArrayIndexAnalysis +from psyclone.psyir.tools.array_index_analysis import translate_logical_expr +import pysmt.shortcuts as smt +from pysmt.exceptions import NoSolverAvailableError + + +# ----------------------------------------------------------------------------- +def test_reverse(fortran_reader, fortran_writer): + '''Test that an array reversal routine has no array conflicts + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine reverse(arr) + real, intent(inout) :: arr(:) + real :: tmp + integer :: i, n + n = size(arr) + do i = 1, n/2 + tmp = arr(i) + arr(i) = arr(n+1-i) + arr(n+1-i) = tmp + end do + end subroutine''') + try: + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_odd_even_trans(fortran_reader, fortran_writer): + '''Test that Knuth's odd-even transposition has no array conflicts + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine odd_even_transposition(arr, start) + real, intent(inout) :: arr(:) + integer, intent(in) :: start + real :: tmp + integer :: i + do i = start, size(arr), 2 + if (arr(i) > arr(i+1)) then + tmp = arr(i+1) + arr(i+1) = arr(i) + arr(i) = tmp + end if + end do + end subroutine''') + try: + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_tiled_matmul(fortran_reader, fortran_writer): + '''Test that tiled matmul has no array conflicts in 4/6 loops + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine my_matmul(a, b, c) + integer, dimension(:,:), intent(in) :: a + integer, dimension(:,:), intent(in) :: b + integer, dimension(:,:), intent(out) :: c + integer :: x, y, k, k_out_var, x_out_var, y_out_var, a1_n, a2_n, b1_n + + a2_n = SIZE(a, 2) + b1_n = SIZE(b, 1) + a1_n = SIZE(a, 1) + + c(:,:) = 0 + do y_out_var = 1, a2_n, 8 + do x_out_var = 1, b1_n, 8 + do k_out_var = 1, a1_n, 8 + do y = y_out_var, MIN(y_out_var + (8 - 1), a2_n), 1 + do x = x_out_var, MIN(x_out_var + (8 - 1), b1_n), 1 + do k = k_out_var, MIN(k_out_var + (8 - 1), a1_n), 1 + c(x,y) = c(x,y) + a(k,y) * b(x,k) + enddo + enddo + enddo + enddo + enddo + enddo + end subroutine my_matmul''') + try: + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True, True, False, True, True, False] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_flatten1(fortran_reader, fortran_writer): + '''Test that an array flattening routine has no array conflicts in its + inner loop (there are conflicts, due to integer overflow, in its outer + loop) + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine flatten1(mat, arr) + real, intent(in) :: mat(0:,0:) + real, intent(out) :: arr(0:) + integer :: x, y + integer :: nx, ny + nx = size(mat, 1) + ny = size(mat, 2) + do y = 0, ny-1 + do x = 0, nx-1 + arr(nx * y + x) = mat(x, y) + end do + end do + end subroutine''') + try: + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [False, True] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_flatten2(fortran_reader, fortran_writer): + '''Test that an array flattening routine has no array conflicts + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine flatten2(mat, arr) + real, intent(in) :: mat(0:,0:) + real, intent(out) :: arr(0:) + integer :: i, n, ny + n = size(arr) + ny = size(mat, 2) + do i = 0, n-1 + arr(i) = mat(mod(i, ny), i/ny) + end do + end subroutine''') + try: + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_translate_expr(fortran_reader, fortran_writer): + '''Test that Fortran expressions are being correctly translated to SMT. + ''' + def test(expr): + psyir = fortran_reader.psyir_from_source(f''' + subroutine sub(x) + logical, intent(out) :: x + x = {expr} + end subroutine''') + for assign in psyir.walk(Assignment): + rhs_smt = translate_logical_expr(assign.rhs, 32) + try: + assert smt.is_sat(rhs_smt) is True + except NoSolverAvailableError: + pass + return + + test("+1 == 1") + test("abs(-1) == 1") + test("shiftr(2,1) == 1") + test("shifta(-2,1) == -1") + test("iand(5,1) == 1") + test("ior(1,2) == 3") + test("ieor(3,1) == 2") + test("max(3,1) == 3") + test(".true.") + test(".not. .false.") + test(".true. .and. .true.") + test(".true. .or. .false.") + test(".false. .eqv. .false.") + test(".false. .neqv. .true.") + test("1 < 2") + test("10 > 2") + test("1 <= 1 .and. 0 <= 1") + test("1 >= 1 .and. 2 >= 1") + test("foo(1)") + + +# ----------------------------------------------------------------------------- +def check_conflict_free(fortran_reader, loop_str, yesno): + '''Helper function to check that given loop for conflicts. + The loop may refer to array "arr", integer variables "i" and "n", + and logical variable "ok". + ''' + psyir = fortran_reader.psyir_from_source(f''' + subroutine sub(arr, n) + integer, intent(inout) :: arr(:) + integer, intent(in) :: n, i + logical :: ok + {loop_str} + end subroutine''') + try: + results = [] + for loop in psyir.walk(Loop): + analysis = ArrayIndexAnalysis() + results.append(analysis.is_loop_conflict_free(loop)) + assert results == [yesno] + except NoSolverAvailableError: + pass + + +# ----------------------------------------------------------------------------- +def test_ifblock_with_else(fortran_reader, fortran_writer): + '''Test that an IfBlock with an "else" is correctly handled''' + check_conflict_free(fortran_reader, + '''do i = 1, n + ok = i == 1 + if (ok) then + arr(1) = 0 + else + arr(i) = i + end if + end do + arr(2) = 0 + ''', + True) + + +# ----------------------------------------------------------------------------- +def test_array_reference(fortran_reader, fortran_writer): + '''Test an array Reference with no indices is correctly handled''' + check_conflict_free(fortran_reader, + '''do i = 1, n + arr = arr + i + end do + ''', + False) + + +# ----------------------------------------------------------------------------- +def test_singleton_slice(fortran_reader, fortran_writer): + '''Test that an array slice with a single element is correctly handled''' + check_conflict_free(fortran_reader, + '''do i = 1, n + arr(i:i:) = 0 + end do + ''', + True) + + +# ----------------------------------------------------------------------------- +def test_errors(fortran_reader, fortran_writer): + '''Test that ArrayIndexAnalysis raises appropriate exceptions in + error cases + ''' + with pytest.raises(TypeError) as err: + ArrayIndexAnalysis().is_loop_conflict_free(Reference(Symbol("foo"))) + assert ("ArrayIndexAnalysis: Loop argument expected" + in str(err.value)) + + psyir = fortran_reader.psyir_from_source(''' + subroutine sub(arr, n) + integer, intent(inout) :: arr(:) + integer, intent(in) :: n, i + do i = 1, n + arr(i) = i + end do + end subroutine''') + loop = psyir.walk(Loop)[0] + loop.detach() + with pytest.raises(ValueError) as err: + ArrayIndexAnalysis().is_loop_conflict_free(loop) + assert ("ArrayIndexAnalysis: loop has no enclosing routine" + in str(err.value)) diff --git a/src/psyclone/tests/psyir/transformations/transformations_test.py b/src/psyclone/tests/psyir/transformations/transformations_test.py index 3aeb000149..9700875f14 100644 --- a/src/psyclone/tests/psyir/transformations/transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/transformations_test.py @@ -578,6 +578,7 @@ def test_omploop_trans_new_options(sample_psyir): "'fakeoption2']. Valid options are '['node_type_check', " "'verbose', 'collapse', 'force', 'ignore_dependencies_for', " "'privatise_arrays', 'sequential', 'nowait', 'reduction_ops', " + "'use_smt_array_anal', 'smt_timeout_ms', " "'options', 'reprod', 'enable_reductions']." in str(excinfo.value)) From ebd03abc62cb01f40fef5695b29fdbbbb2bdaac5 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Thu, 6 Nov 2025 15:15:28 +0000 Subject: [PATCH 02/17] Support arbitrary precision integers in addition to bit vectors Also use z3-solver instead of pysmt as it exposes additional useful features, such as predicates for checking bit vector overflow. --- .../psyir/tools/array_index_analysis.py | 318 ++++++++++-------- .../tests/psyir/nodes/omp_directives_test.py | 10 +- .../psyir/tools/array_index_analysis_test.py | 83 ++--- 3 files changed, 209 insertions(+), 202 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index e83e354462..58941c63c8 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -37,18 +37,9 @@ '''This module provides a class to determine whether or not distinct iterations of a given loop can generate conflicting array accesses (if not, the loop can potentially be parallelised). It formulates the problem as a set of SMT -constraints over array indices which are then are passed to a third-party -solver via pySMT. We currently mandate use of the Z3 solver as it has a useful -timeout option, missing in other solvers.''' - -# PySMT imports -import pysmt.shortcuts as smt -import pysmt.fnode as smt_fnode -import pysmt.typing as smt_types -import pysmt.logics as smt_logics -from pysmt.exceptions import SolverReturnedUnknownResultError - -# PSyclone imports +constraints over array indices which are then are passed to the Z3 solver.''' + +import z3 from psyclone.psyir.nodes import Loop, DataNode, Literal, Assignment, \ Reference, UnaryOperation, BinaryOperation, IntrinsicCall, \ Routine, Node, IfBlock, Schedule, ArrayReference, Range, WhileLoop @@ -140,19 +131,23 @@ # access pair and determine whether or not the loop is conflict free. class ArrayIndexAnalysis: - # Fortran integer width in bits - int_width = 32 - - def __init__(self, smt_timeout_ms=5000): + def __init__(self, + int_width : int = 32, + use_bv : int = True, + smt_timeout_ms : int = 5000): # Set SMT solver timeout in milliseconds self.smt_timeout = smt_timeout_ms + # Fortran integer width in bits + self.int_width = int_width + # Use fixed-width bit vectors or arbirary precision integers? + self.use_bv = use_bv # Class representing an array access class ArrayAccess: def __init__(self, - cond: smt_fnode.FNode, + cond: z3.BoolRef, is_write: bool, - indices: list[smt_fnode.FNode], + indices: list[z3.ExprRef], psyir_node: Node): # The condition at the location of the access self.cond = cond @@ -190,16 +185,37 @@ def save_subst(self): def restore_subst(self): self.subst = self.subst_stack.pop() + # Create an fresh SMT integer variable + def fresh_integer_var(self) -> z3.ExprRef: + if self.use_bv: + return z3.FreshConst(z3.BitVecSort(self.int_width)) + else: + return z3.FreshInt() + + # Create an integer SMT variable with the given name + def integer_var(self, var) -> z3.ExprRef: + if self.use_bv: + return z3.BitVec(var, self.int_width) + else: + return z3.Int(var) + + # Create an SMT integer value + def integer_val(self, val : int) -> z3.ExprRef: + if self.use_bv: + return z3.BitVecVal(val, self.int_width) + else: + return z3.IntVal(val) + # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol def kill_integer_var(self, var: str): - fresh_sym = smt.FreshSymbol(typename=smt_types.BVType(self.int_width)) - smt_var = smt.Symbol(var, typename=smt_types.BVType(self.int_width)) + fresh_sym = self.fresh_integer_var() + smt_var = self.integer_var(var) self.subst[smt_var] = fresh_sym # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol def kill_logical_var(self, var: str): - fresh_sym = smt.FreshSymbol(typename=smt_types.BOOL) - smt_var = smt.Symbol(var, typename=smt_types.BOOL) + fresh_sym = z3.FreshBool() + smt_var = z3.Bool(var) self.subst[smt_var] = fresh_sym # Kill all scalar integer/logical variables written inside 'node' @@ -216,73 +232,66 @@ def kill_all_written_vars(self, node: Node): self.kill_logical_var(sig.var_name) # Add the SMT constraint to the constraint set - def add_constraint(self, smt_expr: smt_fnode.FNode): + def add_constraint(self, smt_expr: z3.BoolRef): self.constraints.append(smt_expr) # Add an integer assignment constraint to the constraint set - def add_integer_assignment(self, var: str, smt_expr: smt_fnode.FNode): + def add_integer_assignment(self, var: str, smt_expr: z3.ExprRef): # Create a fresh symbol - fresh_sym = smt.FreshSymbol(typename=smt_types.BVType(self.int_width)) + fresh_sym = self.fresh_integer_var() # Assert equality between this symbol and the given SMT expression - self.add_constraint(smt.Equals(fresh_sym, smt_expr)) + self.add_constraint(fresh_sym == smt_expr) # Update the substitution - smt_var = smt.Symbol(var, typename=smt_types.BVType(self.int_width)) + smt_var = self.integer_var(var) self.subst[smt_var] = fresh_sym # Add a logical assignment constraint to the constraint set - def add_logical_assignment(self, var: str, smt_expr: smt_fnode.FNode): + def add_logical_assignment(self, var: str, smt_expr: z3.BoolRef): # Create a fresh symbol - fresh_sym = smt.FreshSymbol(typename=smt_types.BOOL) + fresh_sym = z3.FreshBool() # Assert equality between this symbol and the given SMT expression - self.add_constraint(smt.Iff(fresh_sym, smt_expr)) + self.add_constraint(fresh_sym == smt_expr) # Update the substitution - smt_var = smt.Symbol(var, typename=smt_types.BOOL) + smt_var = z3.Bool(var) self.subst[smt_var] = fresh_sym # Translate integer expresison to SMT, and apply current substitution - def translate_integer_expr_with_subst(self, expr: smt_fnode.FNode): - smt_expr = translate_integer_expr(expr, self.int_width) - return smt_expr.substitute(self.subst) + def translate_integer_expr_with_subst(self, expr: z3.ExprRef): + smt_expr = translate_integer_expr(expr, self.int_width, self.use_bv) + subst_pairs = list(self.subst.items()) + return z3.substitute(smt_expr, *subst_pairs) # Translate logical expresison to SMT, and apply current substitution - def translate_logical_expr_with_subst(self, expr: smt_fnode.FNode): - smt_expr = translate_logical_expr(expr, self.int_width) - return smt_expr.substitute(self.subst) + def translate_logical_expr_with_subst(self, expr: z3.BoolRef): + smt_expr = translate_logical_expr(expr, self.int_width, self.use_bv) + subst_pairs = list(self.subst.items()) + return z3.substitute(smt_expr, *subst_pairs) # Constrain a loop variable to given start/stop/step def constrain_loop_var(self, - var: smt_fnode.FNode, + var: z3.ExprRef, start: DataNode, stop: DataNode, step: DataNode): - zero = smt.SBV(0, self.int_width) + zero = self.integer_val(0) var_begin = self.translate_integer_expr_with_subst(start) var_end = self.translate_integer_expr_with_subst(stop) if step is None: step = Literal("1", INTEGER_TYPE) # pragma: no cover var_step = self.translate_integer_expr_with_subst(step) - self.add_constraint(smt.And( - # (var - var_begin) % var_step == 0 - smt.Equals(smt.BVSRem(smt.BVSub(var, var_begin), var_step), zero), - # var_step > 0 ==> var >= var_begin - smt.Implies(smt.BVSGT(var_step, zero), - smt.BVSGE(var, var_begin)), - # var_step < 0 ==> var <= var_begin - smt.Implies(smt.BVSLT(var_step, zero), - smt.BVSLE(var, var_begin)), - # var_step > 0 ==> var <= var_end - smt.Implies(smt.BVSGT(var_step, zero), - smt.BVSLE(var, var_end)), - # var_step < 0 ==> var >= var_end - smt.Implies(smt.BVSLT(var_step, zero), - smt.BVSGE(var, var_end)))) + self.add_constraint(z3.And( + ((var - var_begin) % var_step) == zero, + z3.Implies(var_step > zero, var >= var_begin), + z3.Implies(var_step < zero, var <= var_begin), + z3.Implies(var_step > zero, var <= var_end), + z3.Implies(var_step < zero, var >= var_end))) # Add an array access to the current access dict def add_array_access(self, array_name: str, is_write: bool, - cond: smt_fnode.FNode, - indices: list[smt_fnode.FNode], + cond: z3.BoolRef, + indices: list[z3.ExprRef], psyir_node: Node): access = ArrayIndexAnalysis.ArrayAccess( cond, is_write, indices, psyir_node) @@ -291,7 +300,7 @@ def add_array_access(self, self.access_dict[array_name].append(access) # Add all array accesses in the given node to the current access dict - def add_all_array_accesses(self, node: Node, cond: smt_fnode.FNode): + def add_all_array_accesses(self, node: Node, cond: z3.BoolRef): var_accesses = node.reference_accesses() for sig, access_seq in var_accesses.items(): for access_info in access_seq: @@ -302,9 +311,7 @@ def add_all_array_accesses(self, node: Node, cond: smt_fnode.FNode): indices = [] for index in access_info.node.indices: if isinstance(index, Range): - var = smt.FreshSymbol( - typename=smt_types.BVType( - self.int_width)) + var = self.fresh_integer_var() self.constrain_loop_var( var, index.start, index.stop, index.step) indices.append(var) @@ -322,8 +329,7 @@ def add_all_array_accesses(self, node: Node, cond: smt_fnode.FNode): isinstance(access_info.node.datatype, ArrayType)): indices = [] for index in access_info.node.datatype.shape: - var = smt.FreshSymbol( - typename=smt_types.BVType(self.int_width)) + var = self.fresh_integer_var() indices.append(var) self.add_array_access( sig.var_name, access_info.is_any_write(), @@ -335,7 +341,7 @@ def save_access_dict(self): self.access_dict = {} # Check if the given loop has a conflict - def is_loop_conflict_free(self, loop: Loop) -> tuple[Node, Node]: + def is_loop_conflict_free(self, loop: Loop) -> bool: # Type checking if not isinstance(loop, Loop): raise TypeError("ArrayIndexAnalysis: Loop argument expected") @@ -351,11 +357,10 @@ def is_loop_conflict_free(self, loop: Loop) -> tuple[Node, Node]: # Start with an empty constraint set and substitution self.init_analysis() self.loop_to_parallelise = loop - smt.reset_env() # Step through body of the enclosing routine, statement by statement for stmt in routine.children: - self.step(stmt, smt.TRUE()) + self.step(stmt, z3.BoolVal(True)) # Check that we have found and analysed the loop to parallelise if not (self.finished and len(self.saved_access_dicts) == 2): @@ -377,26 +382,26 @@ def is_loop_conflict_free(self, loop: Loop) -> tuple[Node, Node]: indices_equal = [] for (i_idx, j_idx) in zip(i_access.indices, j_access.indices): - indices_equal.append(smt.Equals(i_idx, j_idx)) - conflicts.append(smt.And( + indices_equal.append(i_idx == j_idx) + conflicts.append(z3.And( *indices_equal, i_access.cond, j_access.cond)) # Invoke Z3 solver with a timeout - solver = smt.Solver(name='z3', - logic=smt_logics.QF_BV, - generate_models=False, - incremental=False, - solver_options={'timeout': self.smt_timeout}) - try: - return not solver.is_sat(smt.And(*self.constraints, - smt.Or(conflicts))) - except SolverReturnedUnknownResultError: # pragma: no cover - return None # pragma: no cover + solver = z3.Solver() + solver.set("timeout", self.smt_timeout) + solver.add(z3.And(*self.constraints, z3.Or(*conflicts))) + result = solver.check() + if result == z3.unknown: + return None # pragma: no cover + elif result == z3.sat: + return False + else: + return True # Analyse a single statement - def step(self, stmt: Node, cond: smt_fnode.FNode): + def step(self, stmt: Node, cond: z3.BoolRef): # Has analysis finished? if self.finished: return @@ -432,13 +437,13 @@ def step(self, stmt: Node, cond: smt_fnode.FNode): # Recursively step into 'then' if stmt.if_body: self.save_subst() - self.step(stmt.if_body, smt.And(cond, smt_condition)) + self.step(stmt.if_body, z3.And(cond, smt_condition)) self.restore_subst() # Recursively step into 'else' if stmt.else_body: self.save_subst() self.step(stmt.else_body, - smt.And(cond, smt.Not(smt_condition))) + z3.And(cond, z3.Not(smt_condition))) self.restore_subst() # Kill vars written by each branch if stmt.if_body: @@ -457,23 +462,18 @@ def step(self, stmt: Node, cond: smt_fnode.FNode): if stmt is self.loop_to_parallelise: self.in_loop_to_parallelise = True # Consider two arbitary but distinct iterations - i_var = smt.FreshSymbol( - typename=smt_types.BVType(self.int_width)) - j_var = smt.FreshSymbol( - typename=smt_types.BVType(self.int_width)) - self.add_constraint(smt.NotEquals(i_var, j_var)) + i_var = self.fresh_integer_var() + j_var = self.fresh_integer_var() + self.add_constraint(i_var != j_var) iteration_vars = [i_var, j_var] else: # Consider a single, arbitrary iteration - i_var = smt.FreshSymbol( - typename=smt_types.BVType(self.int_width)) + i_var = self.fresh_integer_var() iteration_vars = [i_var] # Analyse loop body for each iteration variable separately for var in iteration_vars: self.save_subst() - smt_loop_var = smt.Symbol( - stmt.variable.name, - typename=smt_types.BVType(self.int_width)) + smt_loop_var = self.integer_var(stmt.variable.name) self.subst[smt_loop_var] = var # Introduce constraints on loop variable self.constrain_loop_var( @@ -500,7 +500,7 @@ def step(self, stmt: Node, cond: smt_fnode.FNode): stmt.condition) # Recursively step into loop body self.save_subst() - self.step(stmt.loop_body, smt.And(cond, smt_condition)) + self.step(stmt.loop_body, z3.And(cond, smt_condition)) self.restore_subst() return @@ -514,29 +514,34 @@ def step(self, stmt: Node, cond: smt_fnode.FNode): # Translate a scalar integer Fortran expression to SMT -def translate_integer_expr(expr_root: Node, int_width: int): - # SMT type to represent Fortran integers - bv_int_t = smt_types.BVType(int_width) - +def translate_integer_expr(expr_root: Node, + int_width: int, + use_bv: bool) -> z3.ExprRef: def trans(expr: Node): # Check that type is a scalar integer of unspecified precision type_ok = _is_scalar_integer(expr.datatype) # Literal if isinstance(expr, Literal) and type_ok: - return smt.SBV(int(expr.value), int_width) + if use_bv: + return z3.BitVecVal(int(expr.value), int_width) + else: + return z3.IntVal(int(expr.value)) # Reference if (isinstance(expr, Reference) and not isinstance(expr, ArrayReference) and type_ok): - return smt.Symbol(expr.name, typename=bv_int_t) + if use_bv: + return z3.BitVec(expr.name, int_width) + else: + return z3.Int(expr.name) # UnaryOperation if isinstance(expr, UnaryOperation): arg_smt = trans(expr.operand) if expr.operator == UnaryOperation.Operator.MINUS: - return smt.BVNeg(arg_smt) + return -arg_smt if expr.operator == UnaryOperation.Operator.PLUS: return arg_smt @@ -547,23 +552,20 @@ def trans(expr: Node): right_smt = trans(right) if expr.operator == BinaryOperation.Operator.ADD: - return smt.BVAdd(left_smt, right_smt) + return left_smt + right_smt if expr.operator == BinaryOperation.Operator.SUB: - return smt.BVSub(left_smt, right_smt) + return left_smt - right_smt if expr.operator == BinaryOperation.Operator.MUL: - return smt.BVMul(left_smt, right_smt) + return left_smt * right_smt if expr.operator == BinaryOperation.Operator.DIV: - return smt.BVSDiv(left_smt, right_smt) + return left_smt / right_smt # IntrinsicCall if isinstance(expr, IntrinsicCall): # Unary operators if expr.intrinsic == IntrinsicCall.Intrinsic.ABS: - zero = smt.BVZero(int_width) smt_arg = trans(expr.children[1]) - return smt.Ite(smt.BVSLT(smt_arg, zero), - smt.BVNeg(smt_arg), - smt_arg) + return z3.Abs(smt_arg) # Binary operators if expr.intrinsic in [IntrinsicCall.Intrinsic.SHIFTL, @@ -576,21 +578,48 @@ def trans(expr: Node): left_smt = trans(expr.children[1]) right_smt = trans(expr.children[2]) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: - return smt.BVLShl(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: - return smt.BVLShr(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: - return smt.BVAShr(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: - return smt.BVAnd(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: - return smt.BVOr(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: - return smt.BVXor(left_smt, right_smt) - # TODO: does BVSRem match the semantics of Fortran MOD? if expr.intrinsic == IntrinsicCall.Intrinsic.MOD: - return smt.BVSRem(left_smt, right_smt) + return left_smt % right_smt + + if use_bv: + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: + return left_smt << right_smt + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: + return z3.LShR(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: + return left_smt >> right_smt + if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: + return left_smt & right_smt + if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: + return left_smt | right_smt + if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: + return left_smt ^ right_smt + else: + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: + return z3.BV2Int(z3.Int2BV(left_smt, int_width) << + z3.Int2BV(right_smt, int_width), + is_signed = True) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: + return z3.BV2Int(z3.LShR( + z3.Int2BV(left_smt, int_width), + z3.Int2BV(right_smt, int_width)), + is_signed = True) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: + return z3.BV2Int(z3.Int2BV(left_smt, int_width) >> + z3.Int2BV(right_smt, int_width), + is_signed = True) + if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: + return z3.BV2Int(z3.Int2BV(left_smt, int_width) & + z3.Int2BV(right_smt, int_width), + is_signed = True) + if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: + return z3.BV2Int(z3.Int2BV(left_smt, int_width) | + z3.Int2BV(right_smt, int_width), + is_signed = True) + if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: + return z3.BV2Int(z3.Int2BV(left_smt, int_width) ^ + z3.Int2BV(right_smt, int_width), + is_signed = True) # N-ary operators if expr.intrinsic in [IntrinsicCall.Intrinsic.MIN, @@ -599,21 +628,24 @@ def trans(expr: Node): reduced = smt_args[0] for arg in smt_args[1:]: if expr.intrinsic == IntrinsicCall.Intrinsic.MIN: - reduced = smt.Ite(smt.BVSLT(reduced, arg), - reduced, arg) + reduced = z3.If(reduced < arg, reduced, arg) elif expr.intrinsic == IntrinsicCall.Intrinsic.MAX: - reduced = smt.Ite(smt.BVSLT(reduced, arg), - arg, reduced) + reduced = z3.If(reduced < arg, arg, reduced) return reduced # Fall through: return a fresh, unconstrained symbol - return smt.FreshSymbol(typename=bv_int_t) + if use_bv: + return z3.FreshConst(z3.BitVecSort(int_width)) + else: + return z3.FreshInt() return trans(expr_root) # Translate a scalar logical Fortran expression to SMT -def translate_logical_expr(expr_root: Node, int_width: int): +def translate_logical_expr(expr_root: Node, + int_width: int, + use_bv: bool) -> z3.BoolRef: def trans(expr: Node): # Check that type is a scalar logical type_ok = _is_scalar_logical(expr.datatype) @@ -621,21 +653,21 @@ def trans(expr: Node): # Literal if isinstance(expr, Literal) and type_ok: if expr.value == "true": - return smt.TRUE() + return z3.BoolVal(True) if expr.value == "false": - return smt.FALSE() + return z3.BoolVal(False) # Reference if (isinstance(expr, Reference) and not isinstance(expr, ArrayReference) and type_ok): - return smt.Symbol(expr.name, typename=smt_types.BOOL) + return z3.Bool(expr.name) # UnaryOperation if isinstance(expr, UnaryOperation): arg_smt = trans(expr.operand) if expr.operator == UnaryOperation.Operator.NOT: - return smt.Not(arg_smt) + return z3.Not(arg_smt) # BinaryOperation if isinstance(expr, BinaryOperation): @@ -649,13 +681,13 @@ def trans(expr: Node): right_smt = trans(right) if expr.operator == BinaryOperation.Operator.AND: - return smt.And(left_smt, right_smt) + return z3.And(left_smt, right_smt) if expr.operator == BinaryOperation.Operator.OR: - return smt.Or(left_smt, right_smt) + return z3.Or(left_smt, right_smt) if expr.operator == BinaryOperation.Operator.EQV: - return smt.Iff(left_smt, right_smt) + return left_smt == right_smt if expr.operator == BinaryOperation.Operator.NEQV: - return smt.Not(smt.Iff(left_smt, right_smt)) + return left_smt != right_smt # Operands are numbers if expr.operator in [BinaryOperation.Operator.EQ, @@ -665,24 +697,24 @@ def trans(expr: Node): BinaryOperation.Operator.GE, BinaryOperation.Operator.LE]: (left, right) = expr.operands - left_smt = translate_integer_expr(left, int_width) - right_smt = translate_integer_expr(right, int_width) + left_smt = translate_integer_expr(left, int_width, use_bv) + right_smt = translate_integer_expr(right, int_width, use_bv) if expr.operator == BinaryOperation.Operator.EQ: - return smt.Equals(left_smt, right_smt) + return left_smt == right_smt if expr.operator == BinaryOperation.Operator.NE: - return smt.NotEquals(left_smt, right_smt) + return left_smt != right_smt if expr.operator == BinaryOperation.Operator.GT: - return smt.BVSGT(left_smt, right_smt) + return left_smt > right_smt if expr.operator == BinaryOperation.Operator.LT: - return smt.BVSLT(left_smt, right_smt) + return left_smt < right_smt if expr.operator == BinaryOperation.Operator.GE: - return smt.BVSGE(left_smt, right_smt) + return left_smt >= right_smt if expr.operator == BinaryOperation.Operator.LE: - return smt.BVSLE(left_smt, right_smt) + return left_smt <= right_smt # Fall through: return a fresh, unconstrained symbol - return smt.FreshSymbol(typename=smt_types.BOOL) + return z3.FreshBool() return trans(expr_root) diff --git a/src/psyclone/tests/psyir/nodes/omp_directives_test.py b/src/psyclone/tests/psyir/nodes/omp_directives_test.py index 1dd5b524bf..a063d0542b 100644 --- a/src/psyclone/tests/psyir/nodes/omp_directives_test.py +++ b/src/psyclone/tests/psyir/nodes/omp_directives_test.py @@ -74,7 +74,6 @@ LFRicOMPLoopTrans, OMPParallelTrans, OMPParallelLoopTrans, LFRicOMPParallelLoopTrans, OMPSingleTrans, OMPMasterTrans, OMPLoopTrans, TransformationError) -from pysmt.exceptions import NoSolverAvailableError BASE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(__file__)))), "test_files", "lfric") @@ -5384,9 +5383,6 @@ def test_array_analysis_option(fortran_reader, fortran_writer): end subroutine my_matmul''') omplooptrans = OMPLoopTrans(omp_directive="paralleldo") loop = psyir.walk(Loop)[0] - try: - omplooptrans.apply(loop, collapse=True, use_smt_array_anal=True) - output = fortran_writer(psyir) - assert "collapse(2)" in output - except NoSolverAvailableError: - pass + omplooptrans.apply(loop, collapse=True, use_smt_array_anal=True) + output = fortran_writer(psyir) + assert "collapse(2)" in output diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py index 09cda0b778..853012a6aa 100644 --- a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -40,8 +40,7 @@ from psyclone.psyir.symbols import Symbol from psyclone.psyir.tools import ArrayIndexAnalysis from psyclone.psyir.tools.array_index_analysis import translate_logical_expr -import pysmt.shortcuts as smt -from pysmt.exceptions import NoSolverAvailableError +import z3 # ----------------------------------------------------------------------------- @@ -60,13 +59,10 @@ def test_reverse(fortran_reader, fortran_writer): arr(n+1-i) = tmp end do end subroutine''') - try: - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [True] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] # ----------------------------------------------------------------------------- @@ -87,13 +83,10 @@ def test_odd_even_trans(fortran_reader, fortran_writer): end if end do end subroutine''') - try: - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [True] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] # ----------------------------------------------------------------------------- @@ -126,13 +119,10 @@ def test_tiled_matmul(fortran_reader, fortran_writer): enddo enddo end subroutine my_matmul''') - try: - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [True, True, False, True, True, False] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True, True, False, True, True, False] # ----------------------------------------------------------------------------- @@ -155,13 +145,10 @@ def test_flatten1(fortran_reader, fortran_writer): end do end do end subroutine''') - try: - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [False, True] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [False, True] # ----------------------------------------------------------------------------- @@ -179,17 +166,15 @@ def test_flatten2(fortran_reader, fortran_writer): arr(i) = mat(mod(i, ny), i/ny) end do end subroutine''') - try: - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [True] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + assert results == [True] # ----------------------------------------------------------------------------- -def test_translate_expr(fortran_reader, fortran_writer): +@pytest.mark.parametrize("use_bv", [True, False]) +def test_translate_expr(use_bv, fortran_reader, fortran_writer): '''Test that Fortran expressions are being correctly translated to SMT. ''' def test(expr): @@ -199,12 +184,9 @@ def test(expr): x = {expr} end subroutine''') for assign in psyir.walk(Assignment): - rhs_smt = translate_logical_expr(assign.rhs, 32) - try: - assert smt.is_sat(rhs_smt) is True - except NoSolverAvailableError: - pass - return + rhs_smt = translate_logical_expr(assign.rhs, 32, use_bv) + solver = z3.Solver() + assert solver.check(rhs_smt) == z3.sat test("+1 == 1") test("abs(-1) == 1") @@ -240,14 +222,11 @@ def check_conflict_free(fortran_reader, loop_str, yesno): logical :: ok {loop_str} end subroutine''') - try: - results = [] - for loop in psyir.walk(Loop): - analysis = ArrayIndexAnalysis() - results.append(analysis.is_loop_conflict_free(loop)) - assert results == [yesno] - except NoSolverAvailableError: - pass + results = [] + for loop in psyir.walk(Loop): + analysis = ArrayIndexAnalysis() + results.append(analysis.is_loop_conflict_free(loop)) + assert results == [yesno] # ----------------------------------------------------------------------------- From f46740f9baca2d0a8cfb5a7abc8da7982c157731 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Thu, 6 Nov 2025 15:44:38 +0000 Subject: [PATCH 03/17] Add more SMT-related options to ParallelLoopTrans --- .../transformations/parallel_loop_trans.py | 29 ++++++++++++++++--- .../transformations/transformations_test.py | 4 +-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index 6c3ab67156..a3996c1a53 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -178,6 +178,8 @@ def validate(self, node, options=None, **kwargs): use_smt_array_anal = self.get_option( "use_smt_array_anal", **kwargs) smt_timeout_ms = self.get_option("smt_timeout_ms", **kwargs) + smt_use_bv = self.get_option("smt_use_bv", **kwargs) + smt_int_width = self.get_option("smt_int_width", **kwargs) else: verbose = options.get("verbose", False) collapse = options.get("collapse", False) @@ -190,6 +192,8 @@ def validate(self, node, options=None, **kwargs): reduction_ops = options.get("reduction_ops", []) use_smt_array_anal = options.get("use_smt_array_anal", False) smt_timeout_ms = options.get("smt_timeout_ms", 5000) + smt_use_bv = options.get("smt_use_bv", True) + smt_int_width = options.get("smt_int_width", 32) # Check type of reduction_ops (not handled by validate_options) if not isinstance(reduction_ops, list): @@ -314,7 +318,9 @@ def validate(self, node, options=None, **kwargs): # Try using the ArrayIndexAnalysis to prove that the # dependency errors are false arr_anal = ArrayIndexAnalysis( - smt_timeout_ms=smt_timeout_ms) + smt_timeout_ms=smt_timeout_ms, + use_bv=smt_use_bv, + int_width=smt_int_width) if arr_anal.is_loop_conflict_free(node): errors = [] @@ -348,6 +354,8 @@ def apply(self, node, options=None, verbose: bool = False, IntrinsicCall.Intrinsic]] = None, use_smt_array_anal: bool = False, smt_timeout_ms: int = 5000, + smt_use_bv: bool = True, + smt_int_width: int = 32, **kwargs): ''' Apply the Loop transformation to the specified node in a @@ -395,6 +403,10 @@ def apply(self, node, options=None, verbose: bool = False, :param bool use_smt_array_anal: whether to use the SMT-based ArrayIndexAnalysis to discharge false dependency errors. :param bool smt_timeout_ms: SMT solver timeout in milliseconds. + :param bool smt_use_bv: use bit vectors or arbitary precision integers + in the SMT solver? + :param bool smt_int_width: width of Fortran integers in bits for + the SMT solver. ''' if not options: @@ -405,7 +417,10 @@ def apply(self, node, options=None, verbose: bool = False, sequential=sequential, nowait=nowait, reduction_ops=reduction_ops, use_smt_array_anal=use_smt_array_anal, - smt_timeout_ms=smt_timeout_ms, **kwargs + smt_timeout_ms=smt_timeout_ms, + smt_use_bv=smt_use_bv, + smt_int_width=smt_int_width, + **kwargs ) # Rename the input options that are renamed in this apply method. # TODO 2668, rename options to be consistent. @@ -427,7 +442,9 @@ def apply(self, node, options=None, verbose: bool = False, nowait = options.get("nowait", False) reduction_ops = options.get("reduction_ops", []) use_smt_array_anal = options.get("use_smt_array_anal", False) - smt_timeout_ms = options.get("smt_timeout_ms", False) + smt_timeout_ms = options.get("smt_timeout_ms", 5000) + smt_timeout_ms = options.get("smt_use_bv", True) + smt_int_width = options.get("smt_int_width", 32) self.validate(node, options=options, verbose=verbose, collapse=collapse, force=force, @@ -437,6 +454,8 @@ def apply(self, node, options=None, verbose: bool = False, reduction_ops=reduction_ops, use_smt_array_anal=use_smt_array_anal, smt_timeout_ms=smt_timeout_ms, + smt_use_bv=smt_use_bv, + smt_int_width=smt_int_width, **kwargs) list_of_signatures = [Signature(name) for name in list_of_names] @@ -523,7 +542,9 @@ def apply(self, node, options=None, verbose: bool = False, [msg.code == DTCode.ERROR_DEPENDENCY for msg in msgs]) arr_anal = ArrayIndexAnalysis( - smt_timeout_ms=smt_timeout_ms) + smt_timeout_ms=smt_timeout_ms, + use_bv=smt_use_bv, + int_width=smt_int_width) discharge_errors = ( all_dep_errors and arr_anal.is_loop_conflict_free(next_loop)) diff --git a/src/psyclone/tests/psyir/transformations/transformations_test.py b/src/psyclone/tests/psyir/transformations/transformations_test.py index 9700875f14..b2c27d7895 100644 --- a/src/psyclone/tests/psyir/transformations/transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/transformations_test.py @@ -578,8 +578,8 @@ def test_omploop_trans_new_options(sample_psyir): "'fakeoption2']. Valid options are '['node_type_check', " "'verbose', 'collapse', 'force', 'ignore_dependencies_for', " "'privatise_arrays', 'sequential', 'nowait', 'reduction_ops', " - "'use_smt_array_anal', 'smt_timeout_ms', " - "'options', 'reprod', 'enable_reductions']." + "'use_smt_array_anal', 'smt_timeout_ms', 'smt_use_bv', " + "'smt_int_width', 'options', 'reprod', 'enable_reductions']." in str(excinfo.value)) # Check we get the relevant error message when submitting multiple From 6385bd2abdbe07a93e00285b87881d33efebab55 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Fri, 7 Nov 2025 10:31:46 +0000 Subject: [PATCH 04/17] New analysis option to prohibit bit vector overflow --- .../psyir/tools/array_index_analysis.py | 158 ++++++++++++------ .../transformations/parallel_loop_trans.py | 44 ++--- .../psyir/tools/array_index_analysis_test.py | 3 +- .../transformations/transformations_test.py | 4 +- 4 files changed, 128 insertions(+), 81 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 58941c63c8..95f4683536 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -130,17 +130,23 @@ # global constraint set, is unsatisfiable. In this way, we can check every # access pair and determine whether or not the loop is conflict free. + class ArrayIndexAnalysis: - def __init__(self, - int_width : int = 32, - use_bv : int = True, - smt_timeout_ms : int = 5000): - # Set SMT solver timeout in milliseconds - self.smt_timeout = smt_timeout_ms - # Fortran integer width in bits - self.int_width = int_width - # Use fixed-width bit vectors or arbirary precision integers? - self.use_bv = use_bv + # Class representing analysis options + class Options: + def __init__(self, + int_width: int = 32, + use_bv: int = True, + smt_timeout_ms: int = 5000, + prohibit_overflow: bool = False): + # Set SMT solver timeout in milliseconds + self.smt_timeout = smt_timeout_ms + # Fortran integer width in bits + self.int_width = int_width + # Use fixed-width bit vectors or arbirary precision integers? + self.use_bv = use_bv + # Prohibit bit-vector overflow when solving constraints? + self.prohibit_overflow = prohibit_overflow # Class representing an array access class ArrayAccess: @@ -155,9 +161,16 @@ def __init__(self, self.is_write = is_write # SMT expressions representing the indices of the access self.indices = indices - # PSyIR Node for the access (useful for error reporting) + # PSyIR node for the access (useful for error reporting) self.psyir_node = psyir_node + # ArrayIndexAnalysis constructor + def __init__(self, options=Options()): + self.smt_timeout = options.smt_timeout + self.int_width = options.int_width + self.use_bv = options.use_bv + self.prohibit_overflow = options.prohibit_overflow + # Initialise analysis def init_analysis(self): # The substitution maps integer and logical Fortran variables @@ -200,7 +213,7 @@ def integer_var(self, var) -> z3.ExprRef: return z3.Int(var) # Create an SMT integer value - def integer_val(self, val : int) -> z3.ExprRef: + def integer_val(self, val: int) -> z3.ExprRef: if self.use_bv: return z3.BitVecVal(val, self.int_width) else: @@ -257,13 +270,28 @@ def add_logical_assignment(self, var: str, smt_expr: z3.BoolRef): # Translate integer expresison to SMT, and apply current substitution def translate_integer_expr_with_subst(self, expr: z3.ExprRef): - smt_expr = translate_integer_expr(expr, self.int_width, self.use_bv) + (smt_expr, prohibit_overflow) = translate_integer_expr( + expr, self.int_width, self.use_bv) subst_pairs = list(self.subst.items()) + if self.prohibit_overflow: + self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) return z3.substitute(smt_expr, *subst_pairs) # Translate logical expresison to SMT, and apply current substitution def translate_logical_expr_with_subst(self, expr: z3.BoolRef): - smt_expr = translate_logical_expr(expr, self.int_width, self.use_bv) + (smt_expr, prohibit_overflow) = translate_logical_expr( + expr, self.int_width, self.use_bv) + subst_pairs = list(self.subst.items()) + if self.prohibit_overflow: + self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) + return z3.substitute(smt_expr, *subst_pairs) + + # Translate conditional expresison to SMT, and apply current substitution + def translate_cond_expr_with_subst(self, expr: z3.BoolRef): + (smt_expr, prohibit_overflow) = translate_logical_expr( + expr, self.int_width, self.use_bv) + if self.prohibit_overflow: + smt_expr = z3.And(smt_expr, prohibit_overflow) subst_pairs = list(self.subst.items()) return z3.substitute(smt_expr, *subst_pairs) @@ -393,12 +421,12 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: solver.set("timeout", self.smt_timeout) solver.add(z3.And(*self.constraints, z3.Or(*conflicts))) result = solver.check() - if result == z3.unknown: - return None # pragma: no cover + if result == z3.unsat: + return True elif result == z3.sat: - return False + return False else: - return True + return None # pragma: no cover # Analyse a single statement def step(self, stmt: Node, cond: z3.BoolRef): @@ -432,8 +460,7 @@ def step(self, stmt: Node, cond: z3.BoolRef): if self.in_loop_to_parallelise: self.add_all_array_accesses(stmt.condition, cond) # Translate condition to SMT - smt_condition = self.translate_logical_expr_with_subst( - stmt.condition) + smt_condition = self.translate_cond_expr_with_subst(stmt.condition) # Recursively step into 'then' if stmt.if_body: self.save_subst() @@ -490,19 +517,19 @@ def step(self, stmt: Node, cond: z3.BoolRef): # WhileLoop if isinstance(stmt, WhileLoop): - # Kill variables written by loop body - self.kill_all_written_vars(stmt.loop_body) - # Add array accesses in condition - if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.condition, cond) - # Translate condition to SMT - smt_condition = self.translate_logical_expr_with_subst( - stmt.condition) - # Recursively step into loop body - self.save_subst() - self.step(stmt.loop_body, z3.And(cond, smt_condition)) - self.restore_subst() - return + # Kill variables written by loop body + self.kill_all_written_vars(stmt.loop_body) + # Add array accesses in condition + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.condition, cond) + # Translate condition to SMT + smt_condition = self.translate_cond_expr_with_subst( + stmt.condition) + # Recursively step into loop body + self.save_subst() + self.step(stmt.loop_body, z3.And(cond, smt_condition)) + self.restore_subst() + return # Fall through if self.in_loop_to_parallelise: @@ -513,11 +540,14 @@ def step(self, stmt: Node, cond: z3.BoolRef): # =============================================== -# Translate a scalar integer Fortran expression to SMT +# Translate a scalar integer Fortran expression to SMT. In addition, +# return a constraint that prohibits bit vector overflow in the expression. def translate_integer_expr(expr_root: Node, int_width: int, - use_bv: bool) -> z3.ExprRef: - def trans(expr: Node): + use_bv: bool) -> (z3.ExprRef, z3.BoolRef): + constraints = [] + + def trans(expr: Node) -> z3.ExprRef: # Check that type is a scalar integer of unspecified precision type_ok = _is_scalar_integer(expr.datatype) @@ -541,6 +571,8 @@ def trans(expr: Node): if isinstance(expr, UnaryOperation): arg_smt = trans(expr.operand) if expr.operator == UnaryOperation.Operator.MINUS: + if use_bv: + constraints.append(z3.BVSNegNoOverflow(arg_smt)) return -arg_smt if expr.operator == UnaryOperation.Operator.PLUS: return arg_smt @@ -552,12 +584,30 @@ def trans(expr: Node): right_smt = trans(right) if expr.operator == BinaryOperation.Operator.ADD: + if use_bv: + constraints.append(z3.BVAddNoOverflow( + left_smt, right_smt, True)) + constraints.append(z3.BVAddNoUnderflow( + left_smt, right_smt)) return left_smt + right_smt if expr.operator == BinaryOperation.Operator.SUB: + if use_bv: + constraints.append(z3.BVSubNoOverflow( + left_smt, right_smt)) + constraints.append(z3.BVSubNoUnderflow( + left_smt, right_smt, True)) return left_smt - right_smt if expr.operator == BinaryOperation.Operator.MUL: + if use_bv: + constraints.append(z3.BVMulNoOverflow( + left_smt, right_smt, True)) + constraints.append(z3.BVMulNoUnderflow( + left_smt, right_smt)) return left_smt * right_smt if expr.operator == BinaryOperation.Operator.DIV: + if use_bv: + constraints.append(z3.BVSDivNoOverflow( + left_smt, right_smt)) return left_smt / right_smt # IntrinsicCall @@ -565,6 +615,8 @@ def trans(expr: Node): # Unary operators if expr.intrinsic == IntrinsicCall.Intrinsic.ABS: smt_arg = trans(expr.children[1]) + if use_bv: + constraints.append(z3.BVSNegNoOverflow(smt_arg)) return z3.Abs(smt_arg) # Binary operators @@ -598,28 +650,28 @@ def trans(expr: Node): if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: return z3.BV2Int(z3.Int2BV(left_smt, int_width) << z3.Int2BV(right_smt, int_width), - is_signed = True) + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: return z3.BV2Int(z3.LShR( z3.Int2BV(left_smt, int_width), z3.Int2BV(right_smt, int_width)), - is_signed = True) + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: return z3.BV2Int(z3.Int2BV(left_smt, int_width) >> z3.Int2BV(right_smt, int_width), - is_signed = True) + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: return z3.BV2Int(z3.Int2BV(left_smt, int_width) & z3.Int2BV(right_smt, int_width), - is_signed = True) + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: return z3.BV2Int(z3.Int2BV(left_smt, int_width) | z3.Int2BV(right_smt, int_width), - is_signed = True) + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: return z3.BV2Int(z3.Int2BV(left_smt, int_width) ^ z3.Int2BV(right_smt, int_width), - is_signed = True) + is_signed=True) # N-ary operators if expr.intrinsic in [IntrinsicCall.Intrinsic.MIN, @@ -639,13 +691,18 @@ def trans(expr: Node): else: return z3.FreshInt() - return trans(expr_root) + expr_root_smt = trans(expr_root) + return (expr_root_smt, z3.And(*constraints)) -# Translate a scalar logical Fortran expression to SMT +# Translate a scalar logical Fortran expression to SMT. In addition, +# return a constraint that prohibits bit vector overflow in the expression. def translate_logical_expr(expr_root: Node, int_width: int, - use_bv: bool) -> z3.BoolRef: + use_bv: bool) -> (z3.BoolRef, z3.BoolRef): + # Constraints to prohibit bit-vector overflow + overflow = [] + def trans(expr: Node): # Check that type is a scalar logical type_ok = _is_scalar_logical(expr.datatype) @@ -697,8 +754,12 @@ def trans(expr: Node): BinaryOperation.Operator.GE, BinaryOperation.Operator.LE]: (left, right) = expr.operands - left_smt = translate_integer_expr(left, int_width, use_bv) - right_smt = translate_integer_expr(right, int_width, use_bv) + (left_smt, prohibit_overflow) = translate_integer_expr( + left, int_width, use_bv) + overflow.append(prohibit_overflow) + (right_smt, prohibit_overflow) = translate_integer_expr( + right, int_width, use_bv) + overflow.append(prohibit_overflow) if expr.operator == BinaryOperation.Operator.EQ: return left_smt == right_smt @@ -716,7 +777,8 @@ def trans(expr: Node): # Fall through: return a fresh, unconstrained symbol return z3.FreshBool() - return trans(expr_root) + expr_root_smt = trans(expr_root) + return (expr_root_smt, z3.And(*overflow)) # Helper functions # ================ diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index a3996c1a53..9309c02731 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -177,9 +177,8 @@ def validate(self, node, options=None, **kwargs): reduction_ops = [] use_smt_array_anal = self.get_option( "use_smt_array_anal", **kwargs) - smt_timeout_ms = self.get_option("smt_timeout_ms", **kwargs) - smt_use_bv = self.get_option("smt_use_bv", **kwargs) - smt_int_width = self.get_option("smt_int_width", **kwargs) + smt_array_anal_options = self.get_option( + "smt_array_anal_options", **kwargs) else: verbose = options.get("verbose", False) collapse = options.get("collapse", False) @@ -191,9 +190,8 @@ def validate(self, node, options=None, **kwargs): privatise_arrays = options.get("privatise_arrays", False) reduction_ops = options.get("reduction_ops", []) use_smt_array_anal = options.get("use_smt_array_anal", False) - smt_timeout_ms = options.get("smt_timeout_ms", 5000) - smt_use_bv = options.get("smt_use_bv", True) - smt_int_width = options.get("smt_int_width", 32) + smt_array_anal_options = options.get( + "smt_array_anal_options", ArrayIndexAnalysis.Options()) # Check type of reduction_ops (not handled by validate_options) if not isinstance(reduction_ops, list): @@ -317,10 +315,7 @@ def validate(self, node, options=None, **kwargs): if len(errors) > 0 and len(errors) == num_depedency_errors: # Try using the ArrayIndexAnalysis to prove that the # dependency errors are false - arr_anal = ArrayIndexAnalysis( - smt_timeout_ms=smt_timeout_ms, - use_bv=smt_use_bv, - int_width=smt_int_width) + arr_anal = ArrayIndexAnalysis(smt_array_anal_options) if arr_anal.is_loop_conflict_free(node): errors = [] @@ -353,9 +348,8 @@ def apply(self, node, options=None, verbose: bool = False, reduction_ops: List[Union[BinaryOperation.Operator, IntrinsicCall.Intrinsic]] = None, use_smt_array_anal: bool = False, - smt_timeout_ms: int = 5000, - smt_use_bv: bool = True, - smt_int_width: int = 32, + smt_array_anal_options: + ArrayIndexAnalysis.Options = ArrayIndexAnalysis.Options(), **kwargs): ''' Apply the Loop transformation to the specified node in a @@ -402,11 +396,8 @@ def apply(self, node, options=None, verbose: bool = False, the reduction operators in the list. :param bool use_smt_array_anal: whether to use the SMT-based ArrayIndexAnalysis to discharge false dependency errors. - :param bool smt_timeout_ms: SMT solver timeout in milliseconds. - :param bool smt_use_bv: use bit vectors or arbitary precision integers - in the SMT solver? - :param bool smt_int_width: width of Fortran integers in bits for - the SMT solver. + :param bool smt_array_anal_options: options for the array index + analysis. ''' if not options: @@ -417,9 +408,7 @@ def apply(self, node, options=None, verbose: bool = False, sequential=sequential, nowait=nowait, reduction_ops=reduction_ops, use_smt_array_anal=use_smt_array_anal, - smt_timeout_ms=smt_timeout_ms, - smt_use_bv=smt_use_bv, - smt_int_width=smt_int_width, + smt_array_anal_options=smt_array_anal_options, **kwargs ) # Rename the input options that are renamed in this apply method. @@ -442,9 +431,8 @@ def apply(self, node, options=None, verbose: bool = False, nowait = options.get("nowait", False) reduction_ops = options.get("reduction_ops", []) use_smt_array_anal = options.get("use_smt_array_anal", False) - smt_timeout_ms = options.get("smt_timeout_ms", 5000) - smt_timeout_ms = options.get("smt_use_bv", True) - smt_int_width = options.get("smt_int_width", 32) + smt_array_anal_options = options.get( + "smt_array_anal_options", ArrayIndexAnalysis.Options()) self.validate(node, options=options, verbose=verbose, collapse=collapse, force=force, @@ -453,9 +441,7 @@ def apply(self, node, options=None, verbose: bool = False, sequential=sequential, nowait=nowait, reduction_ops=reduction_ops, use_smt_array_anal=use_smt_array_anal, - smt_timeout_ms=smt_timeout_ms, - smt_use_bv=smt_use_bv, - smt_int_width=smt_int_width, + smt_array_anal_options=smt_array_anal_options, **kwargs) list_of_signatures = [Signature(name) for name in list_of_names] @@ -542,9 +528,7 @@ def apply(self, node, options=None, verbose: bool = False, [msg.code == DTCode.ERROR_DEPENDENCY for msg in msgs]) arr_anal = ArrayIndexAnalysis( - smt_timeout_ms=smt_timeout_ms, - use_bv=smt_use_bv, - int_width=smt_int_width) + smt_array_anal_options) discharge_errors = ( all_dep_errors and arr_anal.is_loop_conflict_free(next_loop)) diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py index 853012a6aa..c2ea821892 100644 --- a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -184,7 +184,8 @@ def test(expr): x = {expr} end subroutine''') for assign in psyir.walk(Assignment): - rhs_smt = translate_logical_expr(assign.rhs, 32, use_bv) + (rhs_smt, prohibit_overflow) = translate_logical_expr( + assign.rhs, 32, use_bv) solver = z3.Solver() assert solver.check(rhs_smt) == z3.sat diff --git a/src/psyclone/tests/psyir/transformations/transformations_test.py b/src/psyclone/tests/psyir/transformations/transformations_test.py index b2c27d7895..ad5a8f7fd3 100644 --- a/src/psyclone/tests/psyir/transformations/transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/transformations_test.py @@ -578,8 +578,8 @@ def test_omploop_trans_new_options(sample_psyir): "'fakeoption2']. Valid options are '['node_type_check', " "'verbose', 'collapse', 'force', 'ignore_dependencies_for', " "'privatise_arrays', 'sequential', 'nowait', 'reduction_ops', " - "'use_smt_array_anal', 'smt_timeout_ms', 'smt_use_bv', " - "'smt_int_width', 'options', 'reprod', 'enable_reductions']." + "'use_smt_array_anal', 'smt_array_anal_options', " + "'options', 'reprod', 'enable_reductions']." in str(excinfo.value)) # Check we get the relevant error message when submitting multiple From 7567511e92b4496758aa3de7954644b54886cd4a Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Mon, 10 Nov 2025 14:09:15 +0000 Subject: [PATCH 05/17] Add z3-solver to setup.py --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 747aad367f..8db17d7c01 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,8 @@ def get_files(directory, install_path, valid_suffixes): packages=PACKAGES, package_dir={"": "src"}, install_requires=['pyparsing', 'fparser>=0.2.1', 'configparser', - 'sympy', "Jinja2", 'termcolor', 'graphviz'], + 'sympy', "Jinja2", 'termcolor', 'graphviz', + 'z3-solver'], extras_require={ 'doc': ["sphinx", "sphinxcontrib.bibtex", "sphinx_design", "pydata-sphinx-theme", "sphinx-autodoc-typehints", From c04c9141c469c398a42c3cf48148fe6f55beed79 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Fri, 14 Nov 2025 15:52:42 +0000 Subject: [PATCH 06/17] Handle structures correctly --- .../psyir/tools/array_index_analysis.py | 163 ++++++++++-------- .../transformations/parallel_loop_trans.py | 6 +- 2 files changed, 95 insertions(+), 74 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 95f4683536..22180edeb4 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -52,8 +52,13 @@ # The analysis class provides a method 'is_loop_conflict_free()' to decide # whether or not the array accesses in a given loop are conflicting between # iterations. Two array accesses are conflicting if they access the same -# element of the same array, and at least one of the accesses is a write. The -# analysis algorithm operates, broadly, as follows. +# element of the same array, and at least one of the accesses is a write. +# +# The analysis assumes that any scalar integer or scalar logical variables +# written by the loop can safely be considered as private within each +# iteration. This should be validated by the callee. +# +# The analysis algorithm operates, broadly, as follows. # # Given a loop, we find its enclosing routine, and start analysing the routine # statement-by-statement in a recursive-descent fashion. @@ -153,7 +158,7 @@ class ArrayAccess: def __init__(self, cond: z3.BoolRef, is_write: bool, - indices: list[z3.ExprRef], + indices: list[list[z3.ExprRef]], psyir_node: Node): # The condition at the location of the access self.cond = cond @@ -236,13 +241,16 @@ def kill_all_written_vars(self, node: Node): var_accesses = node.reference_accesses() for sig, access_seq in var_accesses.items(): for access_info in access_seq.all_write_accesses: - sym = self.routine.symbol_table.lookup(sig.var_name) - if sym.is_unresolved: - continue # pragma: no cover - elif _is_scalar_integer(sym.datatype): + if isinstance(access_info.node, Loop): self.kill_integer_var(sig.var_name) - elif _is_scalar_logical(sym.datatype): - self.kill_logical_var(sig.var_name) + break + elif isinstance(access_info.node, Reference): + if _is_scalar_integer(access_info.node.datatype): + self.kill_integer_var(sig.var_name) + break + elif _is_scalar_logical(access_info.node.datatype): + self.kill_logical_var(sig.var_name) + break # Add the SMT constraint to the constraint set def add_constraint(self, smt_expr: z3.BoolRef): @@ -319,7 +327,7 @@ def add_array_access(self, array_name: str, is_write: bool, cond: z3.BoolRef, - indices: list[z3.ExprRef], + indices: list[list[z3.ExprRef]], psyir_node: Node): access = ArrayIndexAnalysis.ArrayAccess( cond, is_write, indices, psyir_node) @@ -332,36 +340,32 @@ def add_all_array_accesses(self, node: Node, cond: z3.BoolRef): var_accesses = node.reference_accesses() for sig, access_seq in var_accesses.items(): for access_info in access_seq: - if access_info.is_data_access: - - # ArrayReference - if isinstance(access_info.node, ArrayReference): - indices = [] - for index in access_info.node.indices: - if isinstance(index, Range): - var = self.fresh_integer_var() - self.constrain_loop_var( - var, index.start, index.stop, index.step) - indices.append(var) - else: - indices.append( - self.translate_integer_expr_with_subst( - index)) + if isinstance(access_info.node, Reference): + (_, indices) = access_info.node.get_signature_and_indices() + indices_flat = [i for inds in indices for i in inds] + is_array_access = ( + access_info.is_data_access and + (indices_flat != [] or + isinstance(access_info.node.datatype, ArrayType))) + if is_array_access: + smt_indices = [] + for inds in indices: + smt_inds = [] + for ind in inds: + if isinstance(ind, Range): + var = self.fresh_integer_var() + self.constrain_loop_var( + var, ind.start, ind.stop, ind.step) + smt_inds.append(var) + else: + smt_inds.append( + self.translate_integer_expr_with_subst( + ind)) + smt_indices.append(smt_inds) self.add_array_access( - sig.var_name, + str(sig), access_info.is_any_write(), - cond, indices, access_info.node) - - # Reference with datatype ArrayType - elif (isinstance(access_info.node, Reference) and - isinstance(access_info.node.datatype, ArrayType)): - indices = [] - for index in access_info.node.datatype.shape: - var = self.fresh_integer_var() - indices.append(var) - self.add_array_access( - sig.var_name, access_info.is_any_write(), - cond, indices, access_info.node) + cond, smt_indices, access_info.node) # Move the current access dict to the stack, and proceed with an empty one def save_access_dict(self): @@ -398,23 +402,26 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: iter_i = self.saved_access_dicts[0] iter_j = self.saved_access_dicts[1] conflicts = [] - for (arr_name, i_accesses) in iter_i.items(): - j_accesses = iter_j[arr_name] - # For each write access in the i iteration - for i_access in i_accesses: - if i_access.is_write: - # Check for conflicts against every access in the - # j iteration - for j_access in j_accesses: - assert len(i_access.indices) == len(j_access.indices) - indices_equal = [] - for (i_idx, j_idx) in zip(i_access.indices, - j_access.indices): - indices_equal.append(i_idx == j_idx) - conflicts.append(z3.And( - *indices_equal, - i_access.cond, - j_access.cond)) + for (i_arr_name, i_accesses) in iter_i.items(): + for (j_arr_name, j_accesses) in iter_j.items(): + if (i_arr_name == j_arr_name or + i_arr_name.startswith(j_arr_name + "%") or + j_arr_name.startswith(i_arr_name + "%")): + # For each write access in the i iteration + for i_access in i_accesses: + if i_access.is_write: + # Check for conflicts against every access in the + # j iteration + for j_access in j_accesses: + indices_equal = [] + for (i_idxs, j_idxs) in zip(i_access.indices, + j_access.indices): + for (i_idx, j_idx) in zip(i_idxs, j_idxs): + indices_equal.append(i_idx == j_idx) + conflicts.append(z3.And( + *indices_equal, + i_access.cond, + j_access.cond)) # Invoke Z3 solver with a timeout solver = z3.Solver() @@ -436,18 +443,24 @@ def step(self, stmt: Node, cond: z3.BoolRef): # Assignment if isinstance(stmt, Assignment): - if (isinstance(stmt.lhs, Reference) - and not isinstance(stmt.lhs, ArrayReference)): - if _is_scalar_integer(stmt.lhs.datatype): - rhs_smt = self.translate_integer_expr_with_subst(stmt.rhs) - self.add_integer_assignment(stmt.lhs.name, rhs_smt) - if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.rhs, cond) - return - elif _is_scalar_logical(stmt.lhs.datatype): - rhs_smt = self.translate_logical_expr_with_subst(stmt.rhs) - self.add_logical_assignment(stmt.lhs.name, rhs_smt) - return + if isinstance(stmt.lhs, Reference): + (sig, indices) = stmt.lhs.get_signature_and_indices() + indices_flat = [i for inds in indices for i in inds] + if indices_flat == [] and len(sig) == 1: + if _is_scalar_integer(stmt.lhs.datatype): + rhs_smt = self.translate_integer_expr_with_subst( + stmt.rhs) + self.add_integer_assignment(sig.var_name, rhs_smt) + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.rhs, cond) + return + elif _is_scalar_logical(stmt.lhs.datatype): + rhs_smt = self.translate_logical_expr_with_subst( + stmt.rhs) + self.add_logical_assignment(sig.var_name, rhs_smt) + if self.in_loop_to_parallelise: + self.add_all_array_accesses(stmt.rhs, cond) + return # Schedule if isinstance(stmt, Schedule): @@ -562,10 +575,13 @@ def trans(expr: Node) -> z3.ExprRef: if (isinstance(expr, Reference) and not isinstance(expr, ArrayReference) and type_ok): - if use_bv: - return z3.BitVec(expr.name, int_width) - else: - return z3.Int(expr.name) + (sig, indices) = expr.get_signature_and_indices() + indices_flat = [i for inds in indices for i in inds] + if indices_flat == []: + if use_bv: + return z3.BitVec(str(sig), int_width) + else: + return z3.Int(str(sig)) # UnaryOperation if isinstance(expr, UnaryOperation): @@ -718,7 +734,10 @@ def trans(expr: Node): if (isinstance(expr, Reference) and not isinstance(expr, ArrayReference) and type_ok): - return z3.Bool(expr.name) + (sig, indices) = expr.get_signature_and_indices() + indices_flat = [i for inds in indices for i in inds] + if indices_flat == []: + return z3.Bool(str(sig)) # UnaryOperation if isinstance(expr, UnaryOperation): diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index 9309c02731..e8cbda9207 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -305,7 +305,8 @@ def validate(self, node, options=None, **kwargs): self.inferred_reduction_clauses.append(clause) continue - if (message.code == DTCode.ERROR_DEPENDENCY): + if (message.code == DTCode.ERROR_DEPENDENCY or + message.code == DTCode.ERROR_WRITE_WRITE_RACE): num_depedency_errors = num_depedency_errors + 1 errors.append(str(message)) @@ -525,7 +526,8 @@ def apply(self, node, options=None, verbose: bool = False, if use_smt_array_anal: all_dep_errors = all( - [msg.code == DTCode.ERROR_DEPENDENCY + [msg.code == DTCode.ERROR_DEPENDENCY or + msg.code == DTCode.ERROR_WRITE_WRITE_RACE for msg in msgs]) arr_anal = ArrayIndexAnalysis( smt_array_anal_options) From 2488ed0d747cd43b529b6ed6c4243290397d79c5 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Mon, 17 Nov 2025 11:14:56 +0000 Subject: [PATCH 07/17] Call ArrayIndexAnalysis from DependencyTools instead of ParallelLoopTrans --- .../psyir/tools/array_index_analysis.py | 8 +- src/psyclone/psyir/tools/dependency_tools.py | 42 ++++++++- .../transformations/parallel_loop_trans.py | 90 ++++++------------- .../tests/psyir/nodes/omp_directives_test.py | 3 +- .../transformations/transformations_test.py | 2 +- 5 files changed, 74 insertions(+), 71 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 22180edeb4..807ded8823 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -345,8 +345,8 @@ def add_all_array_accesses(self, node: Node, cond: z3.BoolRef): indices_flat = [i for inds in indices for i in inds] is_array_access = ( access_info.is_data_access and - (indices_flat != [] or - isinstance(access_info.node.datatype, ArrayType))) + (indices_flat != [] or + isinstance(access_info.node.datatype, ArrayType))) if is_array_access: smt_indices = [] for inds in indices: @@ -405,8 +405,8 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: for (i_arr_name, i_accesses) in iter_i.items(): for (j_arr_name, j_accesses) in iter_j.items(): if (i_arr_name == j_arr_name or - i_arr_name.startswith(j_arr_name + "%") or - j_arr_name.startswith(i_arr_name + "%")): + i_arr_name.startswith(j_arr_name + "%") or + j_arr_name.startswith(i_arr_name + "%")): # For each write access in the i iteration for i_access in i_accesses: if i_access.is_write: diff --git a/src/psyclone/psyir/tools/dependency_tools.py b/src/psyclone/psyir/tools/dependency_tools.py index 57daa3de42..853755903e 100644 --- a/src/psyclone/psyir/tools/dependency_tools.py +++ b/src/psyclone/psyir/tools/dependency_tools.py @@ -50,6 +50,7 @@ from psyclone.psyir.backend.sympy_writer import SymPyWriter from psyclone.psyir.backend.visitor import VisitorError from psyclone.psyir.nodes import Loop, Node, Range +from psyclone.psyir.tools.array_index_analysis import ArrayIndexAnalysis # pylint: disable=too-many-lines @@ -162,11 +163,20 @@ class DependencyTools(): specified in the PSyclone config file. This can be used to exclude for example 1-dimensional loops. :type loop_types_to_parallelise: Optional[List[str]] + :param use_smt_array_index_analysis: if True, the SMT-based + array index analysis will be used for detecting array access + conflicts. An ArrayIndexAnalysis.Options value can also be given, + instead of a bool, in which case the analysis will be invoked + with the given options. + :type use_smt_array_index_analysis: Union[ + bool, ArrayIndexAnalysis.Options] :raises TypeError: if an invalid loop type is specified. ''' - def __init__(self, loop_types_to_parallelise=None): + def __init__(self, + loop_types_to_parallelise=None, + use_smt_array_index_analysis=False): if loop_types_to_parallelise: # Verify that all loop types specified are valid: config = Config.get() @@ -183,6 +193,7 @@ def __init__(self, loop_types_to_parallelise=None): else: self._loop_types_to_parallelise = [] self._clear_messages() + self._use_smt_array_index_analysis = use_smt_array_index_analysis # ------------------------------------------------------------------------- def _clear_messages(self): @@ -884,9 +895,15 @@ def can_loop_be_parallelised(self, loop, # TODO #1270 - the is_array_access function might be moved is_array = symbol.is_array_access(access_info=var_info) if is_array: - # Handle arrays - par_able = self._array_access_parallelisable(loop_vars, - var_info) + # If using the SMT-based array index analysis then do + # nothing for now. This analysis is run after the loop. + if self._use_smt_array_index_analysis: + # This analysis runs after the loop + par_able = True + else: + # Handle arrays + par_able = self._array_access_parallelisable(loop_vars, + var_info) else: # Handle scalar variable par_able = self._is_scalar_parallelisable(signature, var_info) @@ -898,6 +915,23 @@ def can_loop_be_parallelised(self, loop, # not just the first one result = False + # Apply the SMT-based array index analysis, if enabled + if self._use_smt_array_index_analysis: + if isinstance(self._use_smt_array_index_analysis, + ArrayIndexAnalysis.Options): + options = self._use_smt_array_index_analysis + else: + options = ArrayIndexAnalysis.Options() + analysis = ArrayIndexAnalysis(options) + conflict_free = analysis.is_loop_conflict_free(loop) + if not conflict_free: + self._add_message( + "The ArrayIndexAnalysis has determined that the" + "array accesses in the loop may be conflicting " + "and hence cannot be parallelised.", + DTCode.ERROR_DEPENDENCY) + result = False + return result # ------------------------------------------------------------------------- diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index e8cbda9207..07873c397e 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -53,7 +53,7 @@ BinaryOperation, IntrinsicCall ) from psyclone.psyir.tools import ( - DependencyTools, DTCode, ReductionInferenceTool, ArrayIndexAnalysis + DependencyTools, DTCode, ReductionInferenceTool, ArrayIndexAnalysis, ) from psyclone.psyir.transformations.loop_trans import LoopTrans from psyclone.psyir.transformations.async_trans_mixin import \ @@ -175,10 +175,8 @@ def validate(self, node, options=None, **kwargs): reduction_ops = self.get_option("reduction_ops", **kwargs) if reduction_ops is None: reduction_ops = [] - use_smt_array_anal = self.get_option( - "use_smt_array_anal", **kwargs) - smt_array_anal_options = self.get_option( - "smt_array_anal_options", **kwargs) + use_smt_array_index_analysis = self.get_option( + "use_smt_array_index_analysis", **kwargs) else: verbose = options.get("verbose", False) collapse = options.get("collapse", False) @@ -189,9 +187,8 @@ def validate(self, node, options=None, **kwargs): sequential = options.get("sequential", False) privatise_arrays = options.get("privatise_arrays", False) reduction_ops = options.get("reduction_ops", []) - use_smt_array_anal = options.get("use_smt_array_anal", False) - smt_array_anal_options = options.get( - "smt_array_anal_options", ArrayIndexAnalysis.Options()) + use_smt_array_index_analysis = options.get( + "use_smt_array_index_analysis", False) # Check type of reduction_ops (not handled by validate_options) if not isinstance(reduction_ops, list): @@ -267,7 +264,8 @@ def validate(self, node, options=None, **kwargs): f" object containing str representing the " f"symbols to ignore, but got '{ignore_dependencies_for}'.") - dep_tools = DependencyTools() + dep_tools = DependencyTools( + use_smt_array_index_analysis=use_smt_array_index_analysis) signatures = [Signature(name) for name in ignore_dependencies_for] @@ -278,7 +276,6 @@ def validate(self, node, options=None, **kwargs): # The DependencyTools also returns False for things that are # not an issue, so we ignore specific messages. errors = [] - num_depedency_errors = 0 for message in dep_tools.get_all_messages(): if message.code == DTCode.WARN_SCALAR_WRITTEN_ONCE: continue @@ -304,22 +301,8 @@ def validate(self, node, options=None, **kwargs): if clause: self.inferred_reduction_clauses.append(clause) continue - - if (message.code == DTCode.ERROR_DEPENDENCY or - message.code == DTCode.ERROR_WRITE_WRITE_RACE): - num_depedency_errors = num_depedency_errors + 1 errors.append(str(message)) - # Use ArrayIndexAnalysis - if use_smt_array_anal: - # Are all the errors array dependency errors? - if len(errors) > 0 and len(errors) == num_depedency_errors: - # Try using the ArrayIndexAnalysis to prove that the - # dependency errors are false - arr_anal = ArrayIndexAnalysis(smt_array_anal_options) - if arr_anal.is_loop_conflict_free(node): - errors = [] - if errors: error_lines = "\n".join(errors) messages = (f"Loop cannot be parallelised because:\n" @@ -348,9 +331,8 @@ def apply(self, node, options=None, verbose: bool = False, nowait: bool = False, reduction_ops: List[Union[BinaryOperation.Operator, IntrinsicCall.Intrinsic]] = None, - use_smt_array_anal: bool = False, - smt_array_anal_options: - ArrayIndexAnalysis.Options = ArrayIndexAnalysis.Options(), + use_smt_array_index_analysis: + Union[bool, ArrayIndexAnalysis.Options] = False, **kwargs): ''' Apply the Loop transformation to the specified node in a @@ -395,10 +377,11 @@ def apply(self, node, options=None, verbose: bool = False, :param reduction_ops: if non-empty, attempt parallelisation of loops by inferring reduction clauses involving any of the reduction operators in the list. - :param bool use_smt_array_anal: whether to use the SMT-based - ArrayIndexAnalysis to discharge false dependency errors. - :param bool smt_array_anal_options: options for the array index - analysis. + :param use_smt_array_index_analysis: if True, the SMT-based + array index analysis will be used for detecting array access + conflicts. An ArrayIndexAnalysis.Options value can also be given, + instead of a bool, in which case the analysis will be invoked + with the given options. ''' if not options: @@ -408,8 +391,7 @@ def apply(self, node, options=None, verbose: bool = False, privatise_arrays=privatise_arrays, sequential=sequential, nowait=nowait, reduction_ops=reduction_ops, - use_smt_array_anal=use_smt_array_anal, - smt_array_anal_options=smt_array_anal_options, + use_smt_array_index_analysis=use_smt_array_index_analysis, **kwargs ) # Rename the input options that are renamed in this apply method. @@ -431,9 +413,8 @@ def apply(self, node, options=None, verbose: bool = False, privatise_arrays = options.get("privatise_arrays", False) nowait = options.get("nowait", False) reduction_ops = options.get("reduction_ops", []) - use_smt_array_anal = options.get("use_smt_array_anal", False) - smt_array_anal_options = options.get( - "smt_array_anal_options", ArrayIndexAnalysis.Options()) + use_smt_array_index_analysis = options.get( + "use_smt_array_index_analysis", False) self.validate(node, options=options, verbose=verbose, collapse=collapse, force=force, @@ -441,12 +422,13 @@ def apply(self, node, options=None, verbose: bool = False, privatise_arrays=privatise_arrays, sequential=sequential, nowait=nowait, reduction_ops=reduction_ops, - use_smt_array_anal=use_smt_array_anal, - smt_array_anal_options=smt_array_anal_options, + use_smt_array_index_analysis=( + use_smt_array_index_analysis), **kwargs) list_of_signatures = [Signature(name) for name in list_of_names] - dtools = DependencyTools() + dtools = DependencyTools( + use_smt_array_index_analysis=use_smt_array_index_analysis) # Add all reduction variables inferred by 'validate' to the list # of signatures to ignore @@ -521,28 +503,14 @@ def apply(self, node, options=None, verbose: bool = False, if not next_loop.independent_iterations( dep_tools=dtools, signatures_to_ignore=list_of_signatures): - msgs = dtools.get_all_messages() - discharge_errors = False - - if use_smt_array_anal: - all_dep_errors = all( - [msg.code == DTCode.ERROR_DEPENDENCY or - msg.code == DTCode.ERROR_WRITE_WRITE_RACE - for msg in msgs]) - arr_anal = ArrayIndexAnalysis( - smt_array_anal_options) - discharge_errors = ( - all_dep_errors and - arr_anal.is_loop_conflict_free(next_loop)) - - if not discharge_errors: - if verbose: - next_loop.preceding_comment = ( - "\n".join([str(m) for m in msgs]) + - " Consider using the \"ignore_dependenc" - "ies_for\" transformation option if this " - "is a false dependency.") - break + if verbose: + msgs = dtools.get_all_messages() + next_loop.preceding_comment = ( + "\n".join([str(m) for m in msgs]) + + " Consider using the \"ignore_dependencies_" + "for\" transformation option if this is a " + "false dependency.") + break else: num_collapsable_loops = None diff --git a/src/psyclone/tests/psyir/nodes/omp_directives_test.py b/src/psyclone/tests/psyir/nodes/omp_directives_test.py index a063d0542b..6a04a673d0 100644 --- a/src/psyclone/tests/psyir/nodes/omp_directives_test.py +++ b/src/psyclone/tests/psyir/nodes/omp_directives_test.py @@ -5383,6 +5383,7 @@ def test_array_analysis_option(fortran_reader, fortran_writer): end subroutine my_matmul''') omplooptrans = OMPLoopTrans(omp_directive="paralleldo") loop = psyir.walk(Loop)[0] - omplooptrans.apply(loop, collapse=True, use_smt_array_anal=True) + omplooptrans.apply( + loop, collapse=True, use_smt_array_index_analysis=True) output = fortran_writer(psyir) assert "collapse(2)" in output diff --git a/src/psyclone/tests/psyir/transformations/transformations_test.py b/src/psyclone/tests/psyir/transformations/transformations_test.py index ad5a8f7fd3..304aeb06bd 100644 --- a/src/psyclone/tests/psyir/transformations/transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/transformations_test.py @@ -578,7 +578,7 @@ def test_omploop_trans_new_options(sample_psyir): "'fakeoption2']. Valid options are '['node_type_check', " "'verbose', 'collapse', 'force', 'ignore_dependencies_for', " "'privatise_arrays', 'sequential', 'nowait', 'reduction_ops', " - "'use_smt_array_anal', 'smt_array_anal_options', " + "'use_smt_array_index_analysis', " "'options', 'reprod', 'enable_reductions']." in str(excinfo.value)) From 06c7b9afb6b0e31d3e67b8927badca1d0351cb74 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Mon, 17 Nov 2025 12:24:51 +0000 Subject: [PATCH 08/17] Simple heuristic to decide between integer and bit-vector solvers --- src/psyclone/psyir/tools/array_index_analysis.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 807ded8823..24d1b020ce 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -137,11 +137,12 @@ class ArrayIndexAnalysis: + # Class representing analysis options class Options: def __init__(self, int_width: int = 32, - use_bv: int = True, + use_bv: bool = None, smt_timeout_ms: int = 5000, prohibit_overflow: bool = False): # Set SMT solver timeout in milliseconds @@ -390,6 +391,19 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: self.init_analysis() self.loop_to_parallelise = loop + # Resolve choice of integers v. bit vectors + if self.use_bv is None: + for call in loop.walk(IntrinsicCall): + i = call.intrinsic + if i in [IntrinsicCall.Intrinsic.SHIFTL, + IntrinsicCall.Intrinsic.SHIFTR, + IntrinsicCall.Intrinsic.SHIFTA, + IntrinsicCall.Intrinsic.IAND, + IntrinsicCall.Intrinsic.IOR, + IntrinsicCall.Intrinsic.IEOR]: + self.use_bv = True + break + # Step through body of the enclosing routine, statement by statement for stmt in routine.children: self.step(stmt, z3.BoolVal(True)) From 9f62114909c969c1a83bb03fa66001c2acfba9c1 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Mon, 17 Nov 2025 12:28:37 +0000 Subject: [PATCH 09/17] Small tweak to heuristic --- src/psyclone/psyir/tools/array_index_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 24d1b020ce..cb32b2bb80 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -393,7 +393,7 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: # Resolve choice of integers v. bit vectors if self.use_bv is None: - for call in loop.walk(IntrinsicCall): + for call in routine.walk(IntrinsicCall): i = call.intrinsic if i in [IntrinsicCall.Intrinsic.SHIFTL, IntrinsicCall.Intrinsic.SHIFTR, From cde3a8f24ae9fde7978c0eaf22539c9d31ff9ae2 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Tue, 18 Nov 2025 16:00:09 +0000 Subject: [PATCH 10/17] Option to handle some common array intrinsics --- src/psyclone/psyir/tools/__init__.py | 6 +- .../psyir/tools/array_index_analysis.py | 204 ++++++++++++------ src/psyclone/psyir/tools/dependency_tools.py | 11 +- .../transformations/parallel_loop_trans.py | 7 +- 4 files changed, 150 insertions(+), 78 deletions(-) diff --git a/src/psyclone/psyir/tools/__init__.py b/src/psyclone/psyir/tools/__init__.py index 3df31cb2ce..8583883156 100644 --- a/src/psyclone/psyir/tools/__init__.py +++ b/src/psyclone/psyir/tools/__init__.py @@ -43,7 +43,8 @@ from psyclone.psyir.tools.read_write_info import ReadWriteInfo from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain from psyclone.psyir.tools.reduction_inference import ReductionInferenceTool -from psyclone.psyir.tools.array_index_analysis import ArrayIndexAnalysis +from psyclone.psyir.tools.array_index_analysis import (ArrayIndexAnalysis, + ArrayIndexAnalysisOptions) # For AutoAPI documentation generation. __all__ = ['CallTreeUtils', @@ -52,4 +53,5 @@ 'DefinitionUseChain', 'ReadWriteInfo', 'ReductionInferenceTool', - 'ArrayIndexAnalysis'] + 'ArrayIndexAnalysis', + 'ArrayIndexAnalysisOptions'] diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index cb32b2bb80..4529c0d0e5 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -136,24 +136,28 @@ # access pair and determine whether or not the loop is conflict free. +# Class representing analysis options +class ArrayIndexAnalysisOptions: + def __init__(self, + int_width: int = 32, + use_bv: bool = None, + smt_timeout_ms: int = 5000, + prohibit_overflow: bool = False, + handle_array_intrins: bool = False): + # Set SMT solver timeout in milliseconds + self.smt_timeout = smt_timeout_ms + # Fortran integer width in bits + self.int_width = int_width + # Use fixed-width bit vectors or arbirary precision integers? + self.use_bv = use_bv + # Prohibit bit-vector overflow when solving constraints? + self.prohibit_overflow = prohibit_overflow + # Handle array intrinsics such as size, lbound, and ubound + self.handle_array_intrins = handle_array_intrins + + +# Main analysis class class ArrayIndexAnalysis: - - # Class representing analysis options - class Options: - def __init__(self, - int_width: int = 32, - use_bv: bool = None, - smt_timeout_ms: int = 5000, - prohibit_overflow: bool = False): - # Set SMT solver timeout in milliseconds - self.smt_timeout = smt_timeout_ms - # Fortran integer width in bits - self.int_width = int_width - # Use fixed-width bit vectors or arbirary precision integers? - self.use_bv = use_bv - # Prohibit bit-vector overflow when solving constraints? - self.prohibit_overflow = prohibit_overflow - # Class representing an array access class ArrayAccess: def __init__(self, @@ -171,11 +175,8 @@ def __init__(self, self.psyir_node = psyir_node # ArrayIndexAnalysis constructor - def __init__(self, options=Options()): - self.smt_timeout = options.smt_timeout - self.int_width = options.int_width - self.use_bv = options.use_bv - self.prohibit_overflow = options.prohibit_overflow + def __init__(self, options=ArrayIndexAnalysisOptions()): + self.opts = options # Initialise analysis def init_analysis(self): @@ -195,6 +196,22 @@ def init_analysis(self): self.in_loop_to_parallelise = False # Has the analaysis finished? self.finished = False + # We map array intrinsic calls (e.g. size, lbound, ubound) to SMT + # integer variables. The following dict maps array names to a + # set of instrinsic variables associated with that array. + self.array_intrins_vars = {} + + # Initialise the 'array_intrins_vars' dict + def init_array_intrins_vars(self, routine): + if self.opts.handle_array_intrins: + for stmt in routine.children: + for call in stmt.walk(IntrinsicCall): + intrins_pair = translate_array_intrinsic_call(call) + if intrins_pair: + (arr_name, var_name) = intrins_pair + if arr_name not in self.array_intrins_vars: + self.array_intrins_vars[arr_name] = set() + self.array_intrins_vars[arr_name].add(var_name) # Push copy of current substitution to the stack def save_subst(self): @@ -206,22 +223,22 @@ def restore_subst(self): # Create an fresh SMT integer variable def fresh_integer_var(self) -> z3.ExprRef: - if self.use_bv: - return z3.FreshConst(z3.BitVecSort(self.int_width)) + if self.opts.use_bv: + return z3.FreshConst(z3.BitVecSort(self.opts.int_width)) else: return z3.FreshInt() # Create an integer SMT variable with the given name def integer_var(self, var) -> z3.ExprRef: - if self.use_bv: - return z3.BitVec(var, self.int_width) + if self.opts.use_bv: + return z3.BitVec(var, self.opts.int_width) else: return z3.Int(var) # Create an SMT integer value def integer_val(self, val: int) -> z3.ExprRef: - if self.use_bv: - return z3.BitVecVal(val, self.int_width) + if self.opts.use_bv: + return z3.BitVecVal(val, self.opts.int_width) else: return z3.IntVal(val) @@ -252,6 +269,15 @@ def kill_all_written_vars(self, node: Node): elif _is_scalar_logical(access_info.node.datatype): self.kill_logical_var(sig.var_name) break + elif isinstance(access_info.node.datatype, ArrayType): + # If an array variable is modified we kill intrinsic + # vars associated with it. This is overly safe: + # we probably only need to kill these vars if the + # array is passed to a mutating routine/intrinsic. + if sig.var_name in self.array_intrins_vars: + for v in self.array_intrins_vars[sig.var_name]: + self.kill_integer_var(v) + break # Add the SMT constraint to the constraint set def add_constraint(self, smt_expr: z3.BoolRef): @@ -280,26 +306,26 @@ def add_logical_assignment(self, var: str, smt_expr: z3.BoolRef): # Translate integer expresison to SMT, and apply current substitution def translate_integer_expr_with_subst(self, expr: z3.ExprRef): (smt_expr, prohibit_overflow) = translate_integer_expr( - expr, self.int_width, self.use_bv) + expr, self.opts) subst_pairs = list(self.subst.items()) - if self.prohibit_overflow: + if self.opts.prohibit_overflow: self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) return z3.substitute(smt_expr, *subst_pairs) # Translate logical expresison to SMT, and apply current substitution def translate_logical_expr_with_subst(self, expr: z3.BoolRef): (smt_expr, prohibit_overflow) = translate_logical_expr( - expr, self.int_width, self.use_bv) + expr, self.opts) subst_pairs = list(self.subst.items()) - if self.prohibit_overflow: + if self.opts.prohibit_overflow: self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) return z3.substitute(smt_expr, *subst_pairs) # Translate conditional expresison to SMT, and apply current substitution def translate_cond_expr_with_subst(self, expr: z3.BoolRef): (smt_expr, prohibit_overflow) = translate_logical_expr( - expr, self.int_width, self.use_bv) - if self.prohibit_overflow: + expr, self.opts) + if self.opts.prohibit_overflow: smt_expr = z3.And(smt_expr, prohibit_overflow) subst_pairs = list(self.subst.items()) return z3.substitute(smt_expr, *subst_pairs) @@ -390,9 +416,10 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: # Start with an empty constraint set and substitution self.init_analysis() self.loop_to_parallelise = loop + self.init_array_intrins_vars(routine) # Resolve choice of integers v. bit vectors - if self.use_bv is None: + if self.opts.use_bv is None: for call in routine.walk(IntrinsicCall): i = call.intrinsic if i in [IntrinsicCall.Intrinsic.SHIFTL, @@ -401,7 +428,7 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: IntrinsicCall.Intrinsic.IAND, IntrinsicCall.Intrinsic.IOR, IntrinsicCall.Intrinsic.IEOR]: - self.use_bv = True + self.opts.use_bv = True break # Step through body of the enclosing routine, statement by statement @@ -439,7 +466,7 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: # Invoke Z3 solver with a timeout solver = z3.Solver() - solver.set("timeout", self.smt_timeout) + solver.set("timeout", self.opts.smt_timeout) solver.add(z3.And(*self.constraints, z3.Or(*conflicts))) result = solver.check() if result == z3.unsat: @@ -570,8 +597,8 @@ def step(self, stmt: Node, cond: z3.BoolRef): # Translate a scalar integer Fortran expression to SMT. In addition, # return a constraint that prohibits bit vector overflow in the expression. def translate_integer_expr(expr_root: Node, - int_width: int, - use_bv: bool) -> (z3.ExprRef, z3.BoolRef): + opts: ArrayIndexAnalysisOptions + ) -> (z3.ExprRef, z3.BoolRef): constraints = [] def trans(expr: Node) -> z3.ExprRef: @@ -580,8 +607,8 @@ def trans(expr: Node) -> z3.ExprRef: # Literal if isinstance(expr, Literal) and type_ok: - if use_bv: - return z3.BitVecVal(int(expr.value), int_width) + if opts.use_bv: + return z3.BitVecVal(int(expr.value), opts.int_width) else: return z3.IntVal(int(expr.value)) @@ -592,8 +619,8 @@ def trans(expr: Node) -> z3.ExprRef: (sig, indices) = expr.get_signature_and_indices() indices_flat = [i for inds in indices for i in inds] if indices_flat == []: - if use_bv: - return z3.BitVec(str(sig), int_width) + if opts.use_bv: + return z3.BitVec(str(sig), opts.int_width) else: return z3.Int(str(sig)) @@ -601,7 +628,7 @@ def trans(expr: Node) -> z3.ExprRef: if isinstance(expr, UnaryOperation): arg_smt = trans(expr.operand) if expr.operator == UnaryOperation.Operator.MINUS: - if use_bv: + if opts.use_bv: constraints.append(z3.BVSNegNoOverflow(arg_smt)) return -arg_smt if expr.operator == UnaryOperation.Operator.PLUS: @@ -614,28 +641,28 @@ def trans(expr: Node) -> z3.ExprRef: right_smt = trans(right) if expr.operator == BinaryOperation.Operator.ADD: - if use_bv: + if opts.use_bv: constraints.append(z3.BVAddNoOverflow( left_smt, right_smt, True)) constraints.append(z3.BVAddNoUnderflow( left_smt, right_smt)) return left_smt + right_smt if expr.operator == BinaryOperation.Operator.SUB: - if use_bv: + if opts.use_bv: constraints.append(z3.BVSubNoOverflow( left_smt, right_smt)) constraints.append(z3.BVSubNoUnderflow( left_smt, right_smt, True)) return left_smt - right_smt if expr.operator == BinaryOperation.Operator.MUL: - if use_bv: + if opts.use_bv: constraints.append(z3.BVMulNoOverflow( left_smt, right_smt, True)) constraints.append(z3.BVMulNoUnderflow( left_smt, right_smt)) return left_smt * right_smt if expr.operator == BinaryOperation.Operator.DIV: - if use_bv: + if opts.use_bv: constraints.append(z3.BVSDivNoOverflow( left_smt, right_smt)) return left_smt / right_smt @@ -645,7 +672,7 @@ def trans(expr: Node) -> z3.ExprRef: # Unary operators if expr.intrinsic == IntrinsicCall.Intrinsic.ABS: smt_arg = trans(expr.children[1]) - if use_bv: + if opts.use_bv: constraints.append(z3.BVSNegNoOverflow(smt_arg)) return z3.Abs(smt_arg) @@ -663,7 +690,7 @@ def trans(expr: Node) -> z3.ExprRef: if expr.intrinsic == IntrinsicCall.Intrinsic.MOD: return left_smt % right_smt - if use_bv: + if opts.use_bv: if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: return left_smt << right_smt if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: @@ -678,29 +705,29 @@ def trans(expr: Node) -> z3.ExprRef: return left_smt ^ right_smt else: if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: - return z3.BV2Int(z3.Int2BV(left_smt, int_width) << - z3.Int2BV(right_smt, int_width), + return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) << + z3.Int2BV(right_smt, opts.int_width), is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: return z3.BV2Int(z3.LShR( - z3.Int2BV(left_smt, int_width), - z3.Int2BV(right_smt, int_width)), + z3.Int2BV(left_smt, opts.int_width), + z3.Int2BV(right_smt, opts.int_width)), is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: - return z3.BV2Int(z3.Int2BV(left_smt, int_width) >> - z3.Int2BV(right_smt, int_width), + return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) >> + z3.Int2BV(right_smt, opts.int_width), is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: - return z3.BV2Int(z3.Int2BV(left_smt, int_width) & - z3.Int2BV(right_smt, int_width), + return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) & + z3.Int2BV(right_smt, opts.int_width), is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: - return z3.BV2Int(z3.Int2BV(left_smt, int_width) | - z3.Int2BV(right_smt, int_width), + return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) | + z3.Int2BV(right_smt, opts.int_width), is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: - return z3.BV2Int(z3.Int2BV(left_smt, int_width) ^ - z3.Int2BV(right_smt, int_width), + return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) ^ + z3.Int2BV(right_smt, opts.int_width), is_signed=True) # N-ary operators @@ -715,9 +742,18 @@ def trans(expr: Node) -> z3.ExprRef: reduced = z3.If(reduced < arg, arg, reduced) return reduced + # Array intrinsics + if opts.handle_array_intrins: + array_intrins_pair = translate_array_intrinsic_call(expr) + if array_intrins_pair: + if opts.use_bv: + return z3.BitVec(array_intrins_pair[1], opts.int_width) + else: + return z3.Int(array_intrins_pair[1]) + # Fall through: return a fresh, unconstrained symbol - if use_bv: - return z3.FreshConst(z3.BitVecSort(int_width)) + if opts.use_bv: + return z3.FreshConst(z3.BitVecSort(opts.int_width)) else: return z3.FreshInt() @@ -728,8 +764,8 @@ def trans(expr: Node) -> z3.ExprRef: # Translate a scalar logical Fortran expression to SMT. In addition, # return a constraint that prohibits bit vector overflow in the expression. def translate_logical_expr(expr_root: Node, - int_width: int, - use_bv: bool) -> (z3.BoolRef, z3.BoolRef): + opts: ArrayIndexAnalysisOptions + ) -> (z3.BoolRef, z3.BoolRef): # Constraints to prohibit bit-vector overflow overflow = [] @@ -788,10 +824,10 @@ def trans(expr: Node): BinaryOperation.Operator.LE]: (left, right) = expr.operands (left_smt, prohibit_overflow) = translate_integer_expr( - left, int_width, use_bv) + left, opts) overflow.append(prohibit_overflow) (right_smt, prohibit_overflow) = translate_integer_expr( - right, int_width, use_bv) + right, opts) overflow.append(prohibit_overflow) if expr.operator == BinaryOperation.Operator.EQ: @@ -813,6 +849,38 @@ def trans(expr: Node): expr_root_smt = trans(expr_root) return (expr_root_smt, z3.And(*overflow)) + +# Translate array intrinsic call to an array name and a scalar integer +# variable name +def translate_array_intrinsic_call(call: IntrinsicCall) -> (str, str): + if call.intrinsic == IntrinsicCall.Intrinsic.SIZE: + var = "#size" + elif call.intrinsic == IntrinsicCall.Intrinsic.LBOUND: + var = "#lbound" + elif call.intrinsic == IntrinsicCall.Intrinsic.UBOUND: + var = "#ubound" + else: + return None + + if (len(call.children) != 2 and len(call.children) != 3): + return None + + array = call.children[1] + if isinstance(array, Reference): + (sig, indices) = array.get_signature_and_indices() + indices_flat = [i for inds in indices for i in inds] + if indices_flat == [] and len(sig) == 1: + var = var + "_" + sig.var_name + if len(call.children) == 3: + rank = call.children[2] + if isinstance(rank, Literal): + var = var + "_" + rank.value + else: + return None + return (sig.var_name, var) + + return None + # Helper functions # ================ diff --git a/src/psyclone/psyir/tools/dependency_tools.py b/src/psyclone/psyir/tools/dependency_tools.py index 853755903e..d324da1695 100644 --- a/src/psyclone/psyir/tools/dependency_tools.py +++ b/src/psyclone/psyir/tools/dependency_tools.py @@ -50,7 +50,8 @@ from psyclone.psyir.backend.sympy_writer import SymPyWriter from psyclone.psyir.backend.visitor import VisitorError from psyclone.psyir.nodes import Loop, Node, Range -from psyclone.psyir.tools.array_index_analysis import ArrayIndexAnalysis +from psyclone.psyir.tools.array_index_analysis import ( + ArrayIndexAnalysis, ArrayIndexAnalysisOptions) # pylint: disable=too-many-lines @@ -165,11 +166,11 @@ class DependencyTools(): :type loop_types_to_parallelise: Optional[List[str]] :param use_smt_array_index_analysis: if True, the SMT-based array index analysis will be used for detecting array access - conflicts. An ArrayIndexAnalysis.Options value can also be given, + conflicts. An ArrayIndexAnalysisOptions value can also be given, instead of a bool, in which case the analysis will be invoked with the given options. :type use_smt_array_index_analysis: Union[ - bool, ArrayIndexAnalysis.Options] + bool, ArrayIndexAnalysisOptions] :raises TypeError: if an invalid loop type is specified. @@ -918,10 +919,10 @@ def can_loop_be_parallelised(self, loop, # Apply the SMT-based array index analysis, if enabled if self._use_smt_array_index_analysis: if isinstance(self._use_smt_array_index_analysis, - ArrayIndexAnalysis.Options): + ArrayIndexAnalysisOptions): options = self._use_smt_array_index_analysis else: - options = ArrayIndexAnalysis.Options() + options = ArrayIndexAnalysisOptions() analysis = ArrayIndexAnalysis(options) conflict_free = analysis.is_loop_conflict_free(loop) if not conflict_free: diff --git a/src/psyclone/psyir/transformations/parallel_loop_trans.py b/src/psyclone/psyir/transformations/parallel_loop_trans.py index 07873c397e..08ca6fa9f5 100644 --- a/src/psyclone/psyir/transformations/parallel_loop_trans.py +++ b/src/psyclone/psyir/transformations/parallel_loop_trans.py @@ -53,7 +53,8 @@ BinaryOperation, IntrinsicCall ) from psyclone.psyir.tools import ( - DependencyTools, DTCode, ReductionInferenceTool, ArrayIndexAnalysis, + DependencyTools, DTCode, ReductionInferenceTool, + ArrayIndexAnalysisOptions ) from psyclone.psyir.transformations.loop_trans import LoopTrans from psyclone.psyir.transformations.async_trans_mixin import \ @@ -332,7 +333,7 @@ def apply(self, node, options=None, verbose: bool = False, reduction_ops: List[Union[BinaryOperation.Operator, IntrinsicCall.Intrinsic]] = None, use_smt_array_index_analysis: - Union[bool, ArrayIndexAnalysis.Options] = False, + Union[bool, ArrayIndexAnalysisOptions] = False, **kwargs): ''' Apply the Loop transformation to the specified node in a @@ -379,7 +380,7 @@ def apply(self, node, options=None, verbose: bool = False, the reduction operators in the list. :param use_smt_array_index_analysis: if True, the SMT-based array index analysis will be used for detecting array access - conflicts. An ArrayIndexAnalysis.Options value can also be given, + conflicts. An ArrayIndexAnalysisOptions value can also be given, instead of a bool, in which case the analysis will be invoked with the given options. From 6e5c443733db6d5743ddd76e34c34f4dd1ae8dd5 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Wed, 19 Nov 2025 13:44:50 +0000 Subject: [PATCH 11/17] Improve testing and coverage --- .../psyir/tools/array_index_analysis.py | 52 +++++---- .../psyir/tools/array_index_analysis_test.py | 105 +++++++++++------- 2 files changed, 97 insertions(+), 60 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 4529c0d0e5..f4dfa1653c 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -143,7 +143,7 @@ def __init__(self, use_bv: bool = None, smt_timeout_ms: int = 5000, prohibit_overflow: bool = False, - handle_array_intrins: bool = False): + handle_array_intrins: bool = True): # Set SMT solver timeout in milliseconds self.smt_timeout = smt_timeout_ms # Fortran integer width in bits @@ -437,7 +437,7 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: # Check that we have found and analysed the loop to parallelise if not (self.finished and len(self.saved_access_dicts) == 2): - return None + return None # pragma: no cover # Forumlate constraints for solving, considering the two iterations iter_i = self.saved_access_dicts[0] @@ -691,12 +691,15 @@ def trans(expr: Node) -> z3.ExprRef: return left_smt % right_smt if opts.use_bv: - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: - return left_smt << right_smt - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: - return z3.LShR(left_smt, right_smt) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: - return left_smt >> right_smt + # TODO: when fparser supports shift operations (#428), + # we can remove the "no cover" block + if True: # pragma: no cover + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: + return left_smt << right_smt + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: + return z3.LShR(left_smt, right_smt) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: + return left_smt >> right_smt if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: return left_smt & right_smt if expr.intrinsic == IntrinsicCall.Intrinsic.IOR: @@ -704,19 +707,24 @@ def trans(expr: Node) -> z3.ExprRef: if expr.intrinsic == IntrinsicCall.Intrinsic.IEOR: return left_smt ^ right_smt else: - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: - return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) << - z3.Int2BV(right_smt, opts.int_width), - is_signed=True) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: - return z3.BV2Int(z3.LShR( - z3.Int2BV(left_smt, opts.int_width), - z3.Int2BV(right_smt, opts.int_width)), - is_signed=True) - if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: - return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) >> - z3.Int2BV(right_smt, opts.int_width), - is_signed=True) + # TODO: when fparser supports shift operations (#428), + # we can remove the "no cover" block + if True: # pragma: no cover + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTL: + return z3.BV2Int( + z3.Int2BV(left_smt, opts.int_width) << + z3.Int2BV(right_smt, opts.int_width), + is_signed=True) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTR: + return z3.BV2Int(z3.LShR( + z3.Int2BV(left_smt, opts.int_width), + z3.Int2BV(right_smt, opts.int_width)), + is_signed=True) + if expr.intrinsic == IntrinsicCall.Intrinsic.SHIFTA: + return z3.BV2Int( + z3.Int2BV(left_smt, opts.int_width) >> + z3.Int2BV(right_smt, opts.int_width), + is_signed=True) if expr.intrinsic == IntrinsicCall.Intrinsic.IAND: return z3.BV2Int(z3.Int2BV(left_smt, opts.int_width) & z3.Int2BV(right_smt, opts.int_width), @@ -863,7 +871,7 @@ def translate_array_intrinsic_call(call: IntrinsicCall) -> (str, str): return None if (len(call.children) != 2 and len(call.children) != 3): - return None + return None # pragma: no cover array = call.children[1] if isinstance(array, Reference): diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py index c2ea821892..fea5bd7ff0 100644 --- a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -38,13 +38,15 @@ import pytest from psyclone.psyir.nodes import (Loop, Assignment, Reference) from psyclone.psyir.symbols import Symbol -from psyclone.psyir.tools import ArrayIndexAnalysis +from psyclone.psyir.tools import ( + ArrayIndexAnalysis, ArrayIndexAnalysisOptions) from psyclone.psyir.tools.array_index_analysis import translate_logical_expr import z3 # ----------------------------------------------------------------------------- -def test_reverse(fortran_reader, fortran_writer): +@pytest.mark.parametrize("use_bv", [True, False]) +def test_reverse(use_bv, fortran_reader, fortran_writer): '''Test that an array reversal routine has no array conflicts ''' psyir = fortran_reader.psyir_from_source(''' @@ -59,9 +61,10 @@ def test_reverse(fortran_reader, fortran_writer): arr(n+1-i) = tmp end do end subroutine''') + opts = ArrayIndexAnalysisOptions(use_bv=use_bv, prohibit_overflow=True) results = [] for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + results.append(ArrayIndexAnalysis(opts).is_loop_conflict_free(loop)) assert results == [True] @@ -84,8 +87,9 @@ def test_odd_even_trans(fortran_reader, fortran_writer): end do end subroutine''') results = [] + opts = ArrayIndexAnalysisOptions(prohibit_overflow=True) for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) + results.append(ArrayIndexAnalysis(opts).is_loop_conflict_free(loop)) assert results == [True] @@ -126,10 +130,9 @@ def test_tiled_matmul(fortran_reader, fortran_writer): # ----------------------------------------------------------------------------- -def test_flatten1(fortran_reader, fortran_writer): - '''Test that an array flattening routine has no array conflicts in its - inner loop (there are conflicts, due to integer overflow, in its outer - loop) +def test_flatten(fortran_reader, fortran_writer): + '''Test that an array flattening routine has no array conflicts in + either loop. ''' psyir = fortran_reader.psyir_from_source(''' subroutine flatten1(mat, arr) @@ -148,66 +151,60 @@ def test_flatten1(fortran_reader, fortran_writer): results = [] for loop in psyir.walk(Loop): results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [False, True] - - -# ----------------------------------------------------------------------------- -def test_flatten2(fortran_reader, fortran_writer): - '''Test that an array flattening routine has no array conflicts - ''' - psyir = fortran_reader.psyir_from_source(''' - subroutine flatten2(mat, arr) - real, intent(in) :: mat(0:,0:) - real, intent(out) :: arr(0:) - integer :: i, n, ny - n = size(arr) - ny = size(mat, 2) - do i = 0, n-1 - arr(i) = mat(mod(i, ny), i/ny) - end do - end subroutine''') - results = [] - for loop in psyir.walk(Loop): - results.append(ArrayIndexAnalysis().is_loop_conflict_free(loop)) - assert results == [True] + assert results == [True, True] # ----------------------------------------------------------------------------- @pytest.mark.parametrize("use_bv", [True, False]) -def test_translate_expr(use_bv, fortran_reader, fortran_writer): +def test_translate_expr(use_bv, + fortran_reader, + fortran_writer): '''Test that Fortran expressions are being correctly translated to SMT. ''' + opts = ArrayIndexAnalysisOptions( + use_bv=use_bv, + prohibit_overflow=True) def test(expr): psyir = fortran_reader.psyir_from_source(f''' subroutine sub(x) + integer :: arr(10) logical, intent(out) :: x + integer :: i x = {expr} end subroutine''') for assign in psyir.walk(Assignment): (rhs_smt, prohibit_overflow) = translate_logical_expr( - assign.rhs, 32, use_bv) + assign.rhs, opts) solver = z3.Solver() assert solver.check(rhs_smt) == z3.sat test("+1 == 1") test("abs(-1) == 1") - test("shiftr(2,1) == 1") - test("shifta(-2,1) == -1") + #test("shiftl(2,1) == 4") + #test("shiftr(2,1) == 1") + #test("shifta(-2,1) == -1") test("iand(5,1) == 1") test("ior(1,2) == 3") test("ieor(3,1) == 2") test("max(3,1) == 3") + test("i == 3") test(".true.") test(".not. .false.") test(".true. .and. .true.") test(".true. .or. .false.") test(".false. .eqv. .false.") test(".false. .neqv. .true.") + test("1 /= 2") test("1 < 2") test("10 > 2") test("1 <= 1 .and. 0 <= 1") test("1 >= 1 .and. 2 >= 1") + test("1 * 1 == 1") + test("mod(3, 2) == 1") test("foo(1)") + test("foo(1) == 1") + test("size(arr,tmp) == 1") + test("size(arr(1:2)) == 2") # ----------------------------------------------------------------------------- @@ -219,13 +216,14 @@ def check_conflict_free(fortran_reader, loop_str, yesno): psyir = fortran_reader.psyir_from_source(f''' subroutine sub(arr, n) integer, intent(inout) :: arr(:) - integer, intent(in) :: n, i + integer, intent(in) :: n, i, tmp, tmp2 logical :: ok {loop_str} end subroutine''') results = [] + opts = ArrayIndexAnalysisOptions(prohibit_overflow=True) for loop in psyir.walk(Loop): - analysis = ArrayIndexAnalysis() + analysis = ArrayIndexAnalysis(opts) results.append(analysis.is_loop_conflict_free(loop)) assert results == [yesno] @@ -237,9 +235,10 @@ def test_ifblock_with_else(fortran_reader, fortran_writer): '''do i = 1, n ok = i == 1 if (ok) then - arr(1) = 0 + arr(ior(1, 1)) = 0 else - arr(i) = i + tmp = i + arr(tmp) = i end if end do arr(2) = 0 @@ -269,6 +268,36 @@ def test_singleton_slice(fortran_reader, fortran_writer): True) +# ----------------------------------------------------------------------------- +def test_while_loop(fortran_reader, fortran_writer): + '''Test a do loop nested within a while loop''' + check_conflict_free(fortran_reader, + '''do while (tmp > 0) + do i = 1, n + tmp2 = arr(i) + arr(i) = 0 + do while (tmp2 > 0) + tmp2 = tmp2 - 1 + end do + end do + tmp = tmp - 1 + end do + ''', + True) + + +# ----------------------------------------------------------------------------- +def test_injective_index(fortran_reader, fortran_writer): + '''Test a do loop with an injective index mapping''' + check_conflict_free(fortran_reader, + '''do i = 1, n + tmp = i+1 + arr(tmp) = 0 + end do + ''', + True) + + # ----------------------------------------------------------------------------- def test_errors(fortran_reader, fortran_writer): '''Test that ArrayIndexAnalysis raises appropriate exceptions in From f6c42f5cd7d961d43947d136ee0ca3f6c4970bb0 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Wed, 19 Nov 2025 13:55:51 +0000 Subject: [PATCH 12/17] Make flake8 happy again --- .../tests/psyir/tools/array_index_analysis_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py index fea5bd7ff0..58dde660c2 100644 --- a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -164,6 +164,7 @@ def test_translate_expr(use_bv, opts = ArrayIndexAnalysisOptions( use_bv=use_bv, prohibit_overflow=True) + def test(expr): psyir = fortran_reader.psyir_from_source(f''' subroutine sub(x) @@ -180,9 +181,9 @@ def test(expr): test("+1 == 1") test("abs(-1) == 1") - #test("shiftl(2,1) == 4") - #test("shiftr(2,1) == 1") - #test("shifta(-2,1) == -1") + # test("shiftl(2,1) == 4") + # test("shiftr(2,1) == 1") + # test("shifta(-2,1) == -1") test("iand(5,1) == 1") test("ior(1,2) == 3") test("ieor(3,1) == 2") From 724837291b4fa1b968131343a85616760ce8b1c5 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Wed, 19 Nov 2025 15:54:13 +0000 Subject: [PATCH 13/17] Add a test for a failed analysis --- .../tests/psyir/nodes/omp_directives_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/psyclone/tests/psyir/nodes/omp_directives_test.py b/src/psyclone/tests/psyir/nodes/omp_directives_test.py index 6a04a673d0..3da17b54b7 100644 --- a/src/psyclone/tests/psyir/nodes/omp_directives_test.py +++ b/src/psyclone/tests/psyir/nodes/omp_directives_test.py @@ -5387,3 +5387,22 @@ def test_array_analysis_option(fortran_reader, fortran_writer): loop, collapse=True, use_smt_array_index_analysis=True) output = fortran_writer(psyir) assert "collapse(2)" in output + + +def test_array_analysis_failure(fortran_reader, fortran_writer): + '''Test that a conflicting loop is not parallelised when using the + SMT-based array index analysis. + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine non_injective_index(arr) + integer, intent(inout) :: arr(:) + integer :: i + do i = 1, size(arr) + arr(i/2) = 0 + end do + end subroutine''') + omplooptrans = OMPLoopTrans(omp_directive="paralleldo") + loop = psyir.walk(Loop)[0] + with pytest.raises(TransformationError) as err: + omplooptrans.apply(loop, use_smt_array_index_analysis=True) + assert "cannot be parallelised" in str(err.value) From 9eef149dc17c3cef5bd6e9c7857823aa140342da Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Fri, 21 Nov 2025 19:43:40 +0000 Subject: [PATCH 14/17] Move some tests to dependency_tools_test.py for coverage --- .../tests/psyir/nodes/omp_directives_test.py | 19 -------------- .../psyir/tools/dependency_tools_test.py | 26 ++++++++++++++++++- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/psyclone/tests/psyir/nodes/omp_directives_test.py b/src/psyclone/tests/psyir/nodes/omp_directives_test.py index 3da17b54b7..6a04a673d0 100644 --- a/src/psyclone/tests/psyir/nodes/omp_directives_test.py +++ b/src/psyclone/tests/psyir/nodes/omp_directives_test.py @@ -5387,22 +5387,3 @@ def test_array_analysis_option(fortran_reader, fortran_writer): loop, collapse=True, use_smt_array_index_analysis=True) output = fortran_writer(psyir) assert "collapse(2)" in output - - -def test_array_analysis_failure(fortran_reader, fortran_writer): - '''Test that a conflicting loop is not parallelised when using the - SMT-based array index analysis. - ''' - psyir = fortran_reader.psyir_from_source(''' - subroutine non_injective_index(arr) - integer, intent(inout) :: arr(:) - integer :: i - do i = 1, size(arr) - arr(i/2) = 0 - end do - end subroutine''') - omplooptrans = OMPLoopTrans(omp_directive="paralleldo") - loop = psyir.walk(Loop)[0] - with pytest.raises(TransformationError) as err: - omplooptrans.apply(loop, use_smt_array_index_analysis=True) - assert "cannot be parallelised" in str(err.value) diff --git a/src/psyclone/tests/psyir/tools/dependency_tools_test.py b/src/psyclone/tests/psyir/tools/dependency_tools_test.py index 8563b0b388..787cb56056 100644 --- a/src/psyclone/tests/psyir/tools/dependency_tools_test.py +++ b/src/psyclone/tests/psyir/tools/dependency_tools_test.py @@ -42,7 +42,8 @@ from psyclone.core import AccessType, Signature from psyclone.errors import InternalError from psyclone.psyir.nodes import Assignment, Loop -from psyclone.psyir.tools import DependencyTools, DTCode +from psyclone.psyir.tools import ( + DependencyTools, DTCode, ArrayIndexAnalysisOptions) from psyclone.tests.utilities import get_invoke @@ -1182,3 +1183,26 @@ def test_nemo_example_ranges(fortran_reader): # is tested in test_ranges_overlap above, here we check that this # overlap is indeed ignored because of the jj index). assert dep_tools.can_loop_be_parallelised(loops) + + +# ---------------------------------------------------------------------------- +@pytest.mark.parametrize("use_bv", [True, None]) +def test_array_analysis_failure(use_bv, fortran_reader, fortran_writer): + '''Test that a conflicting loop is not parallelised when using the + SMT-based array index analysis. + ''' + psyir = fortran_reader.psyir_from_source(''' + subroutine non_injective_index(arr) + integer, intent(inout) :: arr(:) + integer :: i + do i = 1, size(arr) + arr(i/2) = 0 + end do + end subroutine''') + if use_bv: + opts = ArrayIndexAnalysisOptions(use_bv=True) + dep_tools = DependencyTools(use_smt_array_index_analysis=opts) + else: + dep_tools = DependencyTools(use_smt_array_index_analysis=True) + loop = psyir.walk(Loop)[0] + assert not dep_tools.can_loop_be_parallelised(loop) From 91161683d85542e273da3e3175dd1a2a4eabacb0 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Wed, 26 Nov 2025 14:57:55 +0000 Subject: [PATCH 15/17] Add more pydoc and tests --- .../psyir/tools/array_index_analysis.py | 503 ++++++++++-------- .../psyir/tools/array_index_analysis_test.py | 87 ++- 2 files changed, 370 insertions(+), 220 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index f4dfa1653c..c0d8aefe35 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -34,10 +34,11 @@ # Author: M. Naylor, University of Cambridge, UK # ----------------------------------------------------------------------------- -'''This module provides a class to determine whether or not distinct iterations -of a given loop can generate conflicting array accesses (if not, the loop can -potentially be parallelised). It formulates the problem as a set of SMT -constraints over array indices which are then are passed to the Z3 solver.''' +'''This module provides a class to determine whether or not distinct +iterations of a given loop can generate conflicting array accesses (if +not, the loop can potentially be parallelised). It formulates the +problem as a set of SMT constraints over array indices which are then +are passed to the Z3 solver.''' import z3 from psyclone.psyir.nodes import Loop, DataNode, Literal, Assignment, \ @@ -49,137 +50,192 @@ # Outline # ======= # -# The analysis class provides a method 'is_loop_conflict_free()' to decide -# whether or not the array accesses in a given loop are conflicting between -# iterations. Two array accesses are conflicting if they access the same -# element of the same array, and at least one of the accesses is a write. +# The analysis class provides a method 'is_loop_conflict_free()' to +# decide whether or not the array accesses in a given loop are +# conflicting between iterations. Two array accesses are conflicting +# if they access the same element of the same array, and at least one +# of the accesses is a write. # -# The analysis assumes that any scalar integer or scalar logical variables -# written by the loop can safely be considered as private within each -# iteration. This should be validated by the callee. +# The analysis assumes that any scalar integer or scalar logical +# variables written by the loop can safely be considered as private +# to each iteration. This should be validated by the callee. # # The analysis algorithm operates, broadly, as follows. # -# Given a loop, we find its enclosing routine, and start analysing the routine -# statement-by-statement in a recursive-descent fashion. +# Given a loop, we find its enclosing routine, and start analysing the +# routine statement-by-statement in a recursive-descent fashion. # -# As we proceed, we maintain a set of SMT constraints and a substitution that -# maps Fortran variable names to SMT variable names. For each Fortran -# variable, the substitution points to an SMT variable that is constrained (in -# the set of constraints) such that it captures the value of the Fortran -# variable at the current point in the code. When a Fortran variable is -# mutated, the substitution is be modified to point to a fresh SMT variable, -# with new constraints, without destroying the old constraints. +# As we proceed, we maintain a set of SMT constraints and a +# substitution that maps Fortran variable names to SMT variable names. +# For each Fortran variable, the substitution points to an SMT +# variable that is constrained (in the set of constraints) such that +# it captures the value of the Fortran variable at the current point +# in the code. When a Fortran variable is mutated, the substitution is +# modified to point to a fresh SMT variable, with new constraints, +# without destroying the old constraints. # -# More concretely, when we encounter an assignment of a scalar integer/logical -# variable, of the form 'x = rhs', we translate 'rhs' to the SMT formula -# 'smt_rhs' with the current substitution applied. We then add a constraint -# 'var = smt_rhs' where 'var' is a fresh SMT variable, and update the -# substition so that 'x' maps to 'var'. +# More concretely, when we encounter an assignment of a scalar +# integer/logical variable, of the form 'x = rhs', we translate 'rhs' +# to the SMT formula 'smt_rhs' with the current substitution applied. +# We then add a constraint 'var = smt_rhs' where 'var' is a fresh SMT +# variable, and update the substition so that 'x' maps to 'var'. # # The Fortran-expression-to-SMT translator knows about several Fortran -# operators and intrinsics, but not all of them; when it sees something it -# doesn't know about, it simply translates it to a fresh unconstrained SMT -# variable. +# operators and intrinsics, but not all of them; when it sees +# something it doesn't know about, it simply translates it to a fresh +# unconstrained SMT variable. # -# Sometimes we reach a statement that modifies a Fortran variable in an unknown -# way (e.g. calling a subroutine). This can be handled by updating the -# substitution to point to a fresh unconstrained SMT variable; we refer to this -# process as "killing" the variable. +# Sometimes we reach a statement that modifies a Fortran variable in +# an unknown way (e.g. calling a subroutine). This can be handled by +# updating the substitution to point to a fresh unconstrained SMT +# variable; we refer to this process as "killing" the variable. # -# In addition to the current substitution, we maintain a stack of previous -# substitutions. This allows substitutions to be saved and restored before and -# after analysing a block of code that may or may not be executed at run time. +# In addition to the current substitution, we maintain a stack of +# previous substitutions. This allows substitutions to be saved and +# restored before and after analysing a block of code that may or may +# not be executed at run time. # -# We also maintain a "current condition". This can be viewed as a constraint -# that has not been comitted to the constraint set because we want to be able -# to grow, contract, and retract it as we enter and exit conditional blocks of -# code. This current condition is passed in recursive calls, so there is an -# implicit stack of them. +# We also maintain a "current condition". This can be viewed as a +# constraint that has not been comitted to the constraint set because +# we want to be able to grow, contract, and retract it as we enter and +# exit conditional blocks of code. This current condition is passed in +# recursive calls, so there is an implicit stack of them. # -# More concretely, when we encouter an 'if' statement, we copy the current -# substitution onto the stack, then recurse into the 'then' body, passing in -# the 'if' condition as an argument, and then restore the old substitution. We -# do the same for the 'else' body if there is one (in this case the negated -# condition is passed to the recursive call). Finally, we kill all variables -# written by the 'then' and 'else' bodies, because we don't know which will be -# executed at run time. +# More concretely, when we encouter an 'if' statement, we copy the +# current substitution onto the stack, then recurse into the 'then' +# body, passing in the 'if' condition as an argument, and then restore +# the old substitution. We do the same for the 'else' body if there is +# one (in this case the negated condition is passed to the recursive +# call). Finally, we kill all variables written by the 'then' and +# 'else' bodies, because we don't know which will be executed at run +# time. (In future, we could do better here by introducing OR +# constraints, e.g. each variable written is either equal to the value +# written in the 'then' OR the 'else' depending on the condition.) # -# As the analysis proceeds, we also maintain a list of array accesses. For each -# access, we record various information including the name of the array, -# whether it is a read or a write, the current condition at the point the -# access is made, and its list of indices (translated to SMT). When we are -# analysing code that is inside the loop of interest, we add all array accesses -# encountered to the list. +# As the analysis proceeds, we also maintain a list of array accesses. +# For each access, we record various information including the name of +# the array, whether it is a read or a write, the current condition at +# the point the access is made, and its list of indices (translated to +# SMT). When we are analysing code that is inside the loop of +# interest, we add all array accesses encountered to the list. # -# When we encounter the loop of interest, we perform a couple of steps before -# recursing into the loop body. First, we kill all variables written by the -# loop body, because we don't know whether we are entering the loop (at -# run time) for the first time or not. Second, we create two SMT variables to -# represent two arbitary but distinct iterations of the loop. Each variable is -# constrained to the start, stop, and step of the loop, and the two variables -# are constrained to be not equal. After that, we analyse the loop body twice, -# each time mapping the loop variable in the substitution to a different SMT -# variable. After analysing the loop body for the first time, we save the -# array access list and start afresh with a new one. Therefore, once the -# analysis is complete, we have two array access lists, one for each iteration. +# When we encounter the loop of interest, we perform a couple of steps +# before recursing into the loop body. First, we kill all variables +# written by the loop body, because we don't know whether we are +# entering the loop (at run time) for the first time or not. Second, +# we create two SMT variables to represent two arbitary but distinct +# iterations of the loop. Each variable is constrained to the start, +# stop, and step of the loop, and the two variables are constrained to +# be not equal. After that, we analyse the loop body twice, each time +# mapping the loop variable in the substitution to each of the SMT +# loop variables. After analysing the loop body for the first time, +# we save the array access list and start afresh with a new one. +# Therefore, once the analysis is complete, we have two array access +# lists, one for each iteration. # -# When we encounter a loop that is not the loop of interest, we follow a -# similar approach but only consider a single arbitrary iteration of the loop. +# When we encounter a loop that is not the loop of interest, we follow +# a similar approach but only consider a single arbitrary iteration of +# the loop. # -# When the recursive descent is complete, we are left with two array access -# lists. We are interested in whether any pair of accesses to the same array -# (in which one of the accesses is a write) represents a conflict. An access -# pair is conflict-free if an equality constraint between each access's -# indices, when combined with the current condition of each access and the -# global constraint set, is unsatisfiable. In this way, we can check every -# access pair and determine whether or not the loop is conflict free. +# When the recursive descent is complete, we are left with two array +# access lists. We are interested in whether any pair of accesses to +# the same array (in which one of the accesses is a write) represents +# a conflict. An access pair is conflict-free if an equality +# constraint between each access's indices, when combined with the +# current condition of each access and the global constraint set, is +# unsatisfiable. In this way, we can check every access pair and +# determine whether or not the loop is conflict free. -# Class representing analysis options +# Analysis Options +# ================ + class ArrayIndexAnalysisOptions: + '''The analysis supports a range of different options, which are all + captured together in this class. + + :param int_width: the bit width of Fortran integers. This is 32 by + default but it can be useful to reduce it to (say) 8 in particular + cases to improve the ability of solver to find a timely solution, + provided the user considers it safe to do so. (Note that the analysis + currently only gathers information about Fortran integer values of + unspecified width.) + + :param use_bv: whether to treat Fortran integers as bit vectors or + arbitrary-precision integers. If None is specified then the + analysis will use a simple heuristic to decide. + + :param smt_timeout_ms: the time limit (in milliseconds) given to + the SMT solver to find a solution. If the solver does not + return within this time, the analysis will conservatively return + that a conflict exists even though it has not yet found one. + + :param prohibit_overflow: if True, the analysis will tell the solver + to ignore the possibility of integer overflow. Integer overflow is + undefined behaviour in Fortran so this is safe. + + :param handle_array_intrins: handle array intrinsics 'size()', + 'lbound()', and 'ubound()' specially. For example, multiple + occurences of 'size(arr)' will be assumed to return the same value, + provided that those occurrences are not separated by a statement + that may modify the size/bounds of 'arr'. + ''' def __init__(self, int_width: int = 32, use_bv: bool = None, smt_timeout_ms: int = 5000, prohibit_overflow: bool = False, handle_array_intrins: bool = True): - # Set SMT solver timeout in milliseconds self.smt_timeout = smt_timeout_ms - # Fortran integer width in bits self.int_width = int_width - # Use fixed-width bit vectors or arbirary precision integers? self.use_bv = use_bv - # Prohibit bit-vector overflow when solving constraints? self.prohibit_overflow = prohibit_overflow - # Handle array intrinsics such as size, lbound, and ubound self.handle_array_intrins = handle_array_intrins -# Main analysis class +# Analysis +# ======== + class ArrayIndexAnalysis: - # Class representing an array access class ArrayAccess: + '''This class is used to record details of each array access + encountered during the analysis. + + :param cond: a boolean SMT expression representing the current + condition at the point the array access is made. + + :param is_write: whether the access is a read or a write. + + :param indices: SMT integer expressions representing the + indices of the array access. + + :param psyir_node: PSyIR node for the access (useful for reporting + conflict messages / errors). + ''' def __init__(self, cond: z3.BoolRef, is_write: bool, indices: list[list[z3.ExprRef]], psyir_node: Node): - # The condition at the location of the access self.cond = cond - # Whether the access is a read or a write self.is_write = is_write - # SMT expressions representing the indices of the access self.indices = indices - # PSyIR node for the access (useful for error reporting) self.psyir_node = psyir_node - # ArrayIndexAnalysis constructor def __init__(self, options=ArrayIndexAnalysisOptions()): + '''This class provides a method 'is_loop_conflict_free()' to + determine whether or not distinct iterations of a given loop + can generate conflicting array accesses. + + :param options: these options allow user control over features + provided by, and choices made by, the analysis. + ''' self.opts = options - # Initialise analysis - def init_analysis(self): + def _init_analysis(self): + '''Intialise the analysis by setting all the internal state + varibles accordingly.''' + # The substitution maps integer and logical Fortran variables # to SMT symbols self.subst = {} @@ -198,11 +254,19 @@ def init_analysis(self): self.finished = False # We map array intrinsic calls (e.g. size, lbound, ubound) to SMT # integer variables. The following dict maps array names to a - # set of instrinsic variables associated with that array. + # set of integer variable names holding the results of intrinsic + # calls on that array. self.array_intrins_vars = {} - # Initialise the 'array_intrins_vars' dict - def init_array_intrins_vars(self, routine): + def _init_array_intrins_vars(self, routine: Routine): + '''Initialise the 'array_intrins_vars' dict so that, for each + array accessed, it holds a set of integer variables + representing the results of intrinsics (such as size, + lbound, ubound) applied to that array. + + :param routine: the Routine holding the Loop that we are + analysing for conflicts. + ''' if self.opts.handle_array_intrins: for stmt in routine.children: for call in stmt.walk(IntrinsicCall): @@ -213,61 +277,63 @@ def init_array_intrins_vars(self, routine): self.array_intrins_vars[arr_name] = set() self.array_intrins_vars[arr_name].add(var_name) - # Push copy of current substitution to the stack - def save_subst(self): + def _save_subst(self): + '''Push copy of current substitution to the stack.''' self.subst_stack.append(self.subst.copy()) - # Pop substitution from stack into current substitution - def restore_subst(self): + def _restore_subst(self): + '''Pop substitution from stack into current substitution.''' self.subst = self.subst_stack.pop() - # Create an fresh SMT integer variable - def fresh_integer_var(self) -> z3.ExprRef: + def _fresh_integer_var(self) -> z3.ExprRef: + '''Create an fresh SMT integer variable.''' if self.opts.use_bv: return z3.FreshConst(z3.BitVecSort(self.opts.int_width)) else: return z3.FreshInt() - # Create an integer SMT variable with the given name - def integer_var(self, var) -> z3.ExprRef: + def _integer_var(self, var: str) -> z3.ExprRef: + '''Create an integer SMT variable with the given name.''' if self.opts.use_bv: return z3.BitVec(var, self.opts.int_width) else: return z3.Int(var) - # Create an SMT integer value - def integer_val(self, val: int) -> z3.ExprRef: + def _integer_val(self, val: int) -> z3.ExprRef: + '''Create an SMT integer value.''' if self.opts.use_bv: return z3.BitVecVal(val, self.opts.int_width) else: return z3.IntVal(val) - # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol - def kill_integer_var(self, var: str): - fresh_sym = self.fresh_integer_var() - smt_var = self.integer_var(var) + def _kill_integer_var(self, var: str): + '''Clear knowledge of integer 'var' by mapping it to a fresh, + unconstrained symbol.''' + fresh_sym = self._fresh_integer_var() + smt_var = self._integer_var(var) self.subst[smt_var] = fresh_sym - # Clear knowledge of 'var' by mapping it to a fresh, unconstrained symbol - def kill_logical_var(self, var: str): + def _kill_logical_var(self, var: str): + '''Clear knowledge of logical 'var' by mapping it to a fresh, + unconstrained symbol''' fresh_sym = z3.FreshBool() smt_var = z3.Bool(var) self.subst[smt_var] = fresh_sym - # Kill all scalar integer/logical variables written inside 'node' - def kill_all_written_vars(self, node: Node): + def _kill_all_written_vars(self, node: Node): + '''Kill all scalar integer/logical variables written inside 'node'.''' var_accesses = node.reference_accesses() for sig, access_seq in var_accesses.items(): for access_info in access_seq.all_write_accesses: if isinstance(access_info.node, Loop): - self.kill_integer_var(sig.var_name) + self._kill_integer_var(sig.var_name) break elif isinstance(access_info.node, Reference): if _is_scalar_integer(access_info.node.datatype): - self.kill_integer_var(sig.var_name) + self._kill_integer_var(sig.var_name) break elif _is_scalar_logical(access_info.node.datatype): - self.kill_logical_var(sig.var_name) + self._kill_logical_var(sig.var_name) break elif isinstance(access_info.node.datatype, ArrayType): # If an array variable is modified we kill intrinsic @@ -276,53 +342,60 @@ def kill_all_written_vars(self, node: Node): # array is passed to a mutating routine/intrinsic. if sig.var_name in self.array_intrins_vars: for v in self.array_intrins_vars[sig.var_name]: - self.kill_integer_var(v) + self._kill_integer_var(v) break - # Add the SMT constraint to the constraint set - def add_constraint(self, smt_expr: z3.BoolRef): + def _add_constraint(self, smt_expr: z3.BoolRef): + '''Add the SMT constraint to the constraint set.''' self.constraints.append(smt_expr) - # Add an integer assignment constraint to the constraint set - def add_integer_assignment(self, var: str, smt_expr: z3.ExprRef): + def _add_integer_assignment(self, var: str, smt_expr: z3.ExprRef): + '''Add an integer assignment constraint to the constraint set.''' # Create a fresh symbol - fresh_sym = self.fresh_integer_var() + fresh_sym = self._fresh_integer_var() # Assert equality between this symbol and the given SMT expression - self.add_constraint(fresh_sym == smt_expr) + self._add_constraint(fresh_sym == smt_expr) # Update the substitution - smt_var = self.integer_var(var) + smt_var = self._integer_var(var) self.subst[smt_var] = fresh_sym - # Add a logical assignment constraint to the constraint set - def add_logical_assignment(self, var: str, smt_expr: z3.BoolRef): + def _add_logical_assignment(self, var: str, smt_expr: z3.BoolRef): + '''Add a logical assignment constraint to the constraint set.''' # Create a fresh symbol fresh_sym = z3.FreshBool() # Assert equality between this symbol and the given SMT expression - self.add_constraint(fresh_sym == smt_expr) + self._add_constraint(fresh_sym == smt_expr) # Update the substitution smt_var = z3.Bool(var) self.subst[smt_var] = fresh_sym - # Translate integer expresison to SMT, and apply current substitution - def translate_integer_expr_with_subst(self, expr: z3.ExprRef): + def _translate_integer_expr_with_subst(self, expr: Node): + '''Translate the given integer expresison to SMT, and apply the + current substitution.''' (smt_expr, prohibit_overflow) = translate_integer_expr( expr, self.opts) subst_pairs = list(self.subst.items()) if self.opts.prohibit_overflow: - self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) + self._add_constraint( + z3.substitute(prohibit_overflow, *subst_pairs)) return z3.substitute(smt_expr, *subst_pairs) - # Translate logical expresison to SMT, and apply current substitution - def translate_logical_expr_with_subst(self, expr: z3.BoolRef): + def _translate_logical_expr_with_subst(self, expr: Node): + '''Translate the given logical expresison to SMT, and apply the + current substitution.''' (smt_expr, prohibit_overflow) = translate_logical_expr( expr, self.opts) subst_pairs = list(self.subst.items()) if self.opts.prohibit_overflow: - self.add_constraint(z3.substitute(prohibit_overflow, *subst_pairs)) + self._add_constraint( + z3.substitute(prohibit_overflow, *subst_pairs)) return z3.substitute(smt_expr, *subst_pairs) - # Translate conditional expresison to SMT, and apply current substitution - def translate_cond_expr_with_subst(self, expr: z3.BoolRef): + def _translate_cond_expr_with_subst(self, expr: Node): + '''Translate the given conditional expresison to SMT, and apply + the current substitution. Instead of adding constraints to + the constraint set, this function ANDs constraints with the + translated expression.''' (smt_expr, prohibit_overflow) = translate_logical_expr( expr, self.opts) if self.opts.prohibit_overflow: @@ -330,40 +403,34 @@ def translate_cond_expr_with_subst(self, expr: z3.BoolRef): subst_pairs = list(self.subst.items()) return z3.substitute(smt_expr, *subst_pairs) - # Constrain a loop variable to given start/stop/step - def constrain_loop_var(self, - var: z3.ExprRef, - start: DataNode, - stop: DataNode, - step: DataNode): - zero = self.integer_val(0) - var_begin = self.translate_integer_expr_with_subst(start) - var_end = self.translate_integer_expr_with_subst(stop) + def _constrain_loop_var(self, + var: z3.ExprRef, + start: DataNode, + stop: DataNode, + step: DataNode): + '''Constrain a loop variable to given start/stop/step.''' + zero = self._integer_val(0) + var_begin = self._translate_integer_expr_with_subst(start) + var_end = self._translate_integer_expr_with_subst(stop) if step is None: step = Literal("1", INTEGER_TYPE) # pragma: no cover - var_step = self.translate_integer_expr_with_subst(step) - self.add_constraint(z3.And( + var_step = self._translate_integer_expr_with_subst(step) + self._add_constraint(z3.And( ((var - var_begin) % var_step) == zero, z3.Implies(var_step > zero, var >= var_begin), z3.Implies(var_step < zero, var <= var_begin), z3.Implies(var_step > zero, var <= var_end), z3.Implies(var_step < zero, var >= var_end))) - # Add an array access to the current access dict - def add_array_access(self, - array_name: str, - is_write: bool, - cond: z3.BoolRef, - indices: list[list[z3.ExprRef]], - psyir_node: Node): - access = ArrayIndexAnalysis.ArrayAccess( - cond, is_write, indices, psyir_node) + def _add_array_access(self, array_name: str, access: ArrayAccess): + '''Add an array access to the current access dict.''' if array_name not in self.access_dict: self.access_dict[array_name] = [] self.access_dict[array_name].append(access) - # Add all array accesses in the given node to the current access dict - def add_all_array_accesses(self, node: Node, cond: z3.BoolRef): + def _add_all_array_accesses(self, node: Node, cond: z3.BoolRef): + '''Add all array accesses in the given node to the current + access dict.''' var_accesses = node.reference_accesses() for sig, access_seq in var_accesses.items(): for access_info in access_seq: @@ -380,27 +447,31 @@ def add_all_array_accesses(self, node: Node, cond: z3.BoolRef): smt_inds = [] for ind in inds: if isinstance(ind, Range): - var = self.fresh_integer_var() - self.constrain_loop_var( + var = self._fresh_integer_var() + self._constrain_loop_var( var, ind.start, ind.stop, ind.step) smt_inds.append(var) else: smt_inds.append( - self.translate_integer_expr_with_subst( + self._translate_integer_expr_with_subst( ind)) smt_indices.append(smt_inds) - self.add_array_access( + self._add_array_access( str(sig), - access_info.is_any_write(), - cond, smt_indices, access_info.node) + ArrayIndexAnalysis.ArrayAccess( + cond, access_info.is_any_write(), + smt_indices, access_info.node)) - # Move the current access dict to the stack, and proceed with an empty one - def save_access_dict(self): + def _save_access_dict(self): + '''Move the current access dict to the stack, and proceed with + an empty one.''' self.saved_access_dicts.append(self.access_dict) self.access_dict = {} - # Check if the given loop has a conflict def is_loop_conflict_free(self, loop: Loop) -> bool: + '''Determine whether or not distinct iterations of the given loop + can generate conflicting array accesses.''' + # Type checking if not isinstance(loop, Loop): raise TypeError("ArrayIndexAnalysis: Loop argument expected") @@ -414,9 +485,9 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: self.routine = routine # Start with an empty constraint set and substitution - self.init_analysis() + self._init_analysis() self.loop_to_parallelise = loop - self.init_array_intrins_vars(routine) + self._init_array_intrins_vars(routine) # Resolve choice of integers v. bit vectors if self.opts.use_bv is None: @@ -433,7 +504,7 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: # Step through body of the enclosing routine, statement by statement for stmt in routine.children: - self.step(stmt, z3.BoolVal(True)) + self._step(stmt, z3.BoolVal(True)) # Check that we have found and analysed the loop to parallelise if not (self.finished and len(self.saved_access_dicts) == 2): @@ -476,8 +547,9 @@ def is_loop_conflict_free(self, loop: Loop) -> bool: else: return None # pragma: no cover - # Analyse a single statement - def step(self, stmt: Node, cond: z3.BoolRef): + def _step(self, stmt: Node, cond: z3.BoolRef): + '''Analyse the given statement in recursive-descent fashion.''' + # Has analysis finished? if self.finished: return @@ -489,81 +561,81 @@ def step(self, stmt: Node, cond: z3.BoolRef): indices_flat = [i for inds in indices for i in inds] if indices_flat == [] and len(sig) == 1: if _is_scalar_integer(stmt.lhs.datatype): - rhs_smt = self.translate_integer_expr_with_subst( + rhs_smt = self._translate_integer_expr_with_subst( stmt.rhs) - self.add_integer_assignment(sig.var_name, rhs_smt) + self._add_integer_assignment(sig.var_name, rhs_smt) if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.rhs, cond) + self._add_all_array_accesses(stmt.rhs, cond) return elif _is_scalar_logical(stmt.lhs.datatype): - rhs_smt = self.translate_logical_expr_with_subst( + rhs_smt = self._translate_logical_expr_with_subst( stmt.rhs) - self.add_logical_assignment(sig.var_name, rhs_smt) + self._add_logical_assignment(sig.var_name, rhs_smt) if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.rhs, cond) + self._add_all_array_accesses(stmt.rhs, cond) return # Schedule if isinstance(stmt, Schedule): for child in stmt.children: - self.step(child, cond) + self._step(child, cond) return # IfBlock if isinstance(stmt, IfBlock): if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.condition, cond) + self._add_all_array_accesses(stmt.condition, cond) # Translate condition to SMT - smt_condition = self.translate_cond_expr_with_subst(stmt.condition) + smt_cond = self._translate_cond_expr_with_subst(stmt.condition) # Recursively step into 'then' if stmt.if_body: - self.save_subst() - self.step(stmt.if_body, z3.And(cond, smt_condition)) - self.restore_subst() + self._save_subst() + self._step(stmt.if_body, z3.And(cond, smt_cond)) + self._restore_subst() # Recursively step into 'else' if stmt.else_body: - self.save_subst() - self.step(stmt.else_body, - z3.And(cond, z3.Not(smt_condition))) - self.restore_subst() + self._save_subst() + self._step(stmt.else_body, + z3.And(cond, z3.Not(smt_cond))) + self._restore_subst() # Kill vars written by each branch if stmt.if_body: - self.kill_all_written_vars(stmt.if_body) + self._kill_all_written_vars(stmt.if_body) if stmt.else_body: - self.kill_all_written_vars(stmt.else_body) + self._kill_all_written_vars(stmt.else_body) return # Loop if isinstance(stmt, Loop): # Kill variables written by loop body - self.kill_all_written_vars(stmt.loop_body) + self._kill_all_written_vars(stmt.loop_body) # Kill loop variable - self.kill_integer_var(stmt.variable.name) + self._kill_integer_var(stmt.variable.name) # Have we reached the loop we'd like to parallelise? if stmt is self.loop_to_parallelise: self.in_loop_to_parallelise = True # Consider two arbitary but distinct iterations - i_var = self.fresh_integer_var() - j_var = self.fresh_integer_var() - self.add_constraint(i_var != j_var) + i_var = self._fresh_integer_var() + j_var = self._fresh_integer_var() + self._add_constraint(i_var != j_var) iteration_vars = [i_var, j_var] else: # Consider a single, arbitrary iteration - i_var = self.fresh_integer_var() + i_var = self._fresh_integer_var() iteration_vars = [i_var] # Analyse loop body for each iteration variable separately for var in iteration_vars: - self.save_subst() - smt_loop_var = self.integer_var(stmt.variable.name) + self._save_subst() + smt_loop_var = self._integer_var(stmt.variable.name) self.subst[smt_loop_var] = var # Introduce constraints on loop variable - self.constrain_loop_var( + self._constrain_loop_var( var, stmt.start_expr, stmt.stop_expr, stmt.step_expr) # Analyse loop body - self.step(stmt.loop_body, cond) + self._step(stmt.loop_body, cond) if stmt is self.loop_to_parallelise: - self.save_access_dict() - self.restore_subst() + self._save_access_dict() + self._restore_subst() # Record whether the analysis has finished if stmt is self.loop_to_parallelise: self.finished = True @@ -572,33 +644,35 @@ def step(self, stmt: Node, cond: z3.BoolRef): # WhileLoop if isinstance(stmt, WhileLoop): # Kill variables written by loop body - self.kill_all_written_vars(stmt.loop_body) + self._kill_all_written_vars(stmt.loop_body) # Add array accesses in condition if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt.condition, cond) + self._add_all_array_accesses(stmt.condition, cond) # Translate condition to SMT - smt_condition = self.translate_cond_expr_with_subst( + smt_condition = self._translate_cond_expr_with_subst( stmt.condition) # Recursively step into loop body - self.save_subst() - self.step(stmt.loop_body, z3.And(cond, smt_condition)) - self.restore_subst() + self._save_subst() + self._step(stmt.loop_body, z3.And(cond, smt_condition)) + self._restore_subst() return # Fall through if self.in_loop_to_parallelise: - self.add_all_array_accesses(stmt, cond) - self.kill_all_written_vars(stmt) + self._add_all_array_accesses(stmt, cond) + self._kill_all_written_vars(stmt) # Translating Fortran expressions to SMT formulae # =============================================== -# Translate a scalar integer Fortran expression to SMT. In addition, -# return a constraint that prohibits bit vector overflow in the expression. def translate_integer_expr(expr_root: Node, opts: ArrayIndexAnalysisOptions ) -> (z3.ExprRef, z3.BoolRef): + '''Translate a scalar integer Fortran expression to SMT. In addition, + return a constraint that prohibits/ignores bit-vector overflow in the + expression.''' + constraints = [] def trans(expr: Node) -> z3.ExprRef: @@ -769,11 +843,13 @@ def trans(expr: Node) -> z3.ExprRef: return (expr_root_smt, z3.And(*constraints)) -# Translate a scalar logical Fortran expression to SMT. In addition, -# return a constraint that prohibits bit vector overflow in the expression. def translate_logical_expr(expr_root: Node, opts: ArrayIndexAnalysisOptions ) -> (z3.BoolRef, z3.BoolRef): + '''Translate a scalar logical Fortran expression to SMT. In addition, + return a constraint that prohibits/ignores bit-vector overflow in the + expression.''' + # Constraints to prohibit bit-vector overflow overflow = [] @@ -858,9 +934,10 @@ def trans(expr: Node): return (expr_root_smt, z3.And(*overflow)) -# Translate array intrinsic call to an array name and a scalar integer -# variable name def translate_array_intrinsic_call(call: IntrinsicCall) -> (str, str): + '''Translate array intrinsic call to an array name and a scalar + integer variable name.''' + if call.intrinsic == IntrinsicCall.Intrinsic.SIZE: var = "#size" elif call.intrinsic == IntrinsicCall.Intrinsic.LBOUND: @@ -893,14 +970,14 @@ def translate_array_intrinsic_call(call: IntrinsicCall) -> (str, str): # ================ -# Check that type is a scalar integer of unspecified precision def _is_scalar_integer(dt: DataType) -> bool: + '''Check that type is a scalar integer of unspecified precision.''' return (isinstance(dt, ScalarType) and dt.intrinsic == ScalarType.Intrinsic.INTEGER and dt.precision == ScalarType.Precision.UNDEFINED) -# Check that type is a scalar logical def _is_scalar_logical(dt: DataType) -> bool: + '''Check that type is a scalar logical.''' return (isinstance(dt, ScalarType) and dt.intrinsic == ScalarType.Intrinsic.BOOLEAN) diff --git a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py index 58dde660c2..efc6c78e2d 100644 --- a/src/psyclone/tests/psyir/tools/array_index_analysis_test.py +++ b/src/psyclone/tests/psyir/tools/array_index_analysis_test.py @@ -129,6 +129,38 @@ def test_tiled_matmul(fortran_reader, fortran_writer): assert results == [True, True, False, True, True, False] +# ----------------------------------------------------------------------------- +def test_chunking_loop(fortran_reader, fortran_writer): + '''Test that a loop with array chunking has no array conflicts + ''' + psyir = fortran_reader.psyir_from_source(''' + module chunking_example + contains + + subroutine chunking(arr, chunk_size) + integer, dimension(:), intent(inout) :: arr + integer, intent(in) :: chunk_size + integer :: n, chunk_begin, chunk_end + + n = size(arr) + do chunk_begin = 1, n, chunk_size + chunk_end = min(chunk_begin+chunk_size-1, n) + call modify(arr(chunk_begin:chunk_end)) + end do + end subroutine + + pure subroutine modify(a) + integer, intent(inout) :: a(:) + end subroutine + + end module''') + opts = ArrayIndexAnalysisOptions(use_bv=False) + results = [] + for loop in psyir.walk(Loop): + results.append(ArrayIndexAnalysis(opts).is_loop_conflict_free(loop)) + assert results == [True] + + # ----------------------------------------------------------------------------- def test_flatten(fortran_reader, fortran_writer): '''Test that an array flattening routine has no array conflicts in @@ -217,7 +249,7 @@ def check_conflict_free(fortran_reader, loop_str, yesno): psyir = fortran_reader.psyir_from_source(f''' subroutine sub(arr, n) integer, intent(inout) :: arr(:) - integer, intent(in) :: n, i, tmp, tmp2 + integer, intent(inout) :: n, i, j, tmp, tmp2 logical :: ok {loop_str} end subroutine''') @@ -226,7 +258,7 @@ def check_conflict_free(fortran_reader, loop_str, yesno): for loop in psyir.walk(Loop): analysis = ArrayIndexAnalysis(opts) results.append(analysis.is_loop_conflict_free(loop)) - assert results == [yesno] + assert results == yesno # ----------------------------------------------------------------------------- @@ -244,7 +276,7 @@ def test_ifblock_with_else(fortran_reader, fortran_writer): end do arr(2) = 0 ''', - True) + [True]) # ----------------------------------------------------------------------------- @@ -255,7 +287,7 @@ def test_array_reference(fortran_reader, fortran_writer): arr = arr + i end do ''', - False) + [False]) # ----------------------------------------------------------------------------- @@ -266,7 +298,7 @@ def test_singleton_slice(fortran_reader, fortran_writer): arr(i:i:) = 0 end do ''', - True) + [True]) # ----------------------------------------------------------------------------- @@ -284,7 +316,7 @@ def test_while_loop(fortran_reader, fortran_writer): tmp = tmp - 1 end do ''', - True) + [True]) # ----------------------------------------------------------------------------- @@ -296,7 +328,48 @@ def test_injective_index(fortran_reader, fortran_writer): arr(tmp) = 0 end do ''', - True) + [True]) + + +# ----------------------------------------------------------------------------- +def test_invariant_if(fortran_reader, fortran_writer): + '''Test a do loop with an invariant if-condition''' + check_conflict_free(fortran_reader, + '''do i = 1, size(arr)-1 + if (tmp >= 0) then + arr(i) = 1 + else + arr(i+1) = 2 + end if + end do''', + [True]) + + +# ----------------------------------------------------------------------------- +def test_last_iteration(fortran_reader, fortran_writer): + '''Test a do loop with special behaviour on final iteration''' + check_conflict_free(fortran_reader, + '''n = size(arr) + do i = 1, n-1 + arr(i) = 0 + if (i == n-1) then + arr(i+1) = 10 + end if + end do''', + [True]) + + +# ----------------------------------------------------------------------------- +def test_triangular_loop(fortran_reader, fortran_writer): + '''Test a triangular nested loop''' + check_conflict_free(fortran_reader, + '''n = size(arr) + do i = 1, n-1 + do j = i+1, n + arr(j) = arr(j) + arr(i) + end do + end do''', + [False, True]) # ----------------------------------------------------------------------------- From 558c5f9d3574856396a9e008731109bf8a998aa9 Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Thu, 27 Nov 2025 14:41:08 +0000 Subject: [PATCH 16/17] Use a clearer encoding for loop variable constraints --- .../psyir/tools/array_index_analysis.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index c0d8aefe35..355668282c 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -415,12 +415,19 @@ def _constrain_loop_var(self, if step is None: step = Literal("1", INTEGER_TYPE) # pragma: no cover var_step = self._translate_integer_expr_with_subst(step) + i = self._fresh_integer_var() self._add_constraint(z3.And( - ((var - var_begin) % var_step) == zero, - z3.Implies(var_step > zero, var >= var_begin), - z3.Implies(var_step < zero, var <= var_begin), - z3.Implies(var_step > zero, var <= var_end), - z3.Implies(var_step < zero, var >= var_end))) + var_step != zero, + z3.Implies(var_step > zero, + z3.And(var >= var_begin, var <= var_end)), + z3.Implies(var_step < zero, + z3.And(var <= var_begin, var >= var_end)), + var == var_begin + i * var_step, + i >= zero)) + # Prohibit overflow/underflow of "i * var_step" + if self.opts.use_bv and self.opts.prohibit_overflow: + self._add_constraint(z3.BVMulNoOverflow(i, var_step, True)) + self._add_constraint(z3.BVMulNoUnderflow(i, var_step)) def _add_array_access(self, array_name: str, access: ArrayAccess): '''Add an array access to the current access dict.''' From 5cb7280edca6231d0fbc8bc4cb52af03224be82d Mon Sep 17 00:00:00 2001 From: Matthew Naylor Date: Thu, 27 Nov 2025 14:53:08 +0000 Subject: [PATCH 17/17] flake8 --- src/psyclone/psyir/tools/array_index_analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/psyclone/psyir/tools/array_index_analysis.py b/src/psyclone/psyir/tools/array_index_analysis.py index 355668282c..4ac1580d07 100644 --- a/src/psyclone/psyir/tools/array_index_analysis.py +++ b/src/psyclone/psyir/tools/array_index_analysis.py @@ -419,9 +419,9 @@ def _constrain_loop_var(self, self._add_constraint(z3.And( var_step != zero, z3.Implies(var_step > zero, - z3.And(var >= var_begin, var <= var_end)), + z3.And(var >= var_begin, var <= var_end)), z3.Implies(var_step < zero, - z3.And(var <= var_begin, var >= var_end)), + z3.And(var <= var_begin, var >= var_end)), var == var_begin + i * var_step, i >= zero)) # Prohibit overflow/underflow of "i * var_step"