Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions .github/workflows/ci-paddle.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
name: CI Paddle
on:
push:
branches: [paddle]
tags: ["v*"]
pull_request:
merge_group:
workflow_dispatch:

permissions:
contents: read

concurrency:
group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true

jobs:
test:
name: Test
runs-on:
group: H20
timeout-minutes: 30
env:
container_name: tilelang-paddle-test-${{ github.run_id }}
steps:
- name: Check docker image and run container
env:
FLAGS_fraction_of_gpu_memory_to_use: 0.15
CTEST_PARALLEL_LEVEL: 2
WITH_GPU: "ON"
CUDA_ARCH_NAME: Hopper
WITH_AVX: "ON"
PY_VERSION: "3.10"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
no_proxy: "bcebos.com,apiin.im.baidu.com,gitee.com,aliyun.com,.baidu.com,.tuna.tsinghua.edu.cn"
run: |
docker_image=ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:cuda129-coverage-test
docker run -d -t --gpus all --name ${{ env.container_name }} \
-v "/dev/shm:/dev/shm" \
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
-v ${{ github.workspace }}:/workspace \
-e FLAGS_fraction_of_gpu_memory_to_use \
-e CTEST_PARALLEL_LEVEL \
-e WITH_GPU \
-e CUDA_ARCH_NAME \
-e WITH_AVX \
-e PY_VERSION \
-e GITHUB_TOKEN \
-e no_proxy \
-w /workspace \
--network host \
${docker_image}

- name: Checkout repository
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source ${{ github.workspace }}/../../../proxy
git config --global --add safe.directory "*"
# Clean workspace
find . -maxdepth 1 ! -name "." -exec rm -rf {} +
# Checkout
git init
git remote add origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
git fetch origin ${{ github.ref }} --depth=1
git checkout FETCH_HEAD
git submodule update --init --recursive
'

- name: Install dependencies
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source ${{ github.workspace }}/../../../proxy

# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

# Create and activate virtual environment
uv venv .venv --seed
source .venv/bin/activate

# Install paddle
uv pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/

# Install project and minimal test runner
uv pip install pytest
uv pip install -e .
'

- name: Run tests
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source .venv/bin/activate
pytest tests_paddle/
'

- name: Terminate and delete the container
if: always()
run: |
set +e
docker stop ${{ env.container_name }}
docker rm ${{ env.container_name }}
7 changes: 5 additions & 2 deletions .github/workflows/dist-paddle.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
name: Dist
name: Dist Paddle
on:
push:
branches: [paddle]
tags: ["v*"]
pull_request:
merge_group:
workflow_dispatch:

permissions:
Expand All @@ -28,7 +31,7 @@ jobs:
# Otherwise, the version of the SDist has a git hash suffix (e.g., 0.1.0+gitabcdef12),
# but the package built from the SDist has no way to get the git hash (it is not a git repo),
# leading to inconsistent versions between SDist and built packages (+gitabcdef12 vs. +gitunknown).
NO_VERSION_LABEL: 'ON'
NO_VERSION_LABEL: "ON"

steps:
- name: Checkout repository
Expand Down
81 changes: 81 additions & 0 deletions tests_paddle/test_quick_start.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np
import paddle

paddle.compat.enable_torch_proxy(scope={"tilelang"})

import tilelang
import tilelang.language as T


# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)

# Clear local accumulation
T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)

# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

return matmul_relu_kernel


def test_quick_start():
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

# Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# Test the kernel in Python with PyTorch data
import paddle

# Create random input tensors on the GPU
a = paddle.randn(M, K, device="cuda", dtype=paddle.float16)
b = paddle.randn(K, N, device="cuda", dtype=paddle.float16)
c = paddle.empty(M, N, device="cuda", dtype=paddle.float16)

# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = paddle.nn.functional.relu(a @ b)

# Validate correctness
np.testing.assert_allclose(c.numpy(), ref_c.numpy(), rtol=1e-2, atol=1e-2)
Loading