xref: /llvm-project/mlir/test/Dialect/Quant/strip-func-quant-types.mlir (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s
2
3// CHECK-LABEL: @strip_operands
4// CHECK-SAME: %[[ARG_0:.*]]: i8
5// CHECK-SAME: %[[ARG_1:.*]]: i16
6// CHECK-SAME: %[[ARG_2:.*]]: f32
7
8// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}>
9// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}>
10
11// CHECK: "test.custom_op"(%[[ARG_1_CAST]])
12// CHECK: "test.custom_op"(%[[ARG_0_CAST]])
13// CHECK: "test.custom_op"(%[[ARG_2]])
14
15!qalias = !quant.uniform<i8:f32, 2.0:128>
16!qalias1 = !quant.uniform<i16:f32, 3.0:128>
17
18func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) {
19  "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias>
20  "test.custom_op"(%arg1) : (!qalias1) -> tensor<?x!qalias1>
21  "test.custom_op"(%arg2) : (f32) -> tensor<4xf32>
22}
23
24// -----
25
26// CHECK-LABEL: @strip_results
27// CHECK-SAME: tensor<4xi8>, tensor<?xi16>, tensor<*xi8>, tensor<4xf32>
28
29// CHECK: %[[RESULT_0:.*]] = "test.custom_op"()
30// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
31
32// CHECK: %[[RESULT_1:.*]] = "test.custom_op"()
33// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
34
35// CHECK: %[[RESULT_2:.*]] = "test.custom_op"()
36// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8>
37
38// CHECK: %[[RESULT_3:.*]] = "test.custom_op"()
39
40// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]]
41
42!qalias = !quant.uniform<i8:f32, 2.0:128>
43!qalias1 = !quant.uniform<i16:f32, 3.0:128>
44
45func.func @strip_results() -> (tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>) {
46  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
47  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
48  %2 = "test.custom_op"() : () -> tensor<*x!qalias>
49  %3 = "test.custom_op"() : () -> tensor<4xf32>
50  return %0, %1, %2, %3 : tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>
51}
52
53// -----
54
55
56// CHECK-LABEL: @callee
57// CHECK-SAME: (tensor<4xi8>, tensor<?xi16>) -> (tensor<*xi8>, tensor<4xf32>)
58
59// CHECK-LABEL: @strip_call
60
61// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"()
62// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
63
64// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"()
65// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
66
67// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]])
68
69// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>>
70// CHECK: "test.custom_op"(%[[RESULT_0_CAST]])
71
72// CHECK: "test.custom_op"(%[[RESULTS]]#1)
73
74// CHECK: return
75
76!qalias = !quant.uniform<i8:f32, 2.0:128>
77!qalias1 = !quant.uniform<i16:f32, 3.0:128>
78
79func.func private @callee(tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
80
81func.func @strip_call() {
82  %0 = "test.custom_op"() : () -> tensor<4x!qalias>
83  %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
84  %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
85  "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> ()
86  "test.custom_op"(%2#1) : (tensor<4xf32>) -> ()
87  return
88}
89