xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/Utils.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include <utility>
24 
25 namespace mlir {
26 namespace linalg {
hasAllOneValues(DenseIntElementsAttr attr)27 static bool hasAllOneValues(DenseIntElementsAttr attr) {
28   return llvm::all_of(
29       attr, [](const APInt &element) { return element.getSExtValue() == 1; });
30 }
31 
createAdd(Location loc,Value x,Value y,OpBuilder & builder)32 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33   if (isa<IntegerType>(x.getType()))
34     return builder.create<arith::AddIOp>(loc, x, y);
35   if (isa<ComplexType>(x.getType()))
36     return builder.create<complex::AddOp>(loc, x, y);
37   return builder.create<arith::AddFOp>(loc, x, y);
38 }
39 
createMul(Location loc,Value x,Value y,Type accType,OpBuilder & builder)40 static Value createMul(Location loc, Value x, Value y, Type accType,
41                        OpBuilder &builder) {
42   // Linalg named ops specify signed extend for named ops.
43   Value xConvert =
44       convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
45   Value yConvert =
46       convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
47   if (isa<ComplexType>(accType))
48     return builder.create<complex::MulOp>(loc, xConvert, yConvert);
49   if (isa<IntegerType>(accType))
50     return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
51   return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
52 }
53 
54 // Delinearizes the given composite `index` by the basis specified in `factors`.
unrollIndex(OpBuilder & b,Location loc,Value index,ArrayRef<int64_t> factors)55 static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
56                                       ArrayRef<int64_t> factors) {
57   assert(!factors.empty() && "empty factor list");
58   SmallVector<Value> basis;
59   for (int64_t f : factors)
60     basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
61   FailureOr<SmallVector<Value>> multiIndex =
62       affine::delinearizeIndex(b, loc, index, basis);
63   assert(!failed(multiIndex) && "Failed to linearize img2col index");
64   return *multiIndex;
65 }
66 
67 // Given indices corresponding to iterators in the output (oIndex) and filter
68 // (fIndex) for a convolution, compute the convolved index for the
69 // input as `oIndex * stride + fIndex`.
getConvolvedIndex(OpBuilder & b,Location loc,Value oIndex,Value fIndex,int64_t stride)70 static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
71                                Value fIndex, int64_t stride) {
72   AffineExpr oExpr, fExpr;
73   bindSymbols(b.getContext(), oExpr, fExpr);
74   AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
75   return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
76 }
77 
78 FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase & rewriter,linalg::Conv2DNhwcHwcfOp convOp)79 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
80   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83 
84   if (!filterType.hasStaticShape())
85     return rewriter.notifyMatchFailure(
86         convOp, "expected a static shape for the filter");
87 
88   if (!inputType.hasStaticShape())
89     return rewriter.notifyMatchFailure(convOp,
90                                        "expected a static shape for the input");
91 
92   // TODO: Support dilation.
93   if (!hasAllOneValues(convOp.getDilations()))
94     return rewriter.notifyMatchFailure(convOp,
95                                        "expected all ones for dilations");
96 
97   MLIRContext *context = rewriter.getContext();
98   Value input = convOp.getInputs()[0];
99   Value filter = convOp.getInputs()[1];
100   Value output = convOp.getOutputs()[0];
101 
102   ArrayRef<int64_t> filterShape = filterType.getShape();
103   ArrayRef<int64_t> outputShape = outputType.getShape();
104 
105   int64_t n = outputShape[0];
106   int64_t oh = outputShape[1];
107   int64_t ow = outputShape[2];
108   int64_t oc = outputShape[3];
109   int64_t fh = filterShape[0];
110   int64_t fw = filterShape[1];
111   int64_t ic = filterShape[2];
112 
113   Location loc = convOp.getLoc();
114 
115   // Reshape output and filter to the LHS and result of a (B)MNK matmul.
116   SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
117   auto reshapedFilterType =
118       RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
119   Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
120       loc, reshapedFilterType, filter, filterReassocIndices);
121 
122   SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
123   RankedTensorType reshapedOutputType =
124       RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
125   Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
126       loc, reshapedOutputType, output, outputReassocIndices);
127 
128   SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
129   Value colTensor = rewriter.create<tensor::EmptyOp>(
130       loc, colTensorShape, inputType.getElementType());
131 
132   // Convert the input to a (BMK) column tensor.
133   auto nloops = colTensorShape.size();
134 
135   auto parallel = utils::IteratorType::parallel;
136   auto reduction = utils::IteratorType::reduction;
137   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
138 
139   SmallVector<AffineMap> img2colIndexingMaps = {
140       AffineMap::getMultiDimIdentityMap(nloops, context)};
141 
142   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
143       loc, colTensor.getType(),
144       /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
145       img2colIterators,
146       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
147         // Get the iterators named based on the matmul (batch, m, k).
148         Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149         Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150         Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
151 
152         // Recover the original iteration indices from the problem/input sizes.
153         SmallVector<Value> mIndices = unrollIndex(
154             nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
155         auto ohIndex = mIndices[0];
156         auto owIndex = mIndices[1];
157 
158         SmallVector<Value> kIndices = unrollIndex(
159             nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160         auto fhIndex = kIndices[0];
161         auto fwIndex = kIndices[1];
162         auto icIndex = kIndices[2];
163 
164         // Extract the input element corresponding to the expanded indices.
165         Value hIndex =
166             getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
167                               convOp.getStrides().getValues<int64_t>()[0]);
168         Value wIndex =
169             getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
170                               convOp.getStrides().getValues<int64_t>()[1]);
171 
172         // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
173         SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
174         Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175             loc, input, extractionIndices);
176         nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
177       });
178 
179   // Because the filter does not share the same batch dimension,
180   // the batch dimension is only used in indexing the input and output. Thus
181   // we cannot use existing linalg named ops like linalg.batch_matmul.
182   // i.e. (B x) M x K * K x N = (B x) M x N
183   AffineExpr bDim, mDim, nDim, kDim;
184   bindDims(context, bDim, mDim, nDim, kDim);
185   auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
186   auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
187   auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
188   SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
189                                                        parallel, reduction};
190 
191   auto genericOp = rewriter.create<linalg::GenericOp>(
192       loc, reshapedOutputType,
193       /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194       /*outputs=*/ValueRange{reshapedOutput},
195       ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
196       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
197         Value mul =
198             createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199         Value add = createAdd(loc, mul, args[2], nestedBuilder);
200         nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201       });
202   Value result = genericOp.getResults().front();
203 
204   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
205       loc, outputType, result, outputReassocIndices);
206 
207   rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
208 
209   return std::make_pair(img2ColTensor.getOperation(),
210                         reshapedResult.getOperation());
211 }
212 
213 FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase & rewriter,linalg::DepthwiseConv2DNhwcHwcOp convOp)214 rewriteInIm2Col(RewriterBase &rewriter,
215                 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216   auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217   auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218   auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219 
220   if (!filterType.hasStaticShape())
221     return rewriter.notifyMatchFailure(
222         convOp, "expected a static shape for the filter");
223 
224   if (!inputType.hasStaticShape())
225     return rewriter.notifyMatchFailure(convOp,
226                                        "expected a static shape for the input");
227 
228   // TODO: Support dilation.
229   if (!hasAllOneValues(convOp.getDilations()))
230     return rewriter.notifyMatchFailure(convOp,
231                                        "expected all ones for dilations");
232 
233   Location loc = convOp.getLoc();
234 
235   auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
236     auto operandTensorType = cast<RankedTensorType>(operand.getType());
237     auto nloops = indices.size();
238     ArrayRef<int64_t> inputShape = operandTensorType.getShape();
239 
240     SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
241         llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
242           return rewriter.getAffineDimExpr(index);
243         }));
244 
245     SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
246         indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
247 
248     Value outputTensor = rewriter.create<tensor::EmptyOp>(
249         loc, targetShape, operandTensorType.getElementType());
250 
251     SmallVector<utils::IteratorType> loopAttributeTypes(
252         nloops, utils::IteratorType::parallel);
253 
254     SmallVector<AffineMap> indexingMaps = {
255         inversePermutation(
256             AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
257         AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
258 
259     auto transposedOp = rewriter.create<linalg::GenericOp>(
260         loc, outputTensor.getType(),
261         /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
262         loopAttributeTypes,
263         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
264           nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
265         });
266 
267     return transposedOp.getResult(0);
268   };
269 
270   Value input = convOp.getInputs()[0];
271   Value filter = convOp.getInputs()[1];
272   Value output = convOp.getOutputs()[0];
273 
274   // Transpose input, filter so channels are outermost
275   Value inputT = transposeOperand(input, {0, 3, 1, 2});
276   Value filterT = transposeOperand(filter, {2, 0, 1});
277   ArrayRef<int64_t> filterTShape =
278       cast<RankedTensorType>(filterT.getType()).getShape();
279   ArrayRef<int64_t> outputShape = outputType.getShape();
280 
281   int n = outputShape[0];
282   int oh = outputShape[1];
283   int ow = outputShape[2];
284   int c = outputShape[3];
285   int fh = filterTShape[1];
286   int fw = filterTShape[2];
287 
288   SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
289   Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290 
291   AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292   bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
293 
294   AffineExpr shSym = rewriter.getAffineConstantExpr(
295       convOp.getStrides().getValues<int64_t>()[0]);
296   AffineExpr swSym = rewriter.getAffineConstantExpr(
297       convOp.getStrides().getValues<int64_t>()[1]);
298 
299   SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
300                                         owDim * swSym + kwDim};
301 
302   auto nloops = colTensorShape.size();
303 
304   SmallVector<utils::IteratorType> loopAttributeTypes(
305       nloops, utils::IteratorType::parallel);
306 
307   SmallVector<AffineMap> indexingMaps = {
308       AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
309       AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
310 
311   Value colTensor = rewriter.create<tensor::EmptyOp>(
312       loc, colTensorShape, inputType.getElementType());
313 
314   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
315       loc, colTensor.getType(),
316       /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
317       loopAttributeTypes,
318       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
319         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
320       });
321 
322   SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
323       {0, 1}, {2, 3}, {4, 5}};
324   SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
325   SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
326                                                                  {2, 3}};
327 
328   auto reshapedImg2ColTensorType = RankedTensorType::get(
329       {n * c, oh * ow, fh * fw}, inputType.getElementType());
330   auto reshapedFilterTensorType =
331       RankedTensorType::get({c, fh * fw}, filterType.getElementType());
332   auto reshapedOutputTensorType =
333       RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
334 
335   Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
336       loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337       img2ColTensorReassocIndices);
338   Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
339       loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340   Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
341       loc, reshapedOutputTensorType, transposedOutputTensor,
342       outputReassociationIndice);
343 
344   auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
345       loc, TypeRange{reshapedoutputTensor.getType()},
346       ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347       ValueRange{reshapedoutputTensor});
348 
349   SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350                                                                       {2, 3}};
351 
352   auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
353       loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354       batchMatVecReassociationIndice);
355 
356   Value transposedResult =
357       transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358 
359   rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360   return std::make_pair(img2ColTensor.getOperation(),
361                         transposedResult.getDefiningOp());
362 }
363 
364 FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase & rewriter,linalg::Conv2DNchwFchwOp convOp)365 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369 
370   if (!filterType.hasStaticShape())
371     return rewriter.notifyMatchFailure(
372         convOp, "expected a static shape for the filter");
373 
374   if (!inputType.hasStaticShape())
375     return rewriter.notifyMatchFailure(convOp,
376                                        "expected a static shape for the input");
377 
378   // TODO: Support dilation.
379   if (!hasAllOneValues(convOp.getDilations()))
380     return rewriter.notifyMatchFailure(convOp,
381                                        "expected all ones for dilations");
382 
383   Value input = convOp.getInputs()[0];
384   Value filter = convOp.getInputs()[1];
385   Value output = convOp.getOutputs()[0];
386 
387   auto filterShape = filterType.getShape();
388   auto outputShape = outputType.getShape();
389 
390   int64_t n = outputShape[0];
391   int64_t oc = outputShape[1];
392   int64_t oh = outputShape[2];
393   int64_t ow = outputShape[3];
394   int64_t ic = filterShape[1];
395   int64_t fh = filterShape[2];
396   int64_t fw = filterShape[3];
397 
398   auto loc = convOp.getLoc();
399   MLIRContext *context = rewriter.getContext();
400 
401   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402   auto reshapedFilterType =
403       RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404   Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
405       loc, reshapedFilterType, filter, filterReassocIndices);
406 
407   SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408   auto reshapedOutputType =
409       RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410   Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
411       loc, reshapedOutputType, output, outputReassocIndices);
412 
413   // Convert the input to a (BKN) tensor.
414   SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415   Value colTensor = rewriter.create<tensor::EmptyOp>(
416       loc, colTensorShape, inputType.getElementType());
417 
418   auto nloops = colTensorShape.size();
419 
420   auto parallel = utils::IteratorType::parallel;
421   auto reduction = utils::IteratorType::reduction;
422   SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423 
424   SmallVector<AffineMap, 4> img2colIndexingMaps = {
425       AffineMap::getMultiDimIdentityMap(nloops, context)};
426 
427   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
428       loc, colTensor.getType(),
429       /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430       img2colIterators,
431       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432         // Get the iterators named based on the matmul (batch, m, k).
433         Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434         Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435         Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
436 
437         // Recover the original iteration indices from the problem/input sizes.
438         SmallVector<Value> kIndices = unrollIndex(
439             nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440         auto icIndex = kIndices[0];
441         auto fhIndex = kIndices[1];
442         auto fwIndex = kIndices[2];
443 
444         SmallVector<Value> nIndices = unrollIndex(
445             nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446         auto ohIndex = nIndices[0];
447         auto owIndex = nIndices[1];
448 
449         // Extract the input element corresponding to the expanded indices.
450         Value hIndex =
451             getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452                               convOp.getStrides().getValues<int64_t>()[0]);
453         Value wIndex =
454             getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455                               convOp.getStrides().getValues<int64_t>()[1]);
456 
457         // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458         SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459         Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460             loc, input, extractionIndices);
461         nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
462       });
463 
464   // Because the filter does not share the same batch dimension,
465   // the batch dimension is only used in indexing the input and output. Thus
466   // we cannot use existing linalg named ops like linalg.batch_matmul.
467   // i.e. M x K * (B x) K x N = (B x) M x N
468   AffineExpr bDim, mDim, nDim, kDim;
469   bindDims(context, bDim, mDim, nDim, kDim);
470   auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471   auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472   auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473   SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474                                                        parallel, reduction};
475   auto genericOp = rewriter.create<linalg::GenericOp>(
476       loc, reshapedOutputType,
477       /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478       /*outputs=*/ValueRange{reshapedOutput},
479       ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481         Value mul =
482             createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483         Value add = createAdd(loc, mul, args[2], nestedBuilder);
484         nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485       });
486   Value result = genericOp.getResults().front();
487 
488   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
489       loc, outputType, result, outputReassocIndices);
490 
491   rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
492 
493   return std::make_pair(img2ColTensor.getOperation(),
494                         reshapedResult.getOperation());
495 }
496 
497 FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase & rewriter,linalg::Conv2DNhwcFhwcOp convOp)498 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502 
503   if (!filterType.hasStaticShape())
504     return rewriter.notifyMatchFailure(
505         convOp, "expected a static shape for the filter");
506 
507   if (!inputType.hasStaticShape())
508     return rewriter.notifyMatchFailure(convOp,
509                                        "expected a static shape for the input");
510 
511   // TODO: Support dilation.
512   if (!hasAllOneValues(convOp.getDilations()))
513     return rewriter.notifyMatchFailure(convOp,
514                                        "expected all ones for dilations");
515 
516   MLIRContext *context = rewriter.getContext();
517   Value input = convOp.getInputs()[0];
518   Value filter = convOp.getInputs()[1];
519   Value output = convOp.getOutputs()[0];
520 
521   ArrayRef<int64_t> filterShape = filterType.getShape();
522   ArrayRef<int64_t> outputShape = outputType.getShape();
523 
524   int64_t n = outputShape[0];
525   int64_t oh = outputShape[1];
526   int64_t ow = outputShape[2];
527   int64_t oc = outputShape[3];
528   int64_t fh = filterShape[1];
529   int64_t fw = filterShape[2];
530   int64_t ic = filterShape[3];
531 
532   Location loc = convOp.getLoc();
533 
534   // Reshape output and filter to the LHS and result of a "row-wise" matrix
535   // multiplication.
536   SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537   auto reshapedFilterType =
538       RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539   Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540       loc, reshapedFilterType, filter, filterReassocIndices);
541 
542   SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543   RankedTensorType reshapedOutputType =
544       RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545   Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546       loc, reshapedOutputType, output, outputReassocIndices);
547 
548   SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549   Value colTensor = rewriter.create<tensor::EmptyOp>(
550       loc, colTensorShape, inputType.getElementType());
551 
552   // Convert the input to a (BMK) column tensor.
553   auto nloops = colTensorShape.size();
554 
555   auto parallel = utils::IteratorType::parallel;
556   auto reduction = utils::IteratorType::reduction;
557   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558 
559   SmallVector<AffineMap> img2colIndexingMaps = {
560       AffineMap::getMultiDimIdentityMap(nloops, context)};
561 
562   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563       loc, colTensor.getType(),
564       /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565       img2colIterators,
566       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567         // Get the iterators named based on the matmul (batch, m, k).
568         Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569         Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570         Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
571 
572         // Recover the original iteration indices from the problem/input sizes.
573         SmallVector<Value> mIndices = unrollIndex(
574             nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575         auto ohIndex = mIndices[0];
576         auto owIndex = mIndices[1];
577 
578         SmallVector<Value> kIndices = unrollIndex(
579             nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580         auto fhIndex = kIndices[0];
581         auto fwIndex = kIndices[1];
582         auto icIndex = kIndices[2];
583 
584         // Extract the input element corresponding to the expanded indices.
585         Value hIndex =
586             getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587                               convOp.getStrides().getValues<int64_t>()[0]);
588         Value wIndex =
589             getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590                               convOp.getStrides().getValues<int64_t>()[1]);
591 
592         // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593         SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594         Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595             loc, input, extractionIndices);
596         nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597       });
598 
599   // Because we didn't transpose the filters we don't actually have a batched
600   // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601   // products.
602   AffineExpr bDim, mDim, nDim, kDim;
603   bindDims(context, bDim, mDim, nDim, kDim);
604   auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605   auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606   auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607   SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608                                                        parallel, reduction};
609 
610   auto genericOp = rewriter.create<linalg::GenericOp>(
611       loc, reshapedOutputType,
612       /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613       /*outputs=*/ValueRange{reshapedOutput},
614       ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616         Value mul =
617             createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618         Value add = createAdd(loc, mul, args[2], nestedBuilder);
619         nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620       });
621   Value result = genericOp.getResults().front();
622 
623   auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624       loc, outputType, result, outputReassocIndices);
625 
626   rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627 
628   return std::make_pair(img2ColTensor.getOperation(),
629                         reshapedResult.getOperation());
630 }
631 
632 namespace {
633 
634 class ConvertConv2DNhwcHwcf final
635     : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636 public:
637   using OpRewritePattern::OpRewritePattern;
638 
matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,PatternRewriter & rewriter) const639   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640                                 PatternRewriter &rewriter) const override {
641     if (failed(rewriteInIm2Col(rewriter, convOp)))
642       return failure();
643     return success();
644   }
645 };
646 
647 class ConvertDepthwiseConv2DNhwcHwc final
648     : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649 public:
650   using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
651 
matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,PatternRewriter & rewriter) const652   LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653                                 PatternRewriter &rewriter) const override {
654     if (failed(rewriteInIm2Col(rewriter, convOp)))
655       return failure();
656     return success();
657   }
658 };
659 
660 class ConvertConv2DNchwFchw final
661     : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662 public:
663   using OpRewritePattern::OpRewritePattern;
664 
matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,PatternRewriter & rewriter) const665   LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666                                 PatternRewriter &rewriter) const override {
667     if (failed(rewriteInIm2Col(rewriter, convOp)))
668       return failure();
669     return success();
670   }
671 };
672 
673 class ConvertConv2DNhwcFhwc final
674     : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675 public:
676   using OpRewritePattern::OpRewritePattern;
677 
matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,PatternRewriter & rewriter) const678   LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679                                 PatternRewriter &rewriter) const override {
680     if (failed(rewriteInIm2Col(rewriter, convOp)))
681       return failure();
682     return success();
683   }
684 };
685 } // end anonymous namespace
686 
populateConvertConv2DToImg2ColPatterns(RewritePatternSet & patterns)687 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
688   MLIRContext *context = patterns.getContext();
689   patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690                   ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 } // end namespace linalg
693 } // end namespace mlir
694