xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp (revision 74ed79f7f123788d95f1552800e1af9ceaee4a08)
1 //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding on Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Support/LLVM.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 namespace {
26 /// Base class for constant folding linalg structured ops with N inputs, 1
27 /// output, and permutation indexing maps.
28 ///
29 /// `ConcreteType` should provide methods with signatures
30 ///
31 /// ```c++
32 ///   bool matchIndexingMaps(LinalgOp linalgOp) const;
33 ///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
34 /// ```
35 ///
36 /// The latter inspects the region and returns the computation inside as a
37 /// functor. The functor will be invoked with constant elements for all inputs
38 /// and should return the corresponding computed constant element for output.
39 template <typename ConcreteType>
40 class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
41 public:
42   struct APIntOrFloat {
43     std::optional<APInt> apInt;
44     std::optional<APFloat> apFloat;
45   };
46   struct APIntOrFloatArray {
47     SmallVector<APInt> apInts;
48     SmallVector<APFloat> apFloats;
49   };
50   using RegionComputationFn =
51       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
52 
FoldConstantBase(MLIRContext * context,const ControlFusionFn & controlFn,PatternBenefit benefit=1)53   FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
54                    PatternBenefit benefit = 1)
55       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
56         controlFn(controlFn) {}
57 
matchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const58   LogicalResult matchAndRewrite(LinalgOp linalgOp,
59                                 PatternRewriter &rewriter) const override {
60     // Mixed and buffer sematics aren't supported.
61     if (!linalgOp.hasPureTensorSemantics())
62       return failure();
63 
64     // Only support ops generating one output for now.
65     if (linalgOp.getNumDpsInits() != 1)
66       return failure();
67 
68     auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
69     // Require the output types to be static given that we are generating
70     // constants.
71     if (!outputType || !outputType.hasStaticShape())
72       return failure();
73 
74     if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
75           return isa<ShapedType>(input.getType());
76         }))
77       return failure();
78 
79     // Make sure all element types are the same.
80     auto getOperandElementType = [](Value value) {
81       return cast<ShapedType>(value.getType()).getElementType();
82     };
83     if (!llvm::all_equal(
84             llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
85       return failure();
86 
87     // We can only handle the case where we have int/float elements.
88     auto elementType = outputType.getElementType();
89     if (!elementType.isIntOrFloat())
90       return failure();
91 
92     // Require all indexing maps to be permutations for now. This is common and
93     // it simplifies input/output access greatly: we can do the data shuffling
94     // entirely in the compiler, without needing to turn all indices into
95     // Values, and then do affine apply on them, and then match back the
96     // constant again.
97     if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
98                       [](AffineMap map) { return map.isPermutation(); }))
99       return failure();
100 
101     for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
102       if (linalgOp.payloadUsesValueFromOperand(&operand))
103         return failure();
104     }
105 
106     // Further check the indexing maps are okay for the ConcreteType.
107     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
108       return failure();
109 
110     // Defer to the concrete type to check the region and discover the
111     // computation inside.
112     RegionComputationFn computeFn =
113         static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
114     if (!computeFn)
115       return failure();
116 
117     // All inputs should be constants.
118     int numInputs = linalgOp.getNumDpsInputs();
119     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
120     for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
121       if (!matchPattern(en.value()->get(),
122                         m_Constant(&inputValues[en.index()])))
123         return failure();
124     }
125 
126     // Identified this as a potential candidate for folding. Now check the
127     // policy to see whether we are allowed to proceed.
128     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
129       if (!controlFn(operand))
130         return failure();
131     }
132 
133     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
134     int64_t numElements = outputType.getNumElements();
135 
136     // Use APInt/APFloat instead of Attribute here for constructing the output.
137     // This helps to avoid blowing up compiler memory usage: Attributes would
138     // unify the following cases but they have lifetime as the MLIRContext.
139     SmallVector<APInt> intOutputValues;
140     SmallVector<APFloat> fpOutputValues;
141     if (isa<FloatType>(elementType))
142       fpOutputValues.resize(numElements, APFloat(0.f));
143     else
144       intOutputValues.resize(numElements);
145 
146     // Return the constant dim positions from the given permutation map.
147     auto getDimPositions = [](AffineMap map) {
148       SmallVector<unsigned> dims;
149       dims.reserve(map.getNumResults());
150       for (AffineExpr result : map.getResults()) {
151         dims.push_back(cast<AffineDimExpr>(result).getPosition());
152       }
153       return dims;
154     };
155 
156     SmallVector<SmallVector<unsigned>> inputDims;
157     for (int i = 0; i < numInputs; ++i)
158       inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
159     auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
160     auto outputShape = outputType.getShape();
161 
162     // Allocate small vectors for index delinearization. Initial values do not
163     // matter here as they will be overwritten later.
164     SmallVector<uint64_t> indices(loopBounds.size(), 0);
165     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
166     SmallVector<SmallVector<uint64_t>> srcIndices(
167         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
168     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
169     uint64_t dstLinearIndex = 0;
170 
171     // Allocate spaces for compute function inputs. Initial values do not matter
172     // here as they will be overwritten later.
173     APIntOrFloatArray computeFnInputs;
174 
175     auto inputShapes = llvm::to_vector<4>(
176         llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
177           return cast<ShapedType>(value.getType()).getShape();
178         }));
179 
180     // Given a `linearIndex`, remap it to a linear index to access linalg op
181     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
182     // `srcLinearIndices`, `dstLinearIndex` in place.
183     auto computeRemappedLinearIndex = [&](int linearIndex) {
184       int totalCount = linearIndex;
185       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
186         indices[dim] = totalCount % loopBounds[dim];
187         totalCount /= loopBounds[dim];
188       }
189 
190       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
191         for (int i = 0; i < numInputs; ++i)
192           srcIndices[i][dim] = indices[inputDims[i][dim]];
193         dstIndices[dim] = indices[outputDims[dim]];
194       }
195 
196       dstLinearIndex = dstIndices.front();
197       for (int i = 0; i < numInputs; ++i)
198         srcLinearIndices[i] = srcIndices[i].front();
199 
200       for (int dim = 1; dim < outputType.getRank(); ++dim) {
201         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
202         for (int i = 0; i < numInputs; ++i)
203           srcLinearIndices[i] =
204               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
205       }
206     };
207 
208     bool isFloat = isa<FloatType>(elementType);
209     if (isFloat) {
210       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
211       for (int i = 0; i < numInputs; ++i)
212         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
213 
214       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
215 
216       // Transpose the input constant. Because we don't know its rank in
217       // advance, we need to loop over the range [0, element count) and
218       // delinearize the index.
219       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
220         computeRemappedLinearIndex(linearIndex);
221 
222         // Collect constant elements for all inputs at this loop iteration.
223         for (int i = 0; i < numInputs; ++i)
224           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
225 
226         // Invoke the computation to get the corresponding constant output
227         // element.
228         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
229       }
230     } else {
231       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
232       for (int i = 0; i < numInputs; ++i)
233         inIntRanges.push_back(inputValues[i].getValues<APInt>());
234 
235       computeFnInputs.apInts.resize(numInputs);
236 
237       // Transpose the input constant. Because we don't know its rank in
238       // advance, we need to loop over the range [0, element count) and
239       // delinearize the index.
240       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
241         computeRemappedLinearIndex(linearIndex);
242 
243         // Collect constant elements for all inputs at this loop iteration.
244         for (int i = 0; i < numInputs; ++i)
245           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
246 
247         // Invoke the computation to get the corresponding constant output
248         // element.
249         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
250       }
251     }
252 
253     DenseElementsAttr outputAttr =
254         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
255                 : DenseElementsAttr::get(outputType, intOutputValues);
256 
257     rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
258     return success();
259   }
260 
261 private:
262   ControlFusionFn controlFn;
263 };
264 
265 // Folds linalg.transpose (and linalg.generic ops that are actually transposes)
266 // on constant values.
267 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
268 
269   using FoldConstantBase::FoldConstantBase;
270 
matchIndexingMaps__anond55d2d340111::FoldConstantTranspose271   bool matchIndexingMaps(LinalgOp linalgOp) const {
272     // We should have one input and one output.
273     return linalgOp.getIndexingMapsArray().size() == 2;
274   }
275 
getRegionComputeFn__anond55d2d340111::FoldConstantTranspose276   RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
277     // Make sure the region only contains a yield op.
278     Block &body = linalgOp->getRegion(0).front();
279     if (!llvm::hasSingleElement(body))
280       return nullptr;
281     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
282     if (!yieldOp)
283       return nullptr;
284 
285     // The yield op should return the block argument corresponds to the input.
286     for (Value yieldVal : yieldOp.getValues()) {
287       auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
288       if (!yieldArg || yieldArg.getOwner() != &body)
289         return nullptr;
290       if (yieldArg.getArgNumber() != 0)
291         return nullptr;
292     }
293 
294     // No computation; just return the orginal value.
295     return [](const APIntOrFloatArray &inputs) {
296       if (inputs.apFloats.empty())
297         return APIntOrFloat{inputs.apInts.front(), std::nullopt};
298       return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
299     };
300   }
301 
302   ControlFusionFn controlFn;
303 };
304 } // namespace
305 
populateConstantFoldLinalgOperations(RewritePatternSet & patterns,const ControlFusionFn & controlFn)306 void mlir::linalg::populateConstantFoldLinalgOperations(
307     RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
308   MLIRContext *context = patterns.getContext();
309   patterns.insert<FoldConstantTranspose>(context, controlFn);
310 }
311