Skip to content

Commit 995050e

Browse files
[cherry-pick] Fix transpose for big tensor (#76373)
* Unify the dtype used in the transpose kernel * fix build error --------- Co-authored-by: zhangting2020 <[email protected]>
1 parent 8b12875 commit 995050e

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

paddle/phi/kernels/funcs/transpose_function.cu.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -460,20 +460,20 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
460460
"SelectProperTileSize should return true, but return value is:%d.",
461461
ret));
462462

463-
int tile_long_edge = 0;
464-
int tile_short_edge = 0;
463+
IndexType tile_long_edge = 0;
464+
IndexType tile_short_edge = 0;
465465
float lowest_cost = std::numeric_limits<float>::max();
466-
int input_long_edge = std::max(input_dims[1], input_dims[2]);
466+
IndexType input_long_edge = std::max(input_dims[1], input_dims[2]);
467467

468468
// Find the tile size that best suit in inputs.
469469
for (auto tile_size_pair : tile_sele) {
470470
int proposed_tile_long_edge = tile_size_pair.first;
471471
// data may not aligned to tile, so some threads wasted, we need
472472
// to find least wasted threads, which means we need to find tile
473473
// can split input properly, in another words: num_wasted_threads=0.
474-
int num_full_tiles = input_long_edge / proposed_tile_long_edge;
474+
IndexType num_full_tiles = input_long_edge / proposed_tile_long_edge;
475475

476-
int num_wasted_threads =
476+
IndexType num_wasted_threads =
477477
input_long_edge - num_full_tiles * proposed_tile_long_edge;
478478

479479
float cost = num_wasted_threads;
@@ -490,9 +490,9 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
490490
// The tile size we select should be match with input dim, long side to long
491491
// short side to short.
492492
// First set long side as i if dim1 > Tile min size, then set dim2 as j.
493-
int select_tile_size_i =
493+
IndexType select_tile_size_i =
494494
input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1];
495-
int select_tile_size_j =
495+
IndexType select_tile_size_j =
496496
input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge;
497497

498498
// Check if i is long edge, if not set i as short.
@@ -584,9 +584,9 @@ __global__ void
584584
__launch_bounds__(BLOCK_DIM* BLOCK_DIM) inline fp8_fast_transpose_kernel(
585585
const phi::float8_e4m3fn* __restrict__ src, // Source matrix (M x N)
586586
phi::float8_e4m3fn* __restrict__ dst, // Destination matrix (N x M)
587-
int B,
588-
int M,
589-
int N, // Batch size, M-dimension, N-dimension
587+
uint32_t B,
588+
uint32_t M,
589+
uint32_t N, // Batch size, M-dimension, N-dimension
590590
size_t batch_stride) { // Stride between batches in global memory (M*N
591591
// elements)
592592
// Shared memory tile with padding to avoid bank conflicts, padding instead of
@@ -951,8 +951,8 @@ struct PermTypeClassifier {
951951
type_ = PermuteType::kGeneralTranspose;
952952
num_rows_tile_ = GET_TILE_SIZE(dims[rank - 2], kTileSize);
953953
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
954-
int tile_size = channel * num_rows_tile_ *
955-
GET_TILE_SIZE(dims[last_idx], kTileSize);
954+
int64_t tile_size = channel * num_rows_tile_ *
955+
GET_TILE_SIZE(dims[last_idx], kTileSize);
956956
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
957957
} else {
958958
type_ = PermuteType::kGeneralPermute;
@@ -970,7 +970,7 @@ struct PermTypeClassifier {
970970
num_rows_tile_ = GET_TILE_SIZE(dims[0], kTileSize);
971971

972972
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
973-
int tile_size =
973+
int64_t tile_size =
974974
dims[1] * num_rows_tile_ * GET_TILE_SIZE(dims[2], kTileSize);
975975
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
976976
} else {

0 commit comments

Comments
 (0)