@@ -460,20 +460,20 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
460460 " SelectProperTileSize should return true, but return value is:%d." ,
461461 ret));
462462
463- int tile_long_edge = 0 ;
464- int tile_short_edge = 0 ;
463+ IndexType tile_long_edge = 0 ;
464+ IndexType tile_short_edge = 0 ;
465465 float lowest_cost = std::numeric_limits<float >::max ();
466- int input_long_edge = std::max (input_dims[1 ], input_dims[2 ]);
466+ IndexType input_long_edge = std::max (input_dims[1 ], input_dims[2 ]);
467467
468468 // Find the tile size that best suit in inputs.
469469 for (auto tile_size_pair : tile_sele) {
470470 int proposed_tile_long_edge = tile_size_pair.first ;
471471 // data may not aligned to tile, so some threads wasted, we need
472472 // to find least wasted threads, which means we need to find tile
473473 // can split input properly, in another words: num_wasted_threads=0.
474- int num_full_tiles = input_long_edge / proposed_tile_long_edge;
474+ IndexType num_full_tiles = input_long_edge / proposed_tile_long_edge;
475475
476- int num_wasted_threads =
476+ IndexType num_wasted_threads =
477477 input_long_edge - num_full_tiles * proposed_tile_long_edge;
478478
479479 float cost = num_wasted_threads;
@@ -490,9 +490,9 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
490490 // The tile size we select should be match with input dim, long side to long
491491 // short side to short.
492492 // First set long side as i if dim1 > Tile min size, then set dim2 as j.
493- int select_tile_size_i =
493+ IndexType select_tile_size_i =
494494 input_dims[1 ] >= kMinTileSize ? tile_long_edge : input_dims[1 ];
495- int select_tile_size_j =
495+ IndexType select_tile_size_j =
496496 input_dims[1 ] >= kMinTileSize ? input_dims[2 ] : tile_long_edge;
497497
498498 // Check if i is long edge, if not set i as short.
@@ -584,9 +584,9 @@ __global__ void
584584__launch_bounds__ (BLOCK_DIM* BLOCK_DIM) inline fp8_fast_transpose_kernel (
585585 const phi::float8_e4m3fn* __restrict__ src, // Source matrix (M x N)
586586 phi::float8_e4m3fn* __restrict__ dst, // Destination matrix (N x M)
587- int B,
588- int M,
589- int N, // Batch size, M-dimension, N-dimension
587+ uint32_t B,
588+ uint32_t M,
589+ uint32_t N, // Batch size, M-dimension, N-dimension
590590 size_t batch_stride) { // Stride between batches in global memory (M*N
591591 // elements)
592592 // Shared memory tile with padding to avoid bank conflicts, padding instead of
@@ -951,8 +951,8 @@ struct PermTypeClassifier {
951951 type_ = PermuteType::kGeneralTranspose ;
952952 num_rows_tile_ = GET_TILE_SIZE (dims[rank - 2 ], kTileSize );
953953 int dim_vec_size = GetDimVecSize (dst_vec_size, dims[last_idx], src);
954- int tile_size = channel * num_rows_tile_ *
955- GET_TILE_SIZE (dims[last_idx], kTileSize );
954+ int64_t tile_size = channel * num_rows_tile_ *
955+ GET_TILE_SIZE (dims[last_idx], kTileSize );
956956 vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
957957 } else {
958958 type_ = PermuteType::kGeneralPermute ;
@@ -970,7 +970,7 @@ struct PermTypeClassifier {
970970 num_rows_tile_ = GET_TILE_SIZE (dims[0 ], kTileSize );
971971
972972 int dim_vec_size = GetDimVecSize (dst_vec_size, dims[last_idx], src);
973- int tile_size =
973+ int64_t tile_size =
974974 dims[1 ] * num_rows_tile_ * GET_TILE_SIZE (dims[2 ], kTileSize );
975975 vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
976976 } else {
0 commit comments