@@ -130,41 +130,6 @@ func.func @donot_replace_leakyrelu(%arg0 : tensor<1x104x104x128xf32, #zhigh.layo
130130
131131// -----
132132
133- func.func @replace_sqrt (%arg0 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
134- %0 = " zhigh.Unstick" (%arg0 ) : (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <4 x256 x1 xf32 >
135- %1 = " onnx.Sqrt" (%0 ) : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 >
136- %2 = " zhigh.Stick" (%1 ) {layout = " 3D" } : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
137- return %2 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
138-
139- // CHECK-LABEL: func.func @replace_sqrt
140- // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>> {
141- // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Log"([[PARAM_0_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
142- // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
143- // CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3D"} : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
144- // CHECK: [[VAR_3_:%.+]] = "zhigh.Mul"([[VAR_0_]], [[VAR_2_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>, tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
145- // CHECK: [[VAR_4_:%.+]] = "zhigh.Exp"([[VAR_3_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
146- // CHECK: return [[VAR_4_]] : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
147- // CHECK: }
148- }
149-
150- // -----
151-
152- // Do not replace square root because of unknown dimension.
153- // In this case, there is no static shape to create a constant of 2.
154- func.func @donot_replace_sqrt (%arg0 : tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
155- %0 = " zhigh.Unstick" (%arg0 ) : (tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <?x256 x1 xf32 >
156- %1 = " onnx.Sqrt" (%0 ) : (tensor <?x256 x1 xf32 >) -> tensor <?x256 x1 xf32 >
157- %2 = " zhigh.Stick" (%1 ) {layout = " 3D" } : (tensor <?x256 x1 xf32 >) -> tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
158- return %2 : tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
159-
160- // CHECK-LABEL: func.func @donot_replace_sqrt
161- // CHECK: zhigh.Unstick
162- // CHECK: onnx.Sqrt
163- // CHECK: zhigh.Stick
164- }
165-
166- // -----
167-
168133func.func @replace_reciprocal_sqrt (%arg0 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
169134 %0 = " zhigh.Unstick" (%arg0 ) : (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <4 x256 x1 xf32 >
170135 %1 = " onnx.Sqrt" (%0 ) : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 >
0 commit comments