Skip to content

Commit c57b4c1

Browse files
committed
change test_fp16 -> _run
1 parent 7662bfe commit c57b4c1

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

paddle/cinn/hlir/dialect/operator/transforms/add_cast_to_elementwise_add_pass.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ namespace cinn {
2727
namespace dialect {
2828
namespace ir {
2929

30-
auto is_integer_or_bool = [](const pir::Type& x) {
31-
return x.isa<pir::IndexType>() || x.isa<pir::Int64Type>() ||
32-
x.isa<pir::Int32Type>() || x.isa<pir::Int16Type>() ||
33-
x.isa<pir::Int8Type>() || x.isa<pir::UInt8Type>() ||
34-
x.isa<pir::BoolType>();
35-
};
36-
3730
pir::Type GetOutputDtype(const pir::Type& x, const pir::Type& y) {
3831
pir::IrContext* context = pir::IrContext::Instance();
3932
// type promotion
@@ -44,6 +37,13 @@ pir::Type GetOutputDtype(const pir::Type& x, const pir::Type& y) {
4437
return pir::Complex64Type::get(context);
4538
}
4639

40+
auto is_integer_or_bool = [](const pir::Type& x) {
41+
return x.isa<pir::IndexType>() || x.isa<pir::Int64Type>() ||
42+
x.isa<pir::Int32Type>() || x.isa<pir::Int16Type>() ||
43+
x.isa<pir::Int8Type>() || x.isa<pir::UInt8Type>() ||
44+
x.isa<pir::BoolType>();
45+
};
46+
4747
if (is_integer_or_bool(x) || is_integer_or_bool(y)) {
4848
PADDLE_THROW(::common::errors::InvalidType(
4949
"Type promotion only support calculations between floating-point "

paddle/cinn/hlir/dialect/operator/transforms/add_cast_to_elementwise_add_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

test/dygraph_to_static/test_cast_pass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -32,16 +32,19 @@ def func(x, y):
3232

3333
class TestAddCastToElementwiseAddPass(Dy2StTestBase):
3434
# test AddCastToElementwiseAddPass
35-
def test_bf16(self, dtype="float16"):
35+
def _run(self, dtype):
3636
static_fn = paddle.jit.to_static(func)
3737
x = paddle.randn([200, 200])
3838
y = paddle.randn([200, 200], dtype=dtype)
3939
np.testing.assert_allclose(
4040
static_fn(x, y).numpy(), x.numpy() + y.cast("float32").numpy()
4141
)
4242

43+
def test_bf16(self):
44+
self._run(dtype="bfloat16")
45+
4346
def test_fp16(self):
44-
self.test_bf16("float16")
47+
self._run(dtype="float16")
4548

4649

4750
if __name__ == '__main__':

0 commit comments

Comments
 (0)