Skip to content

Commit 1e74ecc

Browse files
committed
prefix sum
1 parent 5f8b522 commit 1e74ecc

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ My solutions to CUDA challenges on https://leetgpu.com/
1515
[Rainbow Table](https://leetgpu.com/challenges/rainbow-table) | [Link](./rainbow_table.cu) | Easy |
1616
[Reduction](https://leetgpu.com/challenges/reduction) | [Link](./reduction.cu) | Medium |
1717
[Softmax](https://leetgpu.com/challenges/softmax) | [Link](./softmax.cu) | Medium |
18+
[Prefix Sum](https://leetgpu.com/challenges/prefix-sum) | [Link](./prefix_sum.cu) | Medium |
1819
[Dot Product](https://leetgpu.com/challenges/dot-product) | [Link](./dot_product.cu) | Medium |
1920
[Softmax Attention](https://leetgpu.com/challenges/softmax-attention) | [Link](./softmax_attention.cu) | Medium |
2021
[Password Cracking (FNV-1a)](https://leetgpu.com/challenges/password-cracking-fnv-1a) | [Link](./password_cracking_fnv_1a.cu) | Medium |

prefix_sum.cu

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "solve.h"
2+
#include <cuda_runtime.h>
3+
4+
#define FULL_MASK 0xffffffff
5+
6+
__device__ float store[1024*32];
7+
8+
template<bool store_value>
9+
__device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
10+
int tid = threadIdx.x;
11+
int num_threads = blockDim.x;
12+
int block_id = blockIdx.x;
13+
int warp_id = tid / 32;
14+
int lane_id = tid % 32;
15+
16+
s[tid] = 0;
17+
__syncthreads();
18+
19+
// int loop_bound = 32*1024;
20+
int loop_bound = (N + 31);
21+
loop_bound -= (loop_bound % 32);
22+
for (int i = tid; i < loop_bound; i += num_threads) {
23+
float f = i < N ? input[i] : 0;
24+
// sum over warp
25+
for (int i = 16; i >= 1; i >>= 1) {
26+
f += __shfl_xor_sync(FULL_MASK, f, i);
27+
}
28+
// store the sum of these 32 values
29+
if (lane_id == 0) {
30+
s[i/32] = f;
31+
}
32+
}
33+
__syncthreads();
34+
35+
// up sweep
36+
int offset = 1;
37+
for (int d = 512; d > 0; d >>= 1) {
38+
__syncthreads();
39+
if (tid < d) {
40+
int a = (tid+1) * (offset * 2) - 1 - offset;
41+
int b = (tid+1) * (offset * 2) - 1;
42+
s[b] += s[a];
43+
}
44+
offset *= 2;
45+
}
46+
47+
// down sweep
48+
for (int d = 2; d < 1024; d *= 2) {
49+
offset >>= 1;
50+
__syncthreads();
51+
if (tid < d - 1) {
52+
int a = (tid+1) * offset - 1;
53+
int b = (tid+1) * offset - 1 + offset/2;
54+
s[b] += s[a];
55+
}
56+
}
57+
__syncthreads();
58+
59+
for (int i = tid; i < loop_bound; i += num_threads) {
60+
float f = i < N ? input[i] : 0;
61+
for (int d = 1; d <= 16; d *= 2) {
62+
float _f = __shfl_up_sync(FULL_MASK, f, d);
63+
if (lane_id - d >= 0) f += _f;
64+
}
65+
if (i < N) {
66+
if (i >= 32) {
67+
f += s[i/32 - 1];
68+
}
69+
output[i] = f;
70+
}
71+
}
72+
// for (int i = tid * 32; i < min(N, (tid+1)*32); i++) {
73+
// float ans = input[i];
74+
// if (i % 32 != 0) {
75+
// ans += output[i-1];
76+
// }
77+
// if (tid > 0) {
78+
// ans += s[i/32 - 1];
79+
// }
80+
// output[i] = ans;
81+
// }
82+
83+
if constexpr (store_value) {
84+
if (tid == 0) {
85+
store[block_id] = output[N-1];
86+
}
87+
}
88+
}
89+
90+
// template<bool store_value>
91+
// __device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
92+
// int tid = threadIdx.x;
93+
// int block_id = blockIdx.x;
94+
// int start = tid * 32;
95+
// if (start < N) {
96+
// output[start] = input[start];
97+
// for (int i = start + 1; i < min(N, start + 32); i++) {
98+
// output[i] = output[i-1] + input[i];
99+
// }
100+
// }
101+
// __syncthreads();
102+
// if (tid == 0) {
103+
// for (int i = 32+31; i < N; i += 32) {
104+
// output[i] += output[i-32];
105+
// }
106+
// }
107+
// __syncthreads();
108+
109+
// for (int i = start; i < min(N, start + 31); i++) {
110+
// if (tid != 0) {
111+
// output[i] += output[start - 1];
112+
// }
113+
// }
114+
115+
// if constexpr (store_value) {
116+
// store[block_id] = output[N - 1];
117+
// }
118+
// }
119+
120+
// prefix sum small chunks of the overall array of size NUM_THREADS * 32.
121+
__global__ void prefix_sum_kernel1(const float* input, float* output, int N) {
122+
extern __shared__ float s[]; // shared memory, size intended to be N block / 32
123+
124+
int num_per_block = blockDim.x * 32;
125+
int block_id = blockIdx.x;
126+
int N_this_block = min(num_per_block, N - num_per_block * block_id);
127+
prefix_sum_compute<true>(input + num_per_block * block_id, output + num_per_block * block_id, N_this_block, s);
128+
129+
}
130+
131+
// prefix sum over store
132+
__global__ void prefix_sum_kernel2(int N_store) {
133+
extern __shared__ float s[]; // shared memory, size intended to be N / 32
134+
prefix_sum_compute<false>(store, store, N_store, s);
135+
}
136+
137+
138+
// add store's sums to each element
139+
__global__ void prefix_sum_kernel3(float* output, int N) {
140+
int tid = threadIdx.x;
141+
int block_id = blockIdx.x;
142+
int num_threads = blockDim.x;
143+
int num_per_block = num_threads * 32;
144+
int loop_end = min(N, num_per_block * (block_id + 1));
145+
// first block is already done
146+
if (block_id > 0) {
147+
int store_val = store[block_id - 1];
148+
for (int i = num_per_block * block_id + tid; i < loop_end; i += num_threads) {
149+
output[i] += store_val;
150+
}
151+
}
152+
}
153+
154+
// input, output are device pointers
155+
void solve(const float* input, float* output, int N) {
156+
int num_threads = 1024;
157+
int num_blocks = (N + (32*num_threads - 1)) / (32*num_threads);
158+
prefix_sum_kernel1<<<num_blocks, 1024, num_threads * sizeof(float)>>>(input, output, N);
159+
prefix_sum_kernel2<<<1, 1024, num_threads * sizeof(float)>>>(num_blocks);
160+
prefix_sum_kernel3<<<num_blocks, 1024>>>(output, N);
161+
}

0 commit comments

Comments
 (0)