Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1111,20 +1111,15 @@ void FlashMaskV2GradBaseKernel(
return {64, 64};
}
}
} else if (head_size_rounded <= 192) {
// umiswing: head dim > 128 is not supported now
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
"not supported in FlashMask V3 now.",
head_size_rounded));
return {0, 0};
} else if (head_size_rounded <= 256) {
// umiswing: head dim > 128 is not supported now
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
"not supported in FlashMask V3 now.",
head_size_rounded));
return {0, 0};
// umiswing: by now, we reuse template instantiation of head dim 256 for
// head dim in range (128, 256], and therefore no separate dispatch for
// head dim in range (128, 192]
if(has_lt_end && has_ut_start) {
return {64, 32};
} else {
return {64, 64};
}
} else {
PADDLE_THROW(
common::errors::Unimplemented("head dim is rounded to %d, which is "
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ void FlashMaskV2BaseKernel(
common::errors::InvalidArgument(
"batch_size must be equal to batch_size_k"));
}
int const max_headdim = std::min(flashmaskv2_get_max_headdim(), 128);
int const max_headdim = flashmaskv2_get_max_headdim();
PADDLE_ENFORCE_LE(
head_size,
max_headdim,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/flash_attn_v3_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ inline int get_max_headdim() {
return 0;
}

inline int flashmaskv2_get_max_headdim() { return 128; }
inline int flashmaskv2_get_max_headdim() { return 256; }

inline int round_up_headdim(int head_size) {
#ifndef FLASHATTENTION_DISABLE_HDIM64
Expand Down
Loading