Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions aie_kernels/aie2/reduce_min.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include "../aie_kernel_utils.h"
#include <aie_api/aie.hpp>

void _reduce_min_vector(int32_t *restrict in, int32_t *restrict out,
const int32_t input_size) {
void reduce_min_vector(int32_t *restrict in, int32_t *restrict out,
const int32_t input_size) {

event0();
v16int32 massive = broadcast_to_v16int32((int32_t)INT32_MAX);
Expand Down Expand Up @@ -46,8 +46,8 @@ void _reduce_min_vector(int32_t *restrict in, int32_t *restrict out,
return;
}

void _reduce_min_scalar(int32_t *restrict in, int32_t *restrict out,
const int32_t input_size) {
void reduce_min_scalar(int32_t *restrict in, int32_t *restrict out,
const int32_t input_size) {
event0();
int32_t running_min = (int32_t)INT32_MAX;
for (int32_t i = 0; i < input_size; i++) {
Expand All @@ -59,15 +59,3 @@ void _reduce_min_scalar(int32_t *restrict in, int32_t *restrict out,

return;
}

extern "C" {

void reduce_min_vector(int32_t *a_in, int32_t *c_out, int32_t input_size) {
_reduce_min_vector(a_in, c_out, input_size);
}

void reduce_min_scalar(int32_t *a_in, int32_t *c_out, int32_t input_size) {
_reduce_min_scalar(a_in, c_out, input_size);
}

} // extern "C"
10 changes: 5 additions & 5 deletions programming_examples/basic/vector_reduce_min/Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
##===- Makefile -----------------------------------------------------------===##
#
#
# This file licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
#
##===----------------------------------------------------------------------===##

srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))
Expand Down Expand Up @@ -35,9 +35,9 @@ else
cd ${@D} && ${PEANO_INSTALL_DIR}/bin/clang++ ${PEANOWRAP2_FLAGS} -c $< -o ${@F}
endif

build/aie.mlir: ${srcdir}/${aie_py_src}
build/aie.mlir: ${srcdir}/${aie_py_src} build/reduce_min.cc.o
mkdir -p ${@D}
python3 $< ${devicename} ${col} > $@
cd ${@D} && python3 $< ${devicename} ${col} > ${@F}

build/final.xclbin: build/aie.mlir build/reduce_min.cc.o
mkdir -p ${@D}
Expand All @@ -53,7 +53,7 @@ ${targetname}.exe: ${srcdir}/test.cpp
ifeq "${powershell}" "powershell.exe"
cp _build/${targetname}.exe $@
else
cp _build/${targetname} $@
cp _build/${targetname} $@
endif

run: ${targetname}.exe build/final.xclbin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def my_reduce_min():

# AIE Core Function declarations
reduce_add_vector = Kernel(
"reduce_min_vector", "reduce_min.cc.o", [in_ty, out_ty, np.int32]
"reduce_min_vector(int*, int*, int)",
"reduce_min.cc.o",
[in_ty, out_ty, np.int32],
)

# Define a task
Expand Down
42 changes: 40 additions & 2 deletions python/iron/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# (c) Copyright 2024 Advanced Micro Devices, Inc.
# (c) Copyright 2024-2025 Advanced Micro Devices, Inc.

import os
import numpy as np
import cxxfilt
from elftools.elf.elffile import ELFFile
from elftools.elf.sections import SymbolTableSection

from .. import ir # type: ignore
from ..extras.dialects.ext.func import FuncOp # type: ignore
Expand All @@ -15,6 +19,34 @@
from .resolvable import Resolvable


def find_mangled_symbol(file: os.PathLike, demangled_name):
"""
Find the mangled symbol that corresponds to the demangled_name.

Args:
file (str): Path to the file to analyze
demangled_name (str): The demangled name of the symbol to find

Returns:
str: The mangled name of the symbol if found, otherwise None
"""
with open(file, "rb") as file:
elf_file = ELFFile(file)

for section in elf_file.iter_sections():
if isinstance(section, SymbolTableSection):
for symbol in section.iter_symbols():
# Filter out function symbols
if symbol and symbol["st_info"]["type"] == "STT_FUNC":
if symbol.name == demangled_name:
# Name matches the demangled name, thus it has C linkage
return symbol.name
if cxxfilt.demangle(symbol.name) == demangled_name:
# Demangled symbol name matches the demangled name, thus it has C++ linkage
return symbol.name
return None


class BaseKernel(Resolvable):
"""Base class for kernel-like objects that resolve to FuncOp."""

Expand Down Expand Up @@ -72,7 +104,13 @@ def resolve(
ip: ir.InsertionPoint | None = None,
) -> None:
if not self._op:
self._op = external_func(self._name, inputs=self._arg_types)
bin_file = os.path.abspath(self._bin_name)
symbol_name = find_mangled_symbol(bin_file, self._name)
if not symbol_name:
raise ValueError(
f"Could not find symbol for {self._name} in {bin_file}"
)
self._op = external_func(symbol_name, inputs=self._arg_types)


class ExternalFunction(BaseKernel):
Expand Down
2 changes: 2 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ dataclasses>=0.6, <=0.8
numpy>=1.19.5, <2.0 # 2.1 would be nice for typing improvements, but it doesn't seem to work well with pyxrt
rich
ml_dtypes
cxxfilt
elftools
Loading