xref: /llvm-project/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir (revision 34e34a03ac83b51e90f8788945f9668446e468f8)
1// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
2
3func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
4                                                   %filter: tensor<1x?xi8>,
5                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
6  %res = linalg.depthwise_conv_1d_nwc_wc
7    {dilations = dense<1> : vector<1xi64>,
8    strides = dense<1> : vector<1xi64>}
9    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
10    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
11  return %res : tensor<1x8x?xi8>
12}
13
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
16    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
17    transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
18    transform.yield
19  }
20}
21
22// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
23// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
24// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
25// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
26
27// CHECK:           %[[C1:.*]] = arith.constant 1 : index
28// CHECK:           %[[C0:.*]] = arith.constant 0 : index
29// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
30
31/// Create a mask for the input tensor
32// CHECK:           %[[C2:.*]] = arith.constant 2 : index
33// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
34// CHECK:           %[[C8:.*]] = arith.constant 8 : index
35// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
36/// Read the input tensor
37// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
38
39/// Create a mask for the filter tensor
40// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
41// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
42/// Read the filter tensor
43// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
44
45/// Create a mask for the output tensor
46// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
47// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
48// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
49
50/// Convolution
51// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
52// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<4xi8> from vector<1x4xi8>
53// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
54// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<4xi8> to vector<1x8x4xi8>
55// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
56// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
57// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
58// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
59// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
60
61// -----
62
63func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
64      %input: tensor<1x8x?xi8>,
65      %filter: tensor<1x?xi8>,
66      %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
67  %res = linalg.depthwise_conv_1d_nwc_wc
68    {dilations = dense<1> : vector<1xi64>,
69    strides = dense<1> : vector<1xi64>}
70    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
71    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
72  return %res : tensor<1x8x?xi8>
73}
74
75module attributes {transform.with_named_sequence} {
76  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
77    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
78    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
79    transform.yield
80  }
81}
82
83// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
84// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
85// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
86// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
87
88// CHECK:           %[[C1:.*]] = arith.constant 1 : index
89// CHECK:           %[[C0:.*]] = arith.constant 0 : index
90// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
91
92/// Create a mask for the input tensor
93// CHECK:           %[[C2:.*]] = arith.constant 2 : index
94// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
95// CHECK:           %[[C8:.*]] = arith.constant 8 : index
96// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
97/// Read the input tensor
98// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
99
100/// Create a mask for the filter tensor
101// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
102// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
103/// Read the filter tensor
104// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
105
106/// Create a mask for the output tensor
107// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
108// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
109/// Read the output tensor
110// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
111
112/// Convolution
113// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
114// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
115// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
116// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xi8> to vector<1x8x[4]xi8>
117// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
118// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
119// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
120// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
121// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
122
123// -----
124
125func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
126      %input: memref<3x5x?xf32>,
127      %filter: memref<2x?xf32>,
128      %output: memref<3x2x?xf32>) {
129  linalg.depthwise_conv_1d_nwc_wc
130    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
131    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
132    outs(%output : memref<3x2x?xf32>)
133  return
134}
135
136module attributes {transform.with_named_sequence} {
137  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
138    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
139    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
140    transform.yield
141  }
142}
143
144// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
145// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
146// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
147// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
148
149// CHECK:           %[[C1:.*]] = arith.constant 1 : index
150// CHECK:           %[[C0:.*]] = arith.constant 0 : index
151// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
152// CHECK:           %[[C2:.*]] = arith.constant 2 : index
153
154/// Create a mask for the input tensor
155// CHECK:           %[[CH_DIM_IN:.*]] = memref.dim %[[INPUT]], %[[C2]] : memref<3x5x?xf32>
156// CHECK:           %[[C3:.*]] = arith.constant 3 : index
157// CHECK:           %[[C5:.*]] = arith.constant 5 : index
158// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
159/// Read the input tensor
160// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
161
162/// Create a mask for the filter tensor
163// CHECK:           %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
164// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
165/// Read the filter tensor
166// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
167
168/// Create a mask for the output tensor
169// CHECK:           %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
170// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
171/// Read the output tensor
172// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
173
174/// Convolution
175// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
176// CHECK:           %[[IN_2:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
177// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
178// CHECK:           %[[FLT_2:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
179// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
180// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
181// CHECK:           %[[FMA_1:.*]] = vector.fma %[[IN_1]], %[[FLT_1_B]], %[[OUT_1]] : vector<3x2x[4]xf32>
182// CHECK:           %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
183// CHECK:           %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
184// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
185// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
186