xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (revision b0a4e958e85784cff46303c92b6a3a14b20fa1d8)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include <optional>
20 
21 namespace mlir {
22 /// Return all func.return ops in the given function.
23 SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
24   SmallVector<func::ReturnOp> result;
25   for (Block &b : funcOp.getBody())
26     if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27       result.push_back(returnOp);
28   return result;
29 }
30 
31 namespace bufferization {
32 namespace func_ext {
33 
34 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
35   analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
36   auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
37   auto createdAliasingResults =
38       aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
39   auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
40   auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
41   (void)createdEquiv;
42   (void)createdAliasingResults;
43   (void)createdRead;
44   (void)createdWritten;
45 #ifndef NDEBUG
46   assert(createdEquiv.second && "equivalence info exists already");
47   assert(createdAliasingResults.second && "aliasing info exists already");
48   assert(createdRead.second && "bbarg access info exists already");
49   assert(createdWritten.second && "bbarg access info exists already");
50 #endif // NDEBUG
51 }
52 
53 /// Return the index-th bufferized function argument type. This assumes that the
54 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
55 /// specified by the user (as per `options.functionArgTypeConverterFn`).
56 static BaseMemRefType
57 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
58                              const BufferizationOptions &options) {
59   auto tensorType =
60       dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
61   assert(tensorType && "expected TensorType");
62 
63   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
64       tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
65 
66   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
67       index, BufferizationDialect::kBufferLayoutAttrName);
68   if (!layoutAttr)
69     return memrefType;
70 
71   auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
73   return MemRefType::get(
74       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
75       layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
76 }
77 
78 /// Return the FuncOp called by `callOp`.
79 static FuncOp getCalledFunction(CallOpInterface callOp) {
80   SymbolRefAttr sym =
81       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
82   if (!sym)
83     return nullptr;
84   return dyn_cast_or_null<FuncOp>(
85       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
86 }
87 
88 /// Get FuncAnalysisState.
89 static const FuncAnalysisState &
90 getFuncAnalysisState(const AnalysisState &state) {
91   assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
92   auto *result = static_cast<const OneShotAnalysisState &>(state)
93                      .getExtension<FuncAnalysisState>();
94   assert(result && "FuncAnalysisState does not exist");
95   return *result;
96 }
97 
98 /// Return the state (phase) of analysis of the FuncOp.
99 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
100                                                   FuncOp funcOp) {
101   if (!isa<OneShotAnalysisState>(state))
102     return FuncOpAnalysisState::NotAnalyzed;
103   auto *funcState = static_cast<const OneShotAnalysisState &>(state)
104                         .getExtension<FuncAnalysisState>();
105   if (!funcState)
106     return FuncOpAnalysisState::NotAnalyzed;
107   const auto &analyzedFuncOps = funcState->analyzedFuncOps;
108   auto it = analyzedFuncOps.find(funcOp);
109   if (it == analyzedFuncOps.end())
110     return FuncOpAnalysisState::NotAnalyzed;
111   return it->second;
112 }
113 
114 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
115 /// specified return value (if any).
116 static std::optional<int64_t>
117 getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state,
118                         int64_t returnValIdx) {
119   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
120   if (funcOpIt == state.equivalentFuncArgs.end())
121     // No equivalence info stores for funcOp.
122     return std::nullopt;
123 
124   auto retValIt = funcOpIt->getSecond().find(returnValIdx);
125   if (retValIt == funcOpIt->getSecond().end())
126     // Return value has no equivalent bbArg.
127     return std::nullopt;
128 
129   return retValIt->getSecond();
130 }
131 
132 struct CallOpInterface
133     : public BufferizableOpInterface::ExternalModel<CallOpInterface,
134                                                     func::CallOp> {
135   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136                               const AnalysisState &state) const {
137     func::CallOp callOp = cast<func::CallOp>(op);
138     FuncOp funcOp = getCalledFunction(callOp);
139     assert(funcOp && "expected CallOp to a FuncOp");
140 
141     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
142       // FuncOp not analyzed yet. Assume that OpOperand is read.
143       return true;
144 
145     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
146     return funcState.readBbArgs.lookup(funcOp).contains(
147         opOperand.getOperandNumber());
148   }
149 
150   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
151                                const AnalysisState &state) const {
152     func::CallOp callOp = cast<func::CallOp>(op);
153     FuncOp funcOp = getCalledFunction(callOp);
154     assert(funcOp && "expected CallOp to a FuncOp");
155 
156     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
157       // FuncOp not analyzed yet. Assume that OpOperand is written.
158       return true;
159 
160     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
161     return funcState.writtenBbArgs.lookup(funcOp).contains(
162         opOperand.getOperandNumber());
163   }
164 
165   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
166                                       const AnalysisState &state) const {
167     func::CallOp callOp = cast<func::CallOp>(op);
168     FuncOp funcOp = getCalledFunction(callOp);
169     assert(funcOp && "expected CallOp to a FuncOp");
170     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
171       // FuncOp not analyzed yet. Any OpResult may be aliasing.
172       return detail::unknownGetAliasingValues(opOperand);
173 
174     // Get aliasing results from state.
175     const FuncAnalysisState &funcState = getFuncAnalysisState(state);
176     auto aliasingReturnVals =
177         funcState.aliasingReturnVals.lookup(funcOp).lookup(
178             opOperand.getOperandNumber());
179 
180     // Check if the aliasing OpResult is equivalent to the OpOperand.
181     std::optional<int64_t> equivalent = {};
182     if (aliasingReturnVals.size() == 1) {
183       equivalent = getEquivalentFuncArgIdx(funcOp, funcState,
184                                            aliasingReturnVals.front());
185       assert((!equivalent.has_value() ||
186               *equivalent == opOperand.getOperandNumber()) &&
187              "inconsistent analysis state");
188     }
189     AliasingValueList result;
190     for (int64_t resultIdx : aliasingReturnVals)
191       result.addAlias({callOp->getOpResult(resultIdx),
192                        equivalent.has_value() ? BufferRelation::Equivalent
193                                               : BufferRelation::Unknown,
194                        /*isDefinite=*/equivalent.has_value()});
195     return result;
196   }
197 
198   FailureOr<BaseMemRefType>
199   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
200                 SmallVector<Value> &invocationStack) const {
201     auto callOp = cast<func::CallOp>(op);
202     FuncOp funcOp = getCalledFunction(callOp);
203     assert(funcOp && "expected CallOp to a FuncOp");
204 
205     // If the callee was already bufferized, we can directly take the type from
206     // its signature.
207     FunctionType funcType = funcOp.getFunctionType();
208     Type resultType =
209         funcType.getResult(cast<OpResult>(value).getResultNumber());
210     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
211       return bufferizedType;
212 
213     // Otherwise, call the type converter to compute the bufferized type.
214     auto tensorType = cast<TensorType>(resultType);
215     return options.functionArgTypeConverterFn(
216         tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
217   }
218 
219   /// All function arguments are writable. It is the responsibility of the
220   /// CallOp to insert buffer copies where necessary.
221   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
222                           const BufferizationOptions &options) const {
223     func::CallOp callOp = cast<func::CallOp>(op);
224 
225     // 1. Compute the result types of the new CallOp.
226     SmallVector<Type> resultTypes;
227     for (Value result : callOp.getResults()) {
228       Type returnType = result.getType();
229       if (!isa<TensorType>(returnType)) {
230         // Non-tensor values are returned.
231         resultTypes.push_back(returnType);
232         continue;
233       }
234 
235       // Returning a memref.
236       FailureOr<BaseMemRefType> resultType =
237           bufferization::getBufferType(result, options);
238       if (failed(resultType))
239         return failure();
240       resultTypes.push_back(*resultType);
241     }
242 
243     // 2. Rewrite tensor operands as memrefs based on type of the already
244     //    bufferized callee.
245     SmallVector<Value> newOperands;
246     FuncOp funcOp = getCalledFunction(callOp);
247     assert(funcOp && "expected CallOp to a FuncOp");
248     FunctionType funcType = funcOp.getFunctionType();
249 
250     for (OpOperand &opOperand : callOp->getOpOperands()) {
251       // Non-tensor operands are just copied.
252       if (!isa<TensorType>(opOperand.get().getType())) {
253         newOperands.push_back(opOperand.get());
254         continue;
255       }
256 
257       // Retrieve buffers for tensor operands.
258       FailureOr<Value> maybeBuffer =
259           getBuffer(rewriter, opOperand.get(), options);
260       if (failed(maybeBuffer))
261         return failure();
262       Value buffer = *maybeBuffer;
263 
264       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
265       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
266       if (!isa<BaseMemRefType>(memRefType)) {
267         // The called function was not bufferized yet. This can happen when
268         // there cycles in the function call graph. Compute the bufferized
269         // result type.
270         FailureOr<BaseMemRefType> maybeMemRefType =
271             bufferization::getBufferType(
272                 funcOp.getArgument(opOperand.getOperandNumber()), options);
273         if (failed(maybeMemRefType))
274           return failure();
275         memRefType = *maybeMemRefType;
276       }
277 
278       // Since we don't yet have a clear layout story, to_memref may
279       // conservatively turn tensors into more dynamic memref than necessary.
280       // If the memref type of the callee fails, introduce an extra memref.cast
281       // that will either canonicalize away or fail compilation until we can do
282       // something better. Insert a reallocation + copy if it cannot be
283       // statically guaranteed that a direct cast would be valid.
284       if (buffer.getType() != memRefType) {
285         auto memrefDstType = dyn_cast<MemRefType>(memRefType);
286         assert(memrefDstType &&
287                "buffer layout not supported on unranked tensors");
288         FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
289             rewriter, buffer, memrefDstType, options);
290         if (failed(replacement))
291           return failure();
292         buffer = *replacement;
293       }
294       newOperands.push_back(buffer);
295     }
296 
297     // 3. Create the new CallOp.
298     Operation *newCallOp = rewriter.create<func::CallOp>(
299         callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
300     newCallOp->setAttrs(callOp->getAttrs());
301 
302     // 4. Replace the old op with the new op.
303     replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
304 
305     return success();
306   }
307 };
308 
309 struct ReturnOpInterface
310     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
311                                                     func::ReturnOp> {
312   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313                               const AnalysisState &state) const {
314     return true;
315   }
316 
317   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318                                const AnalysisState &state) const {
319     return false;
320   }
321 
322   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
323                                       const AnalysisState &state) const {
324     return {};
325   }
326 
327   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328                           const BufferizationOptions &options) const {
329 #ifndef NDEBUG
330     auto returnOp = cast<func::ReturnOp>(op);
331     assert(isa<FuncOp>(returnOp->getParentOp()) &&
332            "only support FuncOp parent for ReturnOp");
333 #endif // NDEBUG
334 
335     // ReturnOps are bufferized as part of FuncOps.
336     return success();
337   }
338 };
339 
340 struct FuncOpInterface
341     : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
342           FuncOpInterface, FuncOp> {
343 
344   static bool supportsUnstructuredControlFlow() { return true; }
345 
346   bool hasTensorSemantics(Operation *op) const {
347     auto isaTensor = llvm::IsaPred<TensorType>;
348 
349     // A function has tensor semantics if it has tensor arguments/results.
350     auto funcOp = cast<FuncOp>(op);
351     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
352     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
353     if (hasTensorArg || hasTensorResult)
354       return true;
355 
356     // It also has tensor semantics if it has tensor block arguments.
357     // TODO: Decouple bufferization of unstructured control flow from
358     // BufferizableOpInterface implementations. We should only care about
359     // region entry block arguments here (which are already covered by the
360     // argument types of the function).
361     for (Block &block : funcOp.getBody())
362       if (any_of(block.getArgumentTypes(), isaTensor))
363         return true;
364 
365     return false;
366   }
367 
368   AliasingOpOperandList
369   getAliasingOpOperands(Operation *op, Value value,
370                         const AnalysisState &state) const {
371     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
372   }
373 
374   FailureOr<BaseMemRefType>
375   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
376                 SmallVector<Value> &invocationStack) const {
377     auto funcOp = cast<FuncOp>(op);
378     auto bbArg = cast<BlockArgument>(value);
379 
380     // Function arguments are special.
381     if (bbArg.getOwner() == &funcOp.getBody().front())
382       return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
383                                           options);
384 
385     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
386         getBufferType(op, value, options, invocationStack);
387   }
388 
389   /// Rewrite function bbArgs and return values into buffer form. This function
390   /// bufferizes the function signature and the ReturnOp. When the entire
391   /// function body has been bufferized, function return types can be switched
392   /// to more concise memref types as part of `foldMemRefCasts`.
393   ///
394   /// All function bbArgs are writable unless they are explicitly marked as
395   /// read-only. Callers must insert copies when needed.
396   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
397                           const BufferizationOptions &options) const {
398     auto funcOp = cast<FuncOp>(op);
399     FunctionType funcType = funcOp.getFunctionType();
400 
401     // Compute the argument types.
402     SmallVector<Type> argTypes;
403     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
404       Type argType = it.value();
405       if (isa<TensorType>(argType)) {
406         argTypes.push_back(
407             getBufferizedFunctionArgType(funcOp, it.index(), options));
408         continue;
409       }
410       argTypes.push_back(argType);
411     }
412 
413     // Compute the result types.
414     SmallVector<Type> retTypes;
415     for (Type resultType : funcType.getResults()) {
416       if (auto tensorType = dyn_cast<TensorType>(resultType)) {
417         BaseMemRefType resultType = options.functionArgTypeConverterFn(
418             tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
419             options);
420         retTypes.push_back(resultType);
421         continue;
422       }
423       retTypes.push_back(resultType);
424     }
425 
426     // Compute the new function type.
427     auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
428 
429     // If the function has no body, set the new function type and we are done.
430     if (funcOp.isExternal()) {
431       funcOp.setType(newFuncType);
432       return success();
433     }
434 
435     // 1. Bufferize every block.
436     for (Block &block : funcOp.getBody())
437       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
438                                                         options)))
439         return failure();
440 
441     // 2. Bufferize the operands of the all return op.
442     for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
443       assert(returnOp->getNumOperands() == retTypes.size() &&
444              "incorrect number of return values");
445       SmallVector<Value> returnValues;
446       for (auto [returnVal, bufferizedType] :
447            llvm::zip_equal(returnOp->getOperands(), retTypes)) {
448         auto tensorType = dyn_cast<TensorType>(returnVal.getType());
449         rewriter.setInsertionPoint(returnOp);
450 
451         // If not a tensor type just forward it.
452         if (!tensorType) {
453           returnValues.push_back(returnVal);
454           continue;
455         }
456 
457         // Note: If `inferFunctionResultLayout = true`, casts are later folded
458         // away.
459         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
460             returnOp.getLoc(), bufferizedType, returnVal);
461         returnValues.push_back(toMemrefOp);
462       }
463 
464       returnOp.getOperandsMutable().assign(returnValues);
465     }
466 
467     // 3. Set the new function type.
468     funcOp.setType(newFuncType);
469     return success();
470   }
471 
472   /// Return `true` if the given function argument is writable.
473   bool isWritable(Operation *op, Value value,
474                   const AnalysisState &state) const {
475     auto funcOp = cast<FuncOp>(op);
476     BlockArgument bbArg = dyn_cast<BlockArgument>(value);
477     assert(bbArg && "expected BlockArgument");
478 
479     // Non-entry block arguments are always writable. (They may alias with
480     // values that are not writable, which will turn them into read-only.)
481     if (bbArg.getOwner() != &funcOp.getBody().front())
482       return true;
483 
484     // "bufferization.writable" overrides other writability decisions. This is
485     // currently used for testing only.
486     if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
487             bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
488       return writable.getValue();
489 
490     // All function arguments are writable by default.
491     return true;
492   }
493 };
494 
495 } // namespace func_ext
496 } // namespace bufferization
497 } // namespace mlir
498 
499 void mlir::bufferization::func_ext::
500     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
501   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
502     func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
503     func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
504     func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
505   });
506 }
507