forked from ROCm/composable_kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.hpp
More file actions
50 lines (40 loc) · 1.93 KB
/
utils.hpp
File metadata and controls
50 lines (40 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}