Skip to content

Commit 90cf522

Browse files
committed
prefix sum
1 parent 5f8b522 commit 90cf522

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ My solutions to CUDA challenges on https://leetgpu.com/
1515
[Rainbow Table](https://leetgpu.com/challenges/rainbow-table) | [Link](./rainbow_table.cu) | Easy |
1616
[Reduction](https://leetgpu.com/challenges/reduction) | [Link](./reduction.cu) | Medium |
1717
[Softmax](https://leetgpu.com/challenges/softmax) | [Link](./softmax.cu) | Medium |
18+
[Prefix Sum](https://leetgpu.com/challenges/prefix-sum) | [Link](./prefix_sum.cu) | Medium |
1819
[Dot Product](https://leetgpu.com/challenges/dot-product) | [Link](./dot_product.cu) | Medium |
1920
[Softmax Attention](https://leetgpu.com/challenges/softmax-attention) | [Link](./softmax_attention.cu) | Medium |
2021
[Password Cracking (FNV-1a)](https://leetgpu.com/challenges/password-cracking-fnv-1a) | [Link](./password_cracking_fnv_1a.cu) | Medium |

prefix_sum.cu

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include "solve.h"
2+
#include <cuda_runtime.h>
3+
4+
#define FULL_MASK 0xffffffff
5+
6+
__device__ float store[1024*32];
7+
8+
template<bool store_value>
9+
__device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
10+
int tid = threadIdx.x;
11+
int num_threads = blockDim.x;
12+
int block_id = blockIdx.x;
13+
int warp_id = tid / 32;
14+
int lane_id = tid % 32;
15+
16+
int loop_bound = 32*1024;
17+
for (int i = tid; i < loop_bound; i += num_threads) {
18+
float f = i < N ? input[i] : 0;
19+
// sum over warp
20+
for (int i = 16; i >= 1; i >>= 1) {
21+
f += __shfl_xor_sync(FULL_MASK, f, i);
22+
}
23+
// store the sum of these 32 values
24+
if (lane_id == 0) {
25+
s[i/32] = f;
26+
}
27+
}
28+
__syncthreads();
29+
30+
// up sweep
31+
int offset = 1;
32+
for (int d = 512; d > 0; d >>= 1) {
33+
__syncthreads();
34+
if (tid < d) {
35+
int a = (tid+1) * (offset * 2) - 1 - offset;
36+
int b = (tid+1) * (offset * 2) - 1;
37+
s[b] += s[a];
38+
}
39+
offset *= 2;
40+
}
41+
42+
// down sweep
43+
for (int d = 2; d < 1024; d *= 2) {
44+
offset >>= 1;
45+
__syncthreads();
46+
if (tid < d - 1) {
47+
int a = (tid+1) * offset - 1;
48+
int b = (tid+1) * offset - 1 + offset/2;
49+
s[b] += s[a];
50+
}
51+
}
52+
__syncthreads();
53+
54+
for (int i = tid; i < loop_bound; i += num_threads) {
55+
float f = i < N ? input[i] : 0;
56+
for (int d = 1; d <= 16; d *= 2) {
57+
float _f = __shfl_up_sync(FULL_MASK, f, d);
58+
if (lane_id - d >= 0) f += _f;
59+
}
60+
if (i < N) {
61+
if (i >= 32) {
62+
f += s[i/32 - 1];
63+
}
64+
output[i] = f;
65+
}
66+
}
67+
// for (int i = tid * 32; i < min(N, (tid+1)*32); i++) {
68+
// float ans = input[i];
69+
// if (i % 32 != 0) {
70+
// ans += output[i-1];
71+
// }
72+
// if (tid > 0) {
73+
// ans += s[i/32 - 1];
74+
// }
75+
// output[i] = ans;
76+
// }
77+
78+
if constexpr (store_value) {
79+
if (tid == 0) {
80+
store[block_id] = output[N-1];
81+
}
82+
}
83+
}
84+
85+
// template<bool store_value>
86+
// __device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
87+
// int tid = threadIdx.x;
88+
// int block_id = blockIdx.x;
89+
// int start = tid * 32;
90+
// if (start < N) {
91+
// output[start] = input[start];
92+
// for (int i = start + 1; i < min(N, start + 32); i++) {
93+
// output[i] = output[i-1] + input[i];
94+
// }
95+
// }
96+
// __syncthreads();
97+
// if (tid == 0) {
98+
// for (int i = 32+31; i < N; i += 32) {
99+
// output[i] += output[i-32];
100+
// }
101+
// }
102+
// __syncthreads();
103+
104+
// for (int i = start; i < min(N, start + 31); i++) {
105+
// if (tid != 0) {
106+
// output[i] += output[start - 1];
107+
// }
108+
// }
109+
110+
// if constexpr (store_value) {
111+
// store[block_id] = output[N - 1];
112+
// }
113+
// }
114+
115+
// prefix sum small chunks of the overall array of size NUM_THREADS * 32.
116+
__global__ void prefix_sum_kernel1(const float* input, float* output, int N) {
117+
extern __shared__ float s[]; // shared memory, size intended to be N block / 32
118+
119+
int num_per_block = blockDim.x * 32;
120+
int block_id = blockIdx.x;
121+
int N_this_block = min(num_per_block, N - num_per_block * block_id);
122+
prefix_sum_compute<true>(input + num_per_block * block_id, output + num_per_block * block_id, N_this_block, s);
123+
124+
}
125+
126+
// prefix sum over store
127+
__global__ void prefix_sum_kernel2(int N_store) {
128+
extern __shared__ float s[]; // shared memory, size intended to be N / 32
129+
prefix_sum_compute<false>(store, store, N_store, s);
130+
}
131+
132+
133+
// add store's sums to each element
134+
__global__ void prefix_sum_kernel3(float* output, int N) {
135+
int tid = threadIdx.x;
136+
int block_id = blockIdx.x;
137+
int num_threads = blockDim.x;
138+
int num_per_block = num_threads * 32;
139+
int loop_end = min(N, num_per_block * (block_id + 1));
140+
// first block is already done
141+
if (block_id > 0) {
142+
int store_val = store[block_id - 1];
143+
for (int i = num_per_block * block_id + tid; i < loop_end; i += num_threads) {
144+
output[i] += store_val;
145+
}
146+
}
147+
}
148+
149+
// input, output are device pointers
150+
void solve(const float* input, float* output, int N) {
151+
int num_threads = 1024;
152+
int num_blocks = (N + (32*num_threads - 1)) / (32*num_threads);
153+
prefix_sum_kernel1<<<num_blocks, 1024, num_threads * sizeof(float)>>>(input, output, N);
154+
prefix_sum_kernel2<<<1, 1024, num_threads * sizeof(float)>>>(num_blocks);
155+
prefix_sum_kernel3<<<num_blocks, 1024>>>(output, N);
156+
}

0 commit comments

Comments
 (0)