xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (revision 91bbebc7e118cceae1fc0e349de08094a3cd2fe7)
1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include <optional>
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility methods for implementation of Tiling Interface for Linalg ops
28 //===----------------------------------------------------------------------===//
29 
30 /// Return the SSA values that represent the data point accessed using a given
31 /// `indexingMap` for a given point in the iteration space represented by `ivs`.
32 static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
33                                               AffineMap indexingMap,
34                                               ValueRange ivs) {
35   SmallVector<Value> indices;
36   indices.reserve(indexingMap.getNumResults());
37   for (auto result : indexingMap.getResults()) {
38     AffineMap m = AffineMap::get(indexingMap.getNumDims(),
39                                  indexingMap.getNumSymbols(), result);
40     Value v = b.create<affine::AffineApplyOp>(loc, m, ivs);
41     indices.push_back(v);
42   }
43   return indices;
44 }
45 
46 /// Method to inline the payload of a `linalgOp` given the iteration space
47 /// point and values for the arguments of the payload.
48 static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
49                                    ValueRange ivs, ValueRange argValues) {
50   Block *body = linalgOp.getBlock();
51   IRMapping map;
52   map.map(body->getArguments(), argValues);
53   for (auto &op : body->without_terminator()) {
54     if (auto indexOp = dyn_cast<IndexOp>(&op)) {
55       map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
56       continue;
57     }
58     b.clone(op, map);
59   }
60 
61   Operation *terminator = body->getTerminator();
62   Location loc = terminator->getLoc();
63   for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
64     Value toStore = map.lookupOrDefault(operand.value());
65     OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
66     auto indices = getIndicesForAccess(
67         b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
68     b.create<memref::StoreOp>(
69         loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
70         indices);
71   }
72   return success();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // External Model for implementing `TilingInterface` for `LinalgOp`s.
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 /// External model implementation of TilingInterface for LinalgOps. An external
81 /// model implementation is used for now till the use of `TilingInterface` is
82 /// on-par with the current Linalg tiling + fusion patterns. Once it is
83 /// maybe possible to move this into the op-definition (though there are
84 /// advantages to leaving it as an external model)
85 template <typename LinalgOpTy>
86 struct LinalgOpTilingInterface
87     : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
88                                             LinalgOpTy> {
89   /// Return the loop iterator type.
90   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
91     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
92     return concreteOp.getIteratorTypesArray();
93   }
94 
95   /// Return the iteration domain range.
96   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
97     OpBuilder::InsertionGuard g(b);
98     b.setInsertionPoint(op);
99     Location loc = op->getLoc();
100     LinalgOp linalgOp = cast<LinalgOp>(op);
101     SmallVector<OpFoldResult> allShapesSizes =
102         linalgOp.createFlatListOfOperandDims(b, loc);
103     AffineMap map = linalgOp.getShapesToLoopsMap();
104 
105     return llvm::to_vector(
106         llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
107           OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
108               b, loc, loopExpr, allShapesSizes);
109           return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
110         }));
111   }
112 
113   /// Instantiate the tiled implementation of the operation.
114   FailureOr<TilingResult>
115   getTiledImplementation(Operation *op, OpBuilder &b,
116                          ArrayRef<OpFoldResult> offsets,
117                          ArrayRef<OpFoldResult> sizes) const {
118     // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
119     // specified could lead to out of bounds accesses.
120     Location loc = op->getLoc();
121     LinalgOp linalgOp = cast<LinalgOp>(op);
122     SmallVector<Value> valuesToTile = linalgOp->getOperands();
123     SmallVector<Value> tiledOperands = makeTiledShapes(
124         b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
125     SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
126         llvm::make_filter_range(
127             tiledOperands,
128             [](Value v) -> bool {
129               return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
130                   v.getDefiningOp());
131             }),
132         [](Value v) -> Operation * { return v.getDefiningOp(); });
133 
134     SmallVector<Type> resultTensorTypes =
135         getTensorOutputTypes(linalgOp, tiledOperands);
136 
137     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
138     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
139 
140     return TilingResult{
141         {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
142   }
143 
144   /// Utility to fetch the offsets and sizes when applied as per the indexing
145   /// map of the linalg op. This helps in fusing the linalg op as a consumer of
146   /// a given slice op.
147   void
148   getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
149                          ArrayRef<OpFoldResult> offsets,
150                          ArrayRef<OpFoldResult> sizes,
151                          SmallVectorImpl<OpFoldResult> &mappedOffsets,
152                          SmallVectorImpl<OpFoldResult> &mappedSizes) const {
153     unsigned numLoops = linalgOp.getNumLoops();
154     auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
155     mappedOffsets.resize(numLoops);
156     mappedSizes.resize(numLoops);
157     if (!indexingMap.isPermutation()) {
158       SmallVector<Range> iterationDomain =
159           tilingInterfaceOp.getIterationDomain(b);
160       for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
161         mappedOffsets[index] = value.offset;
162         mappedSizes[index] = value.size;
163       }
164     }
165     for (const auto &&[index, value] :
166          llvm::enumerate(indexingMap.getResults())) {
167       unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
168       mappedOffsets[dimPosition] = offsets[index];
169       mappedSizes[dimPosition] = sizes[index];
170     }
171   }
172 
173   /// Method to return the position of the result tile computed by the tiled
174   /// operation.
175   LogicalResult getIterationDomainTileFromOperandTile(
176       Operation *op, OpBuilder &b, unsigned operandNumber,
177       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
178       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
179       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
180     auto linalgOp = cast<LinalgOp>(op);
181 
182     // Check that the indexing map used for the operand is a projected
183     // permutation. This could be relaxed with a more general approach that can
184     // map the offsets and sizes from the operand to iteration space tiles
185     // (filling in full extent for dimensions not used to access the result).
186     AffineMap indexingMap =
187         linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
188     if (!indexingMap.isProjectedPermutation()) {
189       return op->emitError()
190              << "unhandled get iter domain position when operand is not "
191                 "accessed using a permuted projection";
192     }
193 
194     getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
195                            iterDomainOffsets, iterDomainSizes);
196     return success();
197   }
198 
199   /// Return the details of the output tile generated by the tiled
200   /// implementation.
201   LogicalResult
202   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
203                         ArrayRef<OpFoldResult> offsets,
204                         ArrayRef<OpFoldResult> sizes,
205                         SmallVector<OpFoldResult> &resultOffsets,
206                         SmallVector<OpFoldResult> &resultSizes) const {
207     Location loc = op->getLoc();
208     LinalgOp linalgOp = cast<LinalgOp>(op);
209 
210     AffineExpr d0;
211     bindDims(b.getContext(), d0);
212     SmallVector<OpFoldResult> subShapeSizes =
213         llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
214           return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr);
215         }));
216 
217     OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
218     SliceParameters sliceParams = computeSliceParameters(
219         b, loc, outOperand->get(), sizes,
220         linalgOp.getMatchingIndexingMap(outOperand), offsets,
221         /*ubs*/ {}, subShapeSizes, true);
222     resultOffsets = sliceParams.offsets;
223     resultSizes = sliceParams.sizes;
224     return success();
225   }
226 
227   LogicalResult getIterationDomainTileFromResultTile(
228       Operation *op, OpBuilder &b, unsigned resultNumber,
229       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
230       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
231       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
232     auto linalgOp = cast<LinalgOp>(op);
233 
234     // Check that the indexing map used for the output is a projected
235     // permutation. This could be relaxed with a more general approach that can
236     // map the offsets and sizes from the result to iteration space tiles
237     // (filling in full extent for dimensions not used to access the result).
238     AffineMap indexingMap =
239         linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
240     if (!indexingMap.isProjectedPermutation()) {
241       return op->emitOpError(
242           "unhandled tiled implementation generation when result is not "
243           "accessed using a permuted projection");
244     }
245 
246     getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
247                            iterDomainOffsets, iterDomainSizes);
248     return success();
249   }
250 
251   FailureOr<TilingResult>
252   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
253                           ArrayRef<OpFoldResult> offsets,
254                           ArrayRef<OpFoldResult> sizes) const {
255     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
256     if (failed(getIterationDomainTileFromResultTile(
257             op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
258       return failure();
259     }
260     auto tilingInterfaceOp = cast<TilingInterface>(op);
261     FailureOr<TilingResult> tilingResult =
262         tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
263 
264     if (failed(tilingResult))
265       return failure();
266 
267     if (tilingResult->tiledOps.size() != 1)
268       return op->emitOpError("failed to generate tiled implementation");
269 
270     return TilingResult{
271         tilingResult->tiledOps,
272         SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
273         tilingResult->generatedSlices};
274   }
275 
276   /// Method to generate the tiled implementation of an operation from the tile
277   /// of the operand.
278   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
279       Operation *op, OpBuilder &b, unsigned operandNumber,
280       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
281     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
282     if (failed(getIterationDomainTileFromOperandTile(
283             op, b, operandNumber, offsets, sizes, mappedOffsets,
284             mappedSizes))) {
285       return failure();
286     }
287     return getTiledImplementation(op, b, mappedOffsets, mappedSizes);
288   }
289 
290   LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
291                                              Location loc,
292                                              ValueRange ivs) const {
293     auto linalgOp = cast<LinalgOp>(op);
294     if (!linalgOp.hasPureBufferSemantics())
295       return op->emitOpError("expected operation to have buffer semantics");
296 
297     SmallVector<Value> indexedValues;
298     indexedValues.reserve(linalgOp->getNumOperands());
299     Location linalgOpLoc = op->getLoc();
300     /// Load the data corresponding to the block arguments that
301     /// represent input operands.
302     for (OpOperand &operand : linalgOp->getOpOperands()) {
303       if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
304         indexedValues.push_back(nullptr);
305         continue;
306       }
307       if (linalgOp.isScalar(&operand)) {
308         indexedValues.push_back(operand.get());
309         continue;
310       }
311       SmallVector<Value> indices = getIndicesForAccess(
312           builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
313       Value load =
314           builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
315       indexedValues.push_back(load);
316     }
317 
318     /// Inline the op payload and store the result.
319     return inlinePayload(builder, linalgOp, ivs, indexedValues);
320   }
321 };
322 
323 //===----------------------------------------------------------------------===//
324 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325 //===----------------------------------------------------------------------===//
326 
327 /// Return an AffineMap for a partial result for the given result number,
328 /// assuming the partial tiling strategy is outer-reduction loop +
329 /// inner-parallel tile. The returned AffineMap can be used as the replacement
330 /// AffineMap for the inner-parallel tile linalg op for the given result number.
331 ///
332 /// The new AffineMap is the old AffineMap with reduction dimensions appended
333 /// at end.
334 static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
335                                            ArrayRef<int> reductionDims,
336                                            unsigned resultNumber) {
337   AffineMap map =
338       linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
339   for (int redPos : reductionDims) {
340     map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
341                            map.getNumResults());
342   }
343   return map;
344 }
345 
346 /// External model implementation of PartialReductionInterface for
347 /// LinalgOps.
348 template <typename LinalgOpTy>
349 struct LinalgOpPartialReductionInterface
350     : public PartialReductionOpInterface::ExternalModel<
351           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
352   FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
353       Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
354       ArrayRef<int> reductionDims) const {
355     auto linalgOp = cast<LinalgOp>(op);
356     OpBuilder::InsertionGuard guard(b);
357 
358     if (linalgOp.hasPureBufferSemantics())
359       return op->emitOpError("expected operation to have tensor semantics");
360 
361     // LinalgOp implements TilingInterface.
362     auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
363     SmallVector<OpFoldResult> shape =
364         llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
365                             [](Range x) { return x.size; });
366 
367     SmallVector<OpFoldResult> tiledShape;
368     for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
369       if (isZeroIndex(tileSize)) {
370         tiledShape.push_back(dimSize);
371       } else {
372         tiledShape.push_back(tileSize);
373       }
374     }
375 
376     SmallVector<Value> inits;
377     for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
378          ++initIdx) {
379       SmallVector<Operation *, 4> combinerOps;
380       if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
381                           combinerOps) ||
382           combinerOps.size() != 1)
383         return op->emitOpError("Failed to anaysis the reduction operation.");
384 
385       Operation *reductionOp = combinerOps[0];
386       std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
387       if (!identity.has_value())
388         return op->emitOpError(
389             "Failed to get an identity value for the reduction operation.");
390 
391       // Append the new partial result dimensions.
392       AffineMap partialMap =
393           getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
394       SmallVector<OpFoldResult> partialResultShape;
395       for (AffineExpr dimExpr : partialMap.getResults()) {
396         auto dim = cast<AffineDimExpr>(dimExpr);
397         partialResultShape.push_back(tiledShape[dim.getPosition()]);
398       }
399 
400       Type elType =
401           getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
402       Value emptyTensor =
403           b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
404       Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
405       auto identityTensor =
406           b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
407       inits.push_back(identityTensor.getResult(0));
408     }
409 
410     return inits;
411   }
412 
413   FailureOr<TilingResult>
414   tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
415                          ValueRange init, ArrayRef<OpFoldResult> offsets,
416                          ArrayRef<OpFoldResult> sizes,
417                          ArrayRef<int> reductionDims) const {
418     OpBuilder::InsertionGuard guard(b);
419     auto linalgOp = cast<LinalgOp>(op);
420 
421     // Step 1. Extend init maps to have reduction dimension dims, since we
422     // are converting them to parallel dimensions.
423     SmallVector<AffineMap> newInitMaps;
424     newInitMaps.reserve(linalgOp.getNumDpsInits());
425     for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
426       // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
427       // this with a for range loop when we have it.
428       AffineMap newMap =
429           getPartialResultAffineMap(linalgOp, reductionDims, idx);
430       newInitMaps.push_back(newMap);
431     }
432 
433     // Step 2a: Extract a slice of the input operands.
434     SmallVector<Value> tiledInputs = makeTiledShapes(
435         b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
436     SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
437         llvm::make_filter_range(
438             tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
439         [](Value v) -> Operation * { return v.getDefiningOp(); });
440 
441     // Step 2b: Extract a slice of the init operands.
442     SmallVector<Value, 1> tiledInits;
443     for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
444       int64_t initRank = valueMap.getNumResults();
445       SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
446       SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
447       SmallVector<OpFoldResult> initSizes;
448       for (AffineExpr dimExpr : valueMap.getResults()) {
449         auto dim = cast<AffineDimExpr>(dimExpr);
450         initSizes.push_back(sizes[dim.getPosition()]);
451       }
452       // TODO: Use SubsetExtractOpInterface here once available.
453       auto extractSlice = b.create<tensor::ExtractSliceOp>(
454           loc, valueToTile, initOffset, initSizes, initStride);
455       tiledInits.push_back(extractSlice);
456       generatedSlices.push_back(extractSlice);
457     }
458 
459     // Update the indexing maps.
460     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
461     // Change the init maps.
462     for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
463       // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
464       // this with a for range loop when we have it.
465       OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
466       int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
467       newMaps[mapIdx] = newInitMaps[idx];
468     }
469 
470     // Step 3. Change the reduction dim iterator types.
471     SmallVector<utils::IteratorType> newIteratorTypes =
472         linalgOp.getIteratorTypesArray();
473     for (int dim : reductionDims)
474       newIteratorTypes[dim] = utils::IteratorType::parallel;
475 
476     // Step 4. Create the new generic op.
477     auto genericOp =
478         b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
479                             tiledInits, newMaps, newIteratorTypes);
480     IRMapping mapping;
481     op->getRegion(0).cloneInto(&genericOp.getRegion(),
482                                genericOp.getRegion().begin(), mapping);
483     return TilingResult{
484         {genericOp.getOperation()},
485         llvm::map_to_vector(genericOp->getResults(),
486                             [](OpResult r) -> Value { return r; }),
487         generatedSlices};
488   }
489 
490   FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
491                                          Location loc, ValueRange partialReduce,
492                                          ArrayRef<int> reductionDims) const {
493     auto linalgOp = cast<LinalgOp>(op);
494 
495     // Permute the reduction dims as permuted by the partial result map.
496 
497     int64_t numInits = linalgOp.getNumDpsInits();
498     SmallVector<Operation *> mergeOperations;
499     SmallVector<Value> replacements;
500     for (int idx : llvm::seq(numInits)) {
501       // linalg.reduce's iteration space is the tiled result's iteration space
502       // (and not the tiled operation's iteration space). To account for this,
503       // permute the reduction dimensions based on the partial result map of the
504       // tiled result.
505       AffineMap partialMap =
506           getPartialResultAffineMap(linalgOp, reductionDims, idx);
507       SmallVector<int64_t> partialReductionDims;
508       for (auto [resultNum, dimExpr] :
509            llvm::enumerate(partialMap.getResults())) {
510         unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
511         if (llvm::find(reductionDims, dim) != reductionDims.end()) {
512           partialReductionDims.push_back(resultNum);
513         }
514       }
515 
516       Value partialResult = partialReduce[idx];
517       Value init = linalgOp.getDpsInits()[idx];
518 
519       auto reduction = b.create<linalg::ReduceOp>(
520           loc, partialResult, init, partialReductionDims,
521           [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
522             // Get the combiner op.
523             SmallVector<Operation *, 4> combinerOps;
524             matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
525             Operation *clonedReductionOp = b.clone(*combinerOps[0]);
526             // Combine the input at idx and output at numInits + idx.
527             clonedReductionOp->setOperand(0, inputs[0]);
528             clonedReductionOp->setOperand(1, inputs[1]);
529             b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
530           });
531 
532       mergeOperations.push_back(reduction);
533       replacements.push_back(reduction->getResult(0));
534     }
535 
536     return MergeResult{mergeOperations, replacements};
537   }
538 
539   LogicalResult getPartialResultTilePosition(
540       Operation *op, OpBuilder &b, unsigned resultNumber,
541       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
542       SmallVector<OpFoldResult> &resultOffsets,
543       SmallVector<OpFoldResult> &resultSizes,
544       ArrayRef<int> reductionDims) const {
545     auto linalgOp = cast<LinalgOp>(op);
546 
547     AffineMap partialMap =
548         getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
549     for (AffineExpr dimExpr : partialMap.getResults()) {
550       unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
551       resultSizes.push_back(sizes[dim]);
552 
553       if (llvm::find(reductionDims, dim) != reductionDims.end()) {
554         // Reduction dims are reduced, and are always outputed in the same
555         // place. So use offset 0 for them.
556         resultOffsets.push_back(b.getIndexAttr(0));
557       } else {
558         resultOffsets.push_back(offsets[dim]);
559       }
560     }
561 
562     return success();
563   }
564 };
565 
566 } // namespace
567 
568 template <typename OpType>
569 static void registerOne(MLIRContext *ctx) {
570   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
571   OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
572       *ctx);
573 }
574 
575 /// Variadic helper function.
576 template <typename... OpTypes>
577 static void registerAll(MLIRContext *ctx) {
578   (registerOne<OpTypes>(ctx), ...);
579 }
580 
581 #define GET_OP_LIST
582 
583 void mlir::linalg::registerTilingInterfaceExternalModels(
584     DialectRegistry &registry) {
585   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
586     registerOne<linalg::GenericOp>(ctx);
587     registerAll<
588 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
589         >(ctx);
590   });
591 }
592