1e95e94adSJeff Niu //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// 2e95e94adSJeff Niu // 3e95e94adSJeff Niu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e95e94adSJeff Niu // See https://llvm.org/LICENSE.txt for license information. 5e95e94adSJeff Niu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e95e94adSJeff Niu // 7e95e94adSJeff Niu //===----------------------------------------------------------------------===// 8e95e94adSJeff Niu 9e95e94adSJeff Niu #include "TestDialect.h" 10e95e94adSJeff Niu #include "TestOps.h" 11e95e94adSJeff Niu #include "mlir/Dialect/Tensor/IR/Tensor.h" 12e95e94adSJeff Niu #include "mlir/IR/Verifier.h" 13e95e94adSJeff Niu #include "mlir/Interfaces/FunctionImplementation.h" 14eeafc9daSChristian Ulmann #include "mlir/Interfaces/MemorySlotInterfaces.h" 15e95e94adSJeff Niu 16e95e94adSJeff Niu using namespace mlir; 17e95e94adSJeff Niu using namespace test; 18e95e94adSJeff Niu 19e95e94adSJeff Niu //===----------------------------------------------------------------------===// 20e95e94adSJeff Niu // TestBranchOp 21e95e94adSJeff Niu //===----------------------------------------------------------------------===// 22e95e94adSJeff Niu 23e95e94adSJeff Niu SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 24e95e94adSJeff Niu assert(index == 0 && "invalid successor index"); 25e95e94adSJeff Niu return SuccessorOperands(getTargetOperandsMutable()); 26e95e94adSJeff Niu } 27e95e94adSJeff Niu 28e95e94adSJeff Niu //===----------------------------------------------------------------------===// 29e95e94adSJeff Niu // TestProducingBranchOp 30e95e94adSJeff Niu //===----------------------------------------------------------------------===// 31e95e94adSJeff Niu 32e95e94adSJeff Niu SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 33e95e94adSJeff Niu assert(index <= 1 && "invalid successor index"); 34e95e94adSJeff Niu if (index == 1) 35e95e94adSJeff Niu return SuccessorOperands(getFirstOperandsMutable()); 36e95e94adSJeff Niu return SuccessorOperands(getSecondOperandsMutable()); 37e95e94adSJeff Niu } 38e95e94adSJeff Niu 39e95e94adSJeff Niu //===----------------------------------------------------------------------===// 40e95e94adSJeff Niu // TestInternalBranchOp 41e95e94adSJeff Niu //===----------------------------------------------------------------------===// 42e95e94adSJeff Niu 43e95e94adSJeff Niu SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 44e95e94adSJeff Niu assert(index <= 1 && "invalid successor index"); 45e95e94adSJeff Niu if (index == 0) 46e95e94adSJeff Niu return SuccessorOperands(0, getSuccessOperandsMutable()); 47e95e94adSJeff Niu return SuccessorOperands(1, getErrorOperandsMutable()); 48e95e94adSJeff Niu } 49e95e94adSJeff Niu 50e95e94adSJeff Niu //===----------------------------------------------------------------------===// 51e95e94adSJeff Niu // TestCallOp 52e95e94adSJeff Niu //===----------------------------------------------------------------------===// 53e95e94adSJeff Niu 54e95e94adSJeff Niu LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 55e95e94adSJeff Niu // Check that the callee attribute was specified. 56e95e94adSJeff Niu auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 57e95e94adSJeff Niu if (!fnAttr) 58e95e94adSJeff Niu return emitOpError("requires a 'callee' symbol reference attribute"); 59e95e94adSJeff Niu if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 60e95e94adSJeff Niu return emitOpError() << "'" << fnAttr.getValue() 61e95e94adSJeff Niu << "' does not reference a valid function"; 62e95e94adSJeff Niu return success(); 63e95e94adSJeff Niu } 64e95e94adSJeff Niu 65e95e94adSJeff Niu //===----------------------------------------------------------------------===// 66e95e94adSJeff Niu // FoldToCallOp 67e95e94adSJeff Niu //===----------------------------------------------------------------------===// 68e95e94adSJeff Niu 69e95e94adSJeff Niu namespace { 70e95e94adSJeff Niu struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 71e95e94adSJeff Niu using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 72e95e94adSJeff Niu 73e95e94adSJeff Niu LogicalResult matchAndRewrite(FoldToCallOp op, 74e95e94adSJeff Niu PatternRewriter &rewriter) const override { 75e95e94adSJeff Niu rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 76e95e94adSJeff Niu op.getCalleeAttr(), ValueRange()); 77e95e94adSJeff Niu return success(); 78e95e94adSJeff Niu } 79e95e94adSJeff Niu }; 80e95e94adSJeff Niu } // namespace 81e95e94adSJeff Niu 82e95e94adSJeff Niu void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 83e95e94adSJeff Niu MLIRContext *context) { 84e95e94adSJeff Niu results.add<FoldToCallOpPattern>(context); 85e95e94adSJeff Niu } 86e95e94adSJeff Niu 87e95e94adSJeff Niu //===----------------------------------------------------------------------===// 88e95e94adSJeff Niu // IsolatedRegionOp - test parsing passthrough operands 89e95e94adSJeff Niu //===----------------------------------------------------------------------===// 90e95e94adSJeff Niu 91e95e94adSJeff Niu ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 92e95e94adSJeff Niu OperationState &result) { 93e95e94adSJeff Niu // Parse the input operand. 94e95e94adSJeff Niu OpAsmParser::Argument argInfo; 95e95e94adSJeff Niu argInfo.type = parser.getBuilder().getIndexType(); 96e95e94adSJeff Niu if (parser.parseOperand(argInfo.ssaName) || 97e95e94adSJeff Niu parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) 98e95e94adSJeff Niu return failure(); 99e95e94adSJeff Niu 100e95e94adSJeff Niu // Parse the body region, and reuse the operand info as the argument info. 101e95e94adSJeff Niu Region *body = result.addRegion(); 102e95e94adSJeff Niu return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); 103e95e94adSJeff Niu } 104e95e94adSJeff Niu 105e95e94adSJeff Niu void IsolatedRegionOp::print(OpAsmPrinter &p) { 106e95e94adSJeff Niu p << ' '; 107e95e94adSJeff Niu p.printOperand(getOperand()); 108e95e94adSJeff Niu p.shadowRegionArgs(getRegion(), getOperand()); 109e95e94adSJeff Niu p << ' '; 110e95e94adSJeff Niu p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 111e95e94adSJeff Niu } 112e95e94adSJeff Niu 113e95e94adSJeff Niu //===----------------------------------------------------------------------===// 114e95e94adSJeff Niu // SSACFGRegionOp 115e95e94adSJeff Niu //===----------------------------------------------------------------------===// 116e95e94adSJeff Niu 117e95e94adSJeff Niu RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 118e95e94adSJeff Niu return RegionKind::SSACFG; 119e95e94adSJeff Niu } 120e95e94adSJeff Niu 121e95e94adSJeff Niu //===----------------------------------------------------------------------===// 122e95e94adSJeff Niu // GraphRegionOp 123e95e94adSJeff Niu //===----------------------------------------------------------------------===// 124e95e94adSJeff Niu 125e95e94adSJeff Niu RegionKind GraphRegionOp::getRegionKind(unsigned index) { 126e95e94adSJeff Niu return RegionKind::Graph; 127e95e94adSJeff Niu } 128e95e94adSJeff Niu 129e95e94adSJeff Niu //===----------------------------------------------------------------------===// 130b084111cSThéo Degioanni // IsolatedGraphRegionOp 131b084111cSThéo Degioanni //===----------------------------------------------------------------------===// 132b084111cSThéo Degioanni 133b084111cSThéo Degioanni RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) { 134b084111cSThéo Degioanni return RegionKind::Graph; 135b084111cSThéo Degioanni } 136b084111cSThéo Degioanni 137b084111cSThéo Degioanni //===----------------------------------------------------------------------===// 138e95e94adSJeff Niu // AffineScopeOp 139e95e94adSJeff Niu //===----------------------------------------------------------------------===// 140e95e94adSJeff Niu 141e95e94adSJeff Niu ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 142e95e94adSJeff Niu // Parse the body region, and reuse the operand info as the argument info. 143e95e94adSJeff Niu Region *body = result.addRegion(); 144e95e94adSJeff Niu return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 145e95e94adSJeff Niu } 146e95e94adSJeff Niu 147e95e94adSJeff Niu void AffineScopeOp::print(OpAsmPrinter &p) { 148e95e94adSJeff Niu p << " "; 149e95e94adSJeff Niu p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 150e95e94adSJeff Niu } 151e95e94adSJeff Niu 152e95e94adSJeff Niu //===----------------------------------------------------------------------===// 153e95e94adSJeff Niu // TestRemoveOpWithInnerOps 154e95e94adSJeff Niu //===----------------------------------------------------------------------===// 155e95e94adSJeff Niu 156e95e94adSJeff Niu namespace { 157e95e94adSJeff Niu struct TestRemoveOpWithInnerOps 158e95e94adSJeff Niu : public OpRewritePattern<TestOpWithRegionPattern> { 159e95e94adSJeff Niu using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 160e95e94adSJeff Niu 161e95e94adSJeff Niu void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 162e95e94adSJeff Niu 163e95e94adSJeff Niu LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 164e95e94adSJeff Niu PatternRewriter &rewriter) const override { 165e95e94adSJeff Niu rewriter.eraseOp(op); 166e95e94adSJeff Niu return success(); 167e95e94adSJeff Niu } 168e95e94adSJeff Niu }; 169e95e94adSJeff Niu } // namespace 170e95e94adSJeff Niu 171e95e94adSJeff Niu //===----------------------------------------------------------------------===// 172e95e94adSJeff Niu // TestOpWithRegionPattern 173e95e94adSJeff Niu //===----------------------------------------------------------------------===// 174e95e94adSJeff Niu 175e95e94adSJeff Niu void TestOpWithRegionPattern::getCanonicalizationPatterns( 176e95e94adSJeff Niu RewritePatternSet &results, MLIRContext *context) { 177e95e94adSJeff Niu results.add<TestRemoveOpWithInnerOps>(context); 178e95e94adSJeff Niu } 179e95e94adSJeff Niu 180e95e94adSJeff Niu //===----------------------------------------------------------------------===// 181e95e94adSJeff Niu // TestOpWithRegionFold 182e95e94adSJeff Niu //===----------------------------------------------------------------------===// 183e95e94adSJeff Niu 184e95e94adSJeff Niu OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { 185e95e94adSJeff Niu return getOperand(); 186e95e94adSJeff Niu } 187e95e94adSJeff Niu 188e95e94adSJeff Niu //===----------------------------------------------------------------------===// 189e95e94adSJeff Niu // TestOpConstant 190e95e94adSJeff Niu //===----------------------------------------------------------------------===// 191e95e94adSJeff Niu 192e95e94adSJeff Niu OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } 193e95e94adSJeff Niu 194e95e94adSJeff Niu //===----------------------------------------------------------------------===// 195e95e94adSJeff Niu // TestOpWithVariadicResultsAndFolder 196e95e94adSJeff Niu //===----------------------------------------------------------------------===// 197e95e94adSJeff Niu 198e95e94adSJeff Niu LogicalResult TestOpWithVariadicResultsAndFolder::fold( 199e95e94adSJeff Niu FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { 200e95e94adSJeff Niu for (Value input : this->getOperands()) { 201e95e94adSJeff Niu results.push_back(input); 202e95e94adSJeff Niu } 203e95e94adSJeff Niu return success(); 204e95e94adSJeff Niu } 205e95e94adSJeff Niu 206e95e94adSJeff Niu //===----------------------------------------------------------------------===// 207e95e94adSJeff Niu // TestOpInPlaceFold 208e95e94adSJeff Niu //===----------------------------------------------------------------------===// 209e95e94adSJeff Niu 210e95e94adSJeff Niu OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { 211e95e94adSJeff Niu // Exercise the fact that an operation created with createOrFold should be 212e95e94adSJeff Niu // allowed to access its parent block. 213e95e94adSJeff Niu assert(getOperation()->getBlock() && 214e95e94adSJeff Niu "expected that operation is not unlinked"); 215e95e94adSJeff Niu 216e95e94adSJeff Niu if (adaptor.getOp() && !getProperties().attr) { 217e95e94adSJeff Niu // The folder adds "attr" if not present. 218e95e94adSJeff Niu getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()); 219e95e94adSJeff Niu return getResult(); 220e95e94adSJeff Niu } 221e95e94adSJeff Niu return {}; 222e95e94adSJeff Niu } 223e95e94adSJeff Niu 224e95e94adSJeff Niu //===----------------------------------------------------------------------===// 225e95e94adSJeff Niu // OpWithInferTypeInterfaceOp 226e95e94adSJeff Niu //===----------------------------------------------------------------------===// 227e95e94adSJeff Niu 228e95e94adSJeff Niu LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 229e95e94adSJeff Niu MLIRContext *, std::optional<Location> location, ValueRange operands, 230e95e94adSJeff Niu DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 231e95e94adSJeff Niu SmallVectorImpl<Type> &inferredReturnTypes) { 232e95e94adSJeff Niu if (operands[0].getType() != operands[1].getType()) { 233e95e94adSJeff Niu return emitOptionalError(location, "operand type mismatch ", 234e95e94adSJeff Niu operands[0].getType(), " vs ", 235e95e94adSJeff Niu operands[1].getType()); 236e95e94adSJeff Niu } 237e95e94adSJeff Niu inferredReturnTypes.assign({operands[0].getType()}); 238e95e94adSJeff Niu return success(); 239e95e94adSJeff Niu } 240e95e94adSJeff Niu 241e95e94adSJeff Niu //===----------------------------------------------------------------------===// 242e95e94adSJeff Niu // OpWithShapedTypeInferTypeInterfaceOp 243e95e94adSJeff Niu //===----------------------------------------------------------------------===// 244e95e94adSJeff Niu 245e95e94adSJeff Niu LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 246e95e94adSJeff Niu MLIRContext *context, std::optional<Location> location, 247e95e94adSJeff Niu ValueShapeRange operands, DictionaryAttr attributes, 248e95e94adSJeff Niu OpaqueProperties properties, RegionRange regions, 249e95e94adSJeff Niu SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 250e95e94adSJeff Niu // Create return type consisting of the last element of the first operand. 251e95e94adSJeff Niu auto operandType = operands.front().getType(); 252e95e94adSJeff Niu auto sval = dyn_cast<ShapedType>(operandType); 253e95e94adSJeff Niu if (!sval) 254e95e94adSJeff Niu return emitOptionalError(location, "only shaped type operands allowed"); 255e95e94adSJeff Niu int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; 256e95e94adSJeff Niu auto type = IntegerType::get(context, 17); 257e95e94adSJeff Niu 258e95e94adSJeff Niu Attribute encoding; 259e95e94adSJeff Niu if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) 260e95e94adSJeff Niu encoding = rankedTy.getEncoding(); 261e95e94adSJeff Niu inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); 262e95e94adSJeff Niu return success(); 263e95e94adSJeff Niu } 264e95e94adSJeff Niu 265e95e94adSJeff Niu LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 266e95e94adSJeff Niu OpBuilder &builder, ValueRange operands, 267e95e94adSJeff Niu llvm::SmallVectorImpl<Value> &shapes) { 268e95e94adSJeff Niu shapes = SmallVector<Value, 1>{ 269e95e94adSJeff Niu builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 270e95e94adSJeff Niu return success(); 271e95e94adSJeff Niu } 272e95e94adSJeff Niu 273e95e94adSJeff Niu //===----------------------------------------------------------------------===// 274e95e94adSJeff Niu // OpWithResultShapeInterfaceOp 275e95e94adSJeff Niu //===----------------------------------------------------------------------===// 276e95e94adSJeff Niu 277e95e94adSJeff Niu LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 278e95e94adSJeff Niu OpBuilder &builder, ValueRange operands, 279e95e94adSJeff Niu llvm::SmallVectorImpl<Value> &shapes) { 280e95e94adSJeff Niu Location loc = getLoc(); 281e95e94adSJeff Niu shapes.reserve(operands.size()); 282e95e94adSJeff Niu for (Value operand : llvm::reverse(operands)) { 283e95e94adSJeff Niu auto rank = cast<RankedTensorType>(operand.getType()).getRank(); 284e95e94adSJeff Niu auto currShape = llvm::to_vector<4>( 285e95e94adSJeff Niu llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 286e95e94adSJeff Niu return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 287e95e94adSJeff Niu })); 288e95e94adSJeff Niu shapes.push_back(builder.create<tensor::FromElementsOp>( 289e95e94adSJeff Niu getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 290e95e94adSJeff Niu currShape)); 291e95e94adSJeff Niu } 292e95e94adSJeff Niu return success(); 293e95e94adSJeff Niu } 294e95e94adSJeff Niu 295e95e94adSJeff Niu //===----------------------------------------------------------------------===// 296e95e94adSJeff Niu // OpWithResultShapePerDimInterfaceOp 297e95e94adSJeff Niu //===----------------------------------------------------------------------===// 298e95e94adSJeff Niu 299e95e94adSJeff Niu LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 300e95e94adSJeff Niu OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 301e95e94adSJeff Niu Location loc = getLoc(); 302e95e94adSJeff Niu shapes.reserve(getNumOperands()); 303e95e94adSJeff Niu for (Value operand : llvm::reverse(getOperands())) { 304e95e94adSJeff Niu auto tensorType = cast<RankedTensorType>(operand.getType()); 305e95e94adSJeff Niu auto currShape = llvm::to_vector<4>(llvm::map_range( 306e95e94adSJeff Niu llvm::seq<int64_t>(0, tensorType.getRank()), 307e95e94adSJeff Niu [&](int64_t dim) -> OpFoldResult { 308e95e94adSJeff Niu return tensorType.isDynamicDim(dim) 309e95e94adSJeff Niu ? static_cast<OpFoldResult>( 310e95e94adSJeff Niu builder.createOrFold<tensor::DimOp>(loc, operand, 311e95e94adSJeff Niu dim)) 312e95e94adSJeff Niu : static_cast<OpFoldResult>( 313e95e94adSJeff Niu builder.getIndexAttr(tensorType.getDimSize(dim))); 314e95e94adSJeff Niu })); 315e95e94adSJeff Niu shapes.emplace_back(std::move(currShape)); 316e95e94adSJeff Niu } 317e95e94adSJeff Niu return success(); 318e95e94adSJeff Niu } 319e95e94adSJeff Niu 320e95e94adSJeff Niu //===----------------------------------------------------------------------===// 321e95e94adSJeff Niu // SideEffectOp 322e95e94adSJeff Niu //===----------------------------------------------------------------------===// 323e95e94adSJeff Niu 324e95e94adSJeff Niu namespace { 325e95e94adSJeff Niu /// A test resource for side effects. 326e95e94adSJeff Niu struct TestResource : public SideEffects::Resource::Base<TestResource> { 327e95e94adSJeff Niu MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 328e95e94adSJeff Niu 329e95e94adSJeff Niu StringRef getName() final { return "<Test>"; } 330e95e94adSJeff Niu }; 331e95e94adSJeff Niu } // namespace 332e95e94adSJeff Niu 333e95e94adSJeff Niu void SideEffectOp::getEffects( 334e95e94adSJeff Niu SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 335e95e94adSJeff Niu // Check for an effects attribute on the op instance. 336e95e94adSJeff Niu ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 337e95e94adSJeff Niu if (!effectsAttr) 338e95e94adSJeff Niu return; 339e95e94adSJeff Niu 340e95e94adSJeff Niu for (Attribute element : effectsAttr) { 341e95e94adSJeff Niu DictionaryAttr effectElement = cast<DictionaryAttr>(element); 342e95e94adSJeff Niu 343e95e94adSJeff Niu // Get the specific memory effect. 344e95e94adSJeff Niu MemoryEffects::Effect *effect = 345e95e94adSJeff Niu StringSwitch<MemoryEffects::Effect *>( 346e95e94adSJeff Niu cast<StringAttr>(effectElement.get("effect")).getValue()) 347e95e94adSJeff Niu .Case("allocate", MemoryEffects::Allocate::get()) 348e95e94adSJeff Niu .Case("free", MemoryEffects::Free::get()) 349e95e94adSJeff Niu .Case("read", MemoryEffects::Read::get()) 350e95e94adSJeff Niu .Case("write", MemoryEffects::Write::get()); 351e95e94adSJeff Niu 352e95e94adSJeff Niu // Check for a non-default resource to use. 353e95e94adSJeff Niu SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 354e95e94adSJeff Niu if (effectElement.get("test_resource")) 355e95e94adSJeff Niu resource = TestResource::get(); 356e95e94adSJeff Niu 357e95e94adSJeff Niu // Check for a result to affect. 358e95e94adSJeff Niu if (effectElement.get("on_result")) 3592c1ae801Sdonald chen effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); 360e95e94adSJeff Niu else if (Attribute ref = effectElement.get("on_reference")) 361e95e94adSJeff Niu effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); 362e95e94adSJeff Niu else 363e95e94adSJeff Niu effects.emplace_back(effect, resource); 364e95e94adSJeff Niu } 365e95e94adSJeff Niu } 366e95e94adSJeff Niu 367e95e94adSJeff Niu void SideEffectOp::getEffects( 368e95e94adSJeff Niu SmallVectorImpl<TestEffects::EffectInstance> &effects) { 369e95e94adSJeff Niu testSideEffectOpGetEffect(getOperation(), effects); 370e95e94adSJeff Niu } 371e95e94adSJeff Niu 3722c1ae801Sdonald chen void SideEffectWithRegionOp::getEffects( 3732c1ae801Sdonald chen SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3742c1ae801Sdonald chen // Check for an effects attribute on the op instance. 3752c1ae801Sdonald chen ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 3762c1ae801Sdonald chen if (!effectsAttr) 3772c1ae801Sdonald chen return; 3782c1ae801Sdonald chen 3792c1ae801Sdonald chen for (Attribute element : effectsAttr) { 3802c1ae801Sdonald chen DictionaryAttr effectElement = cast<DictionaryAttr>(element); 3812c1ae801Sdonald chen 3822c1ae801Sdonald chen // Get the specific memory effect. 3832c1ae801Sdonald chen MemoryEffects::Effect *effect = 3842c1ae801Sdonald chen StringSwitch<MemoryEffects::Effect *>( 3852c1ae801Sdonald chen cast<StringAttr>(effectElement.get("effect")).getValue()) 3862c1ae801Sdonald chen .Case("allocate", MemoryEffects::Allocate::get()) 3872c1ae801Sdonald chen .Case("free", MemoryEffects::Free::get()) 3882c1ae801Sdonald chen .Case("read", MemoryEffects::Read::get()) 3892c1ae801Sdonald chen .Case("write", MemoryEffects::Write::get()); 3902c1ae801Sdonald chen 3912c1ae801Sdonald chen // Check for a non-default resource to use. 3922c1ae801Sdonald chen SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 3932c1ae801Sdonald chen if (effectElement.get("test_resource")) 3942c1ae801Sdonald chen resource = TestResource::get(); 3952c1ae801Sdonald chen 3962c1ae801Sdonald chen // Check for a result to affect. 3972c1ae801Sdonald chen if (effectElement.get("on_result")) 3982c1ae801Sdonald chen effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); 3992c1ae801Sdonald chen else if (effectElement.get("on_operand")) 4002c1ae801Sdonald chen effects.emplace_back(effect, &getOperation()->getOpOperands()[0], 4012c1ae801Sdonald chen resource); 4022c1ae801Sdonald chen else if (effectElement.get("on_argument")) 4032c1ae801Sdonald chen effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0), 4042c1ae801Sdonald chen resource); 4052c1ae801Sdonald chen else if (Attribute ref = effectElement.get("on_reference")) 4062c1ae801Sdonald chen effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); 4072c1ae801Sdonald chen else 4082c1ae801Sdonald chen effects.emplace_back(effect, resource); 4092c1ae801Sdonald chen } 4102c1ae801Sdonald chen } 4112c1ae801Sdonald chen 4122c1ae801Sdonald chen void SideEffectWithRegionOp::getEffects( 4132c1ae801Sdonald chen SmallVectorImpl<TestEffects::EffectInstance> &effects) { 4142c1ae801Sdonald chen testSideEffectOpGetEffect(getOperation(), effects); 4152c1ae801Sdonald chen } 4162c1ae801Sdonald chen 417e95e94adSJeff Niu //===----------------------------------------------------------------------===// 418e95e94adSJeff Niu // StringAttrPrettyNameOp 419e95e94adSJeff Niu //===----------------------------------------------------------------------===// 420e95e94adSJeff Niu 421e95e94adSJeff Niu // This op has fancy handling of its SSA result name. 422e95e94adSJeff Niu ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 423e95e94adSJeff Niu OperationState &result) { 424e95e94adSJeff Niu // Add the result types. 425e95e94adSJeff Niu for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 426e95e94adSJeff Niu result.addTypes(parser.getBuilder().getIntegerType(32)); 427e95e94adSJeff Niu 428e95e94adSJeff Niu if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 429e95e94adSJeff Niu return failure(); 430e95e94adSJeff Niu 431e95e94adSJeff Niu // If the attribute dictionary contains no 'names' attribute, infer it from 432e95e94adSJeff Niu // the SSA name (if specified). 433e95e94adSJeff Niu bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 434e95e94adSJeff Niu return attr.getName() == "names"; 435e95e94adSJeff Niu }); 436e95e94adSJeff Niu 437e95e94adSJeff Niu // If there was no name specified, check to see if there was a useful name 438e95e94adSJeff Niu // specified in the asm file. 439e95e94adSJeff Niu if (hadNames || parser.getNumResults() == 0) 440e95e94adSJeff Niu return success(); 441e95e94adSJeff Niu 442e95e94adSJeff Niu SmallVector<StringRef, 4> names; 443e95e94adSJeff Niu auto *context = result.getContext(); 444e95e94adSJeff Niu 445e95e94adSJeff Niu for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 446e95e94adSJeff Niu auto resultName = parser.getResultName(i); 447e95e94adSJeff Niu StringRef nameStr; 448e95e94adSJeff Niu if (!resultName.first.empty() && !isdigit(resultName.first[0])) 449e95e94adSJeff Niu nameStr = resultName.first; 450e95e94adSJeff Niu 451e95e94adSJeff Niu names.push_back(nameStr); 452e95e94adSJeff Niu } 453e95e94adSJeff Niu 454e95e94adSJeff Niu auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 455e95e94adSJeff Niu result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 456e95e94adSJeff Niu return success(); 457e95e94adSJeff Niu } 458e95e94adSJeff Niu 459e95e94adSJeff Niu void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 460e95e94adSJeff Niu // Note that we only need to print the "name" attribute if the asmprinter 461e95e94adSJeff Niu // result name disagrees with it. This can happen in strange cases, e.g. 462e95e94adSJeff Niu // when there are conflicts. 463e95e94adSJeff Niu bool namesDisagree = getNames().size() != getNumResults(); 464e95e94adSJeff Niu 465e95e94adSJeff Niu SmallString<32> resultNameStr; 466e95e94adSJeff Niu for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 467e95e94adSJeff Niu resultNameStr.clear(); 468e95e94adSJeff Niu llvm::raw_svector_ostream tmpStream(resultNameStr); 469e95e94adSJeff Niu p.printOperand(getResult(i), tmpStream); 470e95e94adSJeff Niu 471e95e94adSJeff Niu auto expectedName = dyn_cast<StringAttr>(getNames()[i]); 472e95e94adSJeff Niu if (!expectedName || 473e95e94adSJeff Niu tmpStream.str().drop_front() != expectedName.getValue()) { 474e95e94adSJeff Niu namesDisagree = true; 475e95e94adSJeff Niu } 476e95e94adSJeff Niu } 477e95e94adSJeff Niu 478e95e94adSJeff Niu if (namesDisagree) 479e95e94adSJeff Niu p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 480e95e94adSJeff Niu else 481e95e94adSJeff Niu p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 482e95e94adSJeff Niu } 483e95e94adSJeff Niu 484e95e94adSJeff Niu // We set the SSA name in the asm syntax to the contents of the name 485e95e94adSJeff Niu // attribute. 486e95e94adSJeff Niu void StringAttrPrettyNameOp::getAsmResultNames( 487e95e94adSJeff Niu function_ref<void(Value, StringRef)> setNameFn) { 488e95e94adSJeff Niu 489e95e94adSJeff Niu auto value = getNames(); 490e95e94adSJeff Niu for (size_t i = 0, e = value.size(); i != e; ++i) 491e95e94adSJeff Niu if (auto str = dyn_cast<StringAttr>(value[i])) 492e95e94adSJeff Niu if (!str.getValue().empty()) 493e95e94adSJeff Niu setNameFn(getResult(i), str.getValue()); 494e95e94adSJeff Niu } 495e95e94adSJeff Niu 496e95e94adSJeff Niu //===----------------------------------------------------------------------===// 497e95e94adSJeff Niu // CustomResultsNameOp 498e95e94adSJeff Niu //===----------------------------------------------------------------------===// 499e95e94adSJeff Niu 500e95e94adSJeff Niu void CustomResultsNameOp::getAsmResultNames( 501e95e94adSJeff Niu function_ref<void(Value, StringRef)> setNameFn) { 502e95e94adSJeff Niu ArrayAttr value = getNames(); 503e95e94adSJeff Niu for (size_t i = 0, e = value.size(); i != e; ++i) 504e95e94adSJeff Niu if (auto str = dyn_cast<StringAttr>(value[i])) 505e95e94adSJeff Niu if (!str.empty()) 506e95e94adSJeff Niu setNameFn(getResult(i), str.getValue()); 507e95e94adSJeff Niu } 508e95e94adSJeff Niu 509e95e94adSJeff Niu //===----------------------------------------------------------------------===// 510*3c64f863SHongren Zheng // ResultNameFromTypeOp 511*3c64f863SHongren Zheng //===----------------------------------------------------------------------===// 512*3c64f863SHongren Zheng 513*3c64f863SHongren Zheng void ResultNameFromTypeOp::getAsmResultNames( 514*3c64f863SHongren Zheng function_ref<void(Value, StringRef)> setNameFn) { 515*3c64f863SHongren Zheng auto result = getResult(); 516*3c64f863SHongren Zheng auto setResultNameFn = [&](::llvm::StringRef name) { 517*3c64f863SHongren Zheng setNameFn(result, name); 518*3c64f863SHongren Zheng }; 519*3c64f863SHongren Zheng auto opAsmTypeInterface = 520*3c64f863SHongren Zheng ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType()); 521*3c64f863SHongren Zheng opAsmTypeInterface.getAsmName(setResultNameFn); 522*3c64f863SHongren Zheng } 523*3c64f863SHongren Zheng 524*3c64f863SHongren Zheng //===----------------------------------------------------------------------===// 525*3c64f863SHongren Zheng // BlockArgumentNameFromTypeOp 526*3c64f863SHongren Zheng //===----------------------------------------------------------------------===// 527*3c64f863SHongren Zheng 528*3c64f863SHongren Zheng void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames( 529*3c64f863SHongren Zheng ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) { 530*3c64f863SHongren Zheng for (auto &block : region) { 531*3c64f863SHongren Zheng for (auto arg : block.getArguments()) { 532*3c64f863SHongren Zheng if (auto opAsmTypeInterface = 533*3c64f863SHongren Zheng ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) { 534*3c64f863SHongren Zheng auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); }; 535*3c64f863SHongren Zheng opAsmTypeInterface.getAsmName(setArgNameFn); 536*3c64f863SHongren Zheng } 537*3c64f863SHongren Zheng } 538*3c64f863SHongren Zheng } 539*3c64f863SHongren Zheng } 540*3c64f863SHongren Zheng 541*3c64f863SHongren Zheng //===----------------------------------------------------------------------===// 542e95e94adSJeff Niu // ResultTypeWithTraitOp 543e95e94adSJeff Niu //===----------------------------------------------------------------------===// 544e95e94adSJeff Niu 545e95e94adSJeff Niu LogicalResult ResultTypeWithTraitOp::verify() { 546e95e94adSJeff Niu if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 547e95e94adSJeff Niu return success(); 548e95e94adSJeff Niu return emitError("result type should have trait 'TestTypeTrait'"); 549e95e94adSJeff Niu } 550e95e94adSJeff Niu 551e95e94adSJeff Niu //===----------------------------------------------------------------------===// 552e95e94adSJeff Niu // AttrWithTraitOp 553e95e94adSJeff Niu //===----------------------------------------------------------------------===// 554e95e94adSJeff Niu 555e95e94adSJeff Niu LogicalResult AttrWithTraitOp::verify() { 556e95e94adSJeff Niu if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 557e95e94adSJeff Niu return success(); 558e95e94adSJeff Niu return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 559e95e94adSJeff Niu } 560e95e94adSJeff Niu 561e95e94adSJeff Niu //===----------------------------------------------------------------------===// 562e95e94adSJeff Niu // RegionIfOp 563e95e94adSJeff Niu //===----------------------------------------------------------------------===// 564e95e94adSJeff Niu 565e95e94adSJeff Niu void RegionIfOp::print(OpAsmPrinter &p) { 566e95e94adSJeff Niu p << " "; 567e95e94adSJeff Niu p.printOperands(getOperands()); 568e95e94adSJeff Niu p << ": " << getOperandTypes(); 569e95e94adSJeff Niu p.printArrowTypeList(getResultTypes()); 570e95e94adSJeff Niu p << " then "; 571e95e94adSJeff Niu p.printRegion(getThenRegion(), 572e95e94adSJeff Niu /*printEntryBlockArgs=*/true, 573e95e94adSJeff Niu /*printBlockTerminators=*/true); 574e95e94adSJeff Niu p << " else "; 575e95e94adSJeff Niu p.printRegion(getElseRegion(), 576e95e94adSJeff Niu /*printEntryBlockArgs=*/true, 577e95e94adSJeff Niu /*printBlockTerminators=*/true); 578e95e94adSJeff Niu p << " join "; 579e95e94adSJeff Niu p.printRegion(getJoinRegion(), 580e95e94adSJeff Niu /*printEntryBlockArgs=*/true, 581e95e94adSJeff Niu /*printBlockTerminators=*/true); 582e95e94adSJeff Niu } 583e95e94adSJeff Niu 584e95e94adSJeff Niu ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 585e95e94adSJeff Niu SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 586e95e94adSJeff Niu SmallVector<Type, 2> operandTypes; 587e95e94adSJeff Niu 588e95e94adSJeff Niu result.regions.reserve(3); 589e95e94adSJeff Niu Region *thenRegion = result.addRegion(); 590e95e94adSJeff Niu Region *elseRegion = result.addRegion(); 591e95e94adSJeff Niu Region *joinRegion = result.addRegion(); 592e95e94adSJeff Niu 593e95e94adSJeff Niu // Parse operand, type and arrow type lists. 594e95e94adSJeff Niu if (parser.parseOperandList(operandInfos) || 595e95e94adSJeff Niu parser.parseColonTypeList(operandTypes) || 596e95e94adSJeff Niu parser.parseArrowTypeList(result.types)) 597e95e94adSJeff Niu return failure(); 598e95e94adSJeff Niu 599e95e94adSJeff Niu // Parse all attached regions. 600e95e94adSJeff Niu if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 601e95e94adSJeff Niu parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 602e95e94adSJeff Niu parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 603e95e94adSJeff Niu return failure(); 604e95e94adSJeff Niu 605e95e94adSJeff Niu return parser.resolveOperands(operandInfos, operandTypes, 606e95e94adSJeff Niu parser.getCurrentLocation(), result.operands); 607e95e94adSJeff Niu } 608e95e94adSJeff Niu 609e95e94adSJeff Niu OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { 610e95e94adSJeff Niu assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && 611e95e94adSJeff Niu "invalid region index"); 612e95e94adSJeff Niu return getOperands(); 613e95e94adSJeff Niu } 614e95e94adSJeff Niu 615e95e94adSJeff Niu void RegionIfOp::getSuccessorRegions( 616e95e94adSJeff Niu RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 617e95e94adSJeff Niu // We always branch to the join region. 618e95e94adSJeff Niu if (!point.isParent()) { 619e95e94adSJeff Niu if (point != getJoinRegion()) 620e95e94adSJeff Niu regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 621e95e94adSJeff Niu else 622e95e94adSJeff Niu regions.push_back(RegionSuccessor(getResults())); 623e95e94adSJeff Niu return; 624e95e94adSJeff Niu } 625e95e94adSJeff Niu 626e95e94adSJeff Niu // The then and else regions are the entry regions of this op. 627e95e94adSJeff Niu regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 628e95e94adSJeff Niu regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 629e95e94adSJeff Niu } 630e95e94adSJeff Niu 631e95e94adSJeff Niu void RegionIfOp::getRegionInvocationBounds( 632e95e94adSJeff Niu ArrayRef<Attribute> operands, 633e95e94adSJeff Niu SmallVectorImpl<InvocationBounds> &invocationBounds) { 634e95e94adSJeff Niu // Each region is invoked at most once. 635e95e94adSJeff Niu invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 636e95e94adSJeff Niu } 637e95e94adSJeff Niu 638e95e94adSJeff Niu //===----------------------------------------------------------------------===// 639e95e94adSJeff Niu // AnyCondOp 640e95e94adSJeff Niu //===----------------------------------------------------------------------===// 641e95e94adSJeff Niu 642e95e94adSJeff Niu void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, 643e95e94adSJeff Niu SmallVectorImpl<RegionSuccessor> ®ions) { 644e95e94adSJeff Niu // The parent op branches into the only region, and the region branches back 645e95e94adSJeff Niu // to the parent op. 646e95e94adSJeff Niu if (point.isParent()) 647e95e94adSJeff Niu regions.emplace_back(&getRegion()); 648e95e94adSJeff Niu else 649e95e94adSJeff Niu regions.emplace_back(getResults()); 650e95e94adSJeff Niu } 651e95e94adSJeff Niu 652e95e94adSJeff Niu void AnyCondOp::getRegionInvocationBounds( 653e95e94adSJeff Niu ArrayRef<Attribute> operands, 654e95e94adSJeff Niu SmallVectorImpl<InvocationBounds> &invocationBounds) { 655e95e94adSJeff Niu invocationBounds.emplace_back(1, 1); 656e95e94adSJeff Niu } 657e95e94adSJeff Niu 658e95e94adSJeff Niu //===----------------------------------------------------------------------===// 659e95e94adSJeff Niu // SingleBlockImplicitTerminatorOp 660e95e94adSJeff Niu //===----------------------------------------------------------------------===// 661e95e94adSJeff Niu 662e95e94adSJeff Niu /// Testing the correctness of some traits. 663e95e94adSJeff Niu static_assert( 664e95e94adSJeff Niu llvm::is_detected<OpTrait::has_implicit_terminator_t, 665e95e94adSJeff Niu SingleBlockImplicitTerminatorOp>::value, 666e95e94adSJeff Niu "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 667e95e94adSJeff Niu static_assert(OpTrait::hasSingleBlockImplicitTerminator< 668e95e94adSJeff Niu SingleBlockImplicitTerminatorOp>::value, 669e95e94adSJeff Niu "hasSingleBlockImplicitTerminator does not match " 670e95e94adSJeff Niu "SingleBlockImplicitTerminatorOp"); 671e95e94adSJeff Niu 672e95e94adSJeff Niu //===----------------------------------------------------------------------===// 673e95e94adSJeff Niu // SingleNoTerminatorCustomAsmOp 674e95e94adSJeff Niu //===----------------------------------------------------------------------===// 675e95e94adSJeff Niu 676e95e94adSJeff Niu ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 677e95e94adSJeff Niu OperationState &state) { 678e95e94adSJeff Niu Region *body = state.addRegion(); 679e95e94adSJeff Niu if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 680e95e94adSJeff Niu return failure(); 681e95e94adSJeff Niu return success(); 682e95e94adSJeff Niu } 683e95e94adSJeff Niu 684e95e94adSJeff Niu void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 685e95e94adSJeff Niu printer.printRegion( 686e95e94adSJeff Niu getRegion(), /*printEntryBlockArgs=*/false, 687e95e94adSJeff Niu // This op has a single block without terminators. But explicitly mark 688e95e94adSJeff Niu // as not printing block terminators for testing. 689e95e94adSJeff Niu /*printBlockTerminators=*/false); 690e95e94adSJeff Niu } 691e95e94adSJeff Niu 692e95e94adSJeff Niu //===----------------------------------------------------------------------===// 693e95e94adSJeff Niu // TestVerifiersOp 694e95e94adSJeff Niu //===----------------------------------------------------------------------===// 695e95e94adSJeff Niu 696e95e94adSJeff Niu LogicalResult TestVerifiersOp::verify() { 697e95e94adSJeff Niu if (!getRegion().hasOneBlock()) 698e95e94adSJeff Niu return emitOpError("`hasOneBlock` trait hasn't been verified"); 699e95e94adSJeff Niu 700e95e94adSJeff Niu Operation *definingOp = getInput().getDefiningOp(); 701e95e94adSJeff Niu if (definingOp && failed(mlir::verify(definingOp))) 702e95e94adSJeff Niu return emitOpError("operand hasn't been verified"); 703e95e94adSJeff Niu 704e95e94adSJeff Niu // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier 705e95e94adSJeff Niu // loop. 706e95e94adSJeff Niu mlir::emitRemark(getLoc(), "success run of verifier"); 707e95e94adSJeff Niu 708e95e94adSJeff Niu return success(); 709e95e94adSJeff Niu } 710e95e94adSJeff Niu 711e95e94adSJeff Niu LogicalResult TestVerifiersOp::verifyRegions() { 712e95e94adSJeff Niu if (!getRegion().hasOneBlock()) 713e95e94adSJeff Niu return emitOpError("`hasOneBlock` trait hasn't been verified"); 714e95e94adSJeff Niu 715e95e94adSJeff Niu for (Block &block : getRegion()) 716e95e94adSJeff Niu for (Operation &op : block) 717e95e94adSJeff Niu if (failed(mlir::verify(&op))) 718e95e94adSJeff Niu return emitOpError("nested op hasn't been verified"); 719e95e94adSJeff Niu 720e95e94adSJeff Niu // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier 721e95e94adSJeff Niu // loop. 722e95e94adSJeff Niu mlir::emitRemark(getLoc(), "success run of region verifier"); 723e95e94adSJeff Niu 724e95e94adSJeff Niu return success(); 725e95e94adSJeff Niu } 726e95e94adSJeff Niu 727e95e94adSJeff Niu //===----------------------------------------------------------------------===// 728e95e94adSJeff Niu // Test InferIntRangeInterface 729e95e94adSJeff Niu //===----------------------------------------------------------------------===// 730e95e94adSJeff Niu 731e95e94adSJeff Niu //===----------------------------------------------------------------------===// 732e95e94adSJeff Niu // TestWithBoundsOp 733e95e94adSJeff Niu 734e95e94adSJeff Niu void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 735e95e94adSJeff Niu SetIntRangeFn setResultRanges) { 736e95e94adSJeff Niu setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); 737e95e94adSJeff Niu } 738e95e94adSJeff Niu 739e95e94adSJeff Niu //===----------------------------------------------------------------------===// 740e95e94adSJeff Niu // TestWithBoundsRegionOp 741e95e94adSJeff Niu 742e95e94adSJeff Niu ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, 743e95e94adSJeff Niu OperationState &result) { 744e95e94adSJeff Niu if (parser.parseOptionalAttrDict(result.attributes)) 745e95e94adSJeff Niu return failure(); 746e95e94adSJeff Niu 747e95e94adSJeff Niu // Parse the input argument 748e95e94adSJeff Niu OpAsmParser::Argument argInfo; 749acd10074SFelix Schneider if (failed(parser.parseArgument(argInfo, true))) 750e95e94adSJeff Niu return failure(); 751e95e94adSJeff Niu 752e95e94adSJeff Niu // Parse the body region, and reuse the operand info as the argument info. 753e95e94adSJeff Niu Region *body = result.addRegion(); 754e95e94adSJeff Niu return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); 755e95e94adSJeff Niu } 756e95e94adSJeff Niu 757e95e94adSJeff Niu void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { 758e95e94adSJeff Niu p.printOptionalAttrDict((*this)->getAttrs()); 759e95e94adSJeff Niu p << ' '; 760e95e94adSJeff Niu p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, 761acd10074SFelix Schneider /*omitType=*/false); 762e95e94adSJeff Niu p << ' '; 763e95e94adSJeff Niu p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 764e95e94adSJeff Niu } 765e95e94adSJeff Niu 766e95e94adSJeff Niu void TestWithBoundsRegionOp::inferResultRanges( 767e95e94adSJeff Niu ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 768e95e94adSJeff Niu Value arg = getRegion().getArgument(0); 769e95e94adSJeff Niu setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); 770e95e94adSJeff Niu } 771e95e94adSJeff Niu 772e95e94adSJeff Niu //===----------------------------------------------------------------------===// 773e95e94adSJeff Niu // TestIncrementOp 774e95e94adSJeff Niu 775e95e94adSJeff Niu void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 776e95e94adSJeff Niu SetIntRangeFn setResultRanges) { 777e95e94adSJeff Niu const ConstantIntRanges &range = argRanges[0]; 778e95e94adSJeff Niu APInt one(range.umin().getBitWidth(), 1); 779e95e94adSJeff Niu setResultRanges(getResult(), 780e95e94adSJeff Niu {range.umin().uadd_sat(one), range.umax().uadd_sat(one), 781e95e94adSJeff Niu range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); 782e95e94adSJeff Niu } 783e95e94adSJeff Niu 784e95e94adSJeff Niu //===----------------------------------------------------------------------===// 785e95e94adSJeff Niu // TestReflectBoundsOp 786e95e94adSJeff Niu 787e95e94adSJeff Niu void TestReflectBoundsOp::inferResultRanges( 788e95e94adSJeff Niu ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 789e95e94adSJeff Niu const ConstantIntRanges &range = argRanges[0]; 790e95e94adSJeff Niu MLIRContext *ctx = getContext(); 791e95e94adSJeff Niu Builder b(ctx); 7924636b66dSFelix Schneider Type sIntTy, uIntTy; 7934636b66dSFelix Schneider // For plain `IntegerType`s, we can derive the appropriate signed and unsigned 7944636b66dSFelix Schneider // Types for the Attributes. 795f54cdc5dSIvan Butygin Type type = getElementTypeOrSelf(getType()); 796f54cdc5dSIvan Butygin if (auto intTy = llvm::dyn_cast<IntegerType>(type)) { 7974636b66dSFelix Schneider unsigned bitwidth = intTy.getWidth(); 7984636b66dSFelix Schneider sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true); 7994636b66dSFelix Schneider uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false); 8004636b66dSFelix Schneider } else 801f54cdc5dSIvan Butygin sIntTy = uIntTy = type; 8024636b66dSFelix Schneider 8034636b66dSFelix Schneider setUminAttr(b.getIntegerAttr(uIntTy, range.umin())); 8044636b66dSFelix Schneider setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax())); 8054636b66dSFelix Schneider setSminAttr(b.getIntegerAttr(sIntTy, range.smin())); 8064636b66dSFelix Schneider setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax())); 807e95e94adSJeff Niu setResultRanges(getResult(), range); 808e95e94adSJeff Niu } 809e95e94adSJeff Niu 810e95e94adSJeff Niu //===----------------------------------------------------------------------===// 811e95e94adSJeff Niu // ConversionFuncOp 812e95e94adSJeff Niu //===----------------------------------------------------------------------===// 813e95e94adSJeff Niu 814e95e94adSJeff Niu ParseResult ConversionFuncOp::parse(OpAsmParser &parser, 815e95e94adSJeff Niu OperationState &result) { 816e95e94adSJeff Niu auto buildFuncType = 817e95e94adSJeff Niu [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 818e95e94adSJeff Niu function_interface_impl::VariadicFlag, 819e95e94adSJeff Niu std::string &) { return builder.getFunctionType(argTypes, results); }; 820e95e94adSJeff Niu 821e95e94adSJeff Niu return function_interface_impl::parseFunctionOp( 822e95e94adSJeff Niu parser, result, /*allowVariadic=*/false, 823e95e94adSJeff Niu getFunctionTypeAttrName(result.name), buildFuncType, 824e95e94adSJeff Niu getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 825e95e94adSJeff Niu } 826e95e94adSJeff Niu 827e95e94adSJeff Niu void ConversionFuncOp::print(OpAsmPrinter &p) { 828e95e94adSJeff Niu function_interface_impl::printFunctionOp( 829e95e94adSJeff Niu p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 830e95e94adSJeff Niu getArgAttrsAttrName(), getResAttrsAttrName()); 831e95e94adSJeff Niu } 832e95e94adSJeff Niu 833e95e94adSJeff Niu //===----------------------------------------------------------------------===// 834e95e94adSJeff Niu // ReifyBoundOp 835e95e94adSJeff Niu //===----------------------------------------------------------------------===// 836e95e94adSJeff Niu 837e95e94adSJeff Niu mlir::presburger::BoundType ReifyBoundOp::getBoundType() { 838e95e94adSJeff Niu if (getType() == "EQ") 839e95e94adSJeff Niu return mlir::presburger::BoundType::EQ; 840e95e94adSJeff Niu if (getType() == "LB") 841e95e94adSJeff Niu return mlir::presburger::BoundType::LB; 842e95e94adSJeff Niu if (getType() == "UB") 843e95e94adSJeff Niu return mlir::presburger::BoundType::UB; 844e95e94adSJeff Niu llvm_unreachable("invalid bound type"); 845e95e94adSJeff Niu } 846e95e94adSJeff Niu 847e95e94adSJeff Niu LogicalResult ReifyBoundOp::verify() { 848e95e94adSJeff Niu if (isa<ShapedType>(getVar().getType())) { 849e95e94adSJeff Niu if (!getDim().has_value()) 850e95e94adSJeff Niu return emitOpError("expected 'dim' attribute for shaped type variable"); 851e95e94adSJeff Niu } else if (getVar().getType().isIndex()) { 852e95e94adSJeff Niu if (getDim().has_value()) 853e95e94adSJeff Niu return emitOpError("unexpected 'dim' attribute for index variable"); 854e95e94adSJeff Niu } else { 855e95e94adSJeff Niu return emitOpError("expected index-typed variable or shape type variable"); 856e95e94adSJeff Niu } 857e95e94adSJeff Niu if (getConstant() && getScalable()) 858e95e94adSJeff Niu return emitOpError("'scalable' and 'constant' are mutually exlusive"); 859e95e94adSJeff Niu if (getScalable() != getVscaleMin().has_value()) 860e95e94adSJeff Niu return emitOpError("expected 'vscale_min' if and only if 'scalable'"); 861e95e94adSJeff Niu if (getScalable() != getVscaleMax().has_value()) 862e95e94adSJeff Niu return emitOpError("expected 'vscale_min' if and only if 'scalable'"); 863e95e94adSJeff Niu return success(); 864e95e94adSJeff Niu } 865e95e94adSJeff Niu 866e95e94adSJeff Niu ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { 867e95e94adSJeff Niu if (getDim().has_value()) 868e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); 869e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable(getVar()); 870e95e94adSJeff Niu } 871e95e94adSJeff Niu 872e95e94adSJeff Niu //===----------------------------------------------------------------------===// 873e95e94adSJeff Niu // CompareOp 874e95e94adSJeff Niu //===----------------------------------------------------------------------===// 875e95e94adSJeff Niu 876e95e94adSJeff Niu ValueBoundsConstraintSet::ComparisonOperator 877e95e94adSJeff Niu CompareOp::getComparisonOperator() { 878e95e94adSJeff Niu if (getCmp() == "EQ") 879e95e94adSJeff Niu return ValueBoundsConstraintSet::ComparisonOperator::EQ; 880e95e94adSJeff Niu if (getCmp() == "LT") 881e95e94adSJeff Niu return ValueBoundsConstraintSet::ComparisonOperator::LT; 882e95e94adSJeff Niu if (getCmp() == "LE") 883e95e94adSJeff Niu return ValueBoundsConstraintSet::ComparisonOperator::LE; 884e95e94adSJeff Niu if (getCmp() == "GT") 885e95e94adSJeff Niu return ValueBoundsConstraintSet::ComparisonOperator::GT; 886e95e94adSJeff Niu if (getCmp() == "GE") 887e95e94adSJeff Niu return ValueBoundsConstraintSet::ComparisonOperator::GE; 888e95e94adSJeff Niu llvm_unreachable("invalid comparison operator"); 889e95e94adSJeff Niu } 890e95e94adSJeff Niu 891e95e94adSJeff Niu mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { 892e95e94adSJeff Niu if (!getLhsMap()) 893e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); 894e95e94adSJeff Niu SmallVector<Value> mapOperands( 895e95e94adSJeff Niu getVarOperands().slice(0, getLhsMap()->getNumInputs())); 896e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); 897e95e94adSJeff Niu } 898e95e94adSJeff Niu 899e95e94adSJeff Niu mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { 900e95e94adSJeff Niu int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; 901e95e94adSJeff Niu if (!getRhsMap()) 902e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable( 903e95e94adSJeff Niu getVarOperands()[rhsOperandsBegin]); 904e95e94adSJeff Niu SmallVector<Value> mapOperands( 905e95e94adSJeff Niu getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); 906e95e94adSJeff Niu return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); 907e95e94adSJeff Niu } 908e95e94adSJeff Niu 909e95e94adSJeff Niu LogicalResult CompareOp::verify() { 910e95e94adSJeff Niu if (getCompose() && (getLhsMap() || getRhsMap())) 911e95e94adSJeff Niu return emitOpError( 912e95e94adSJeff Niu "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); 913e95e94adSJeff Niu int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; 914e95e94adSJeff Niu expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; 915e95e94adSJeff Niu if (getVarOperands().size() != size_t(expectedNumOperands)) 916e95e94adSJeff Niu return emitOpError("expected ") 917e95e94adSJeff Niu << expectedNumOperands << " operands, but got " 918e95e94adSJeff Niu << getVarOperands().size(); 919e95e94adSJeff Niu return success(); 920e95e94adSJeff Niu } 921e95e94adSJeff Niu 922e95e94adSJeff Niu //===----------------------------------------------------------------------===// 9234513050fSChristian Ulmann // TestOpInPlaceSelfFold 9244513050fSChristian Ulmann //===----------------------------------------------------------------------===// 9254513050fSChristian Ulmann 9264513050fSChristian Ulmann OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { 9274513050fSChristian Ulmann if (!getFolded()) { 9284513050fSChristian Ulmann // The folder adds the "folded" if not present. 9294513050fSChristian Ulmann setFolded(true); 9304513050fSChristian Ulmann return getResult(); 9314513050fSChristian Ulmann } 9324513050fSChristian Ulmann return {}; 9334513050fSChristian Ulmann } 9344513050fSChristian Ulmann 9354513050fSChristian Ulmann //===----------------------------------------------------------------------===// 936e95e94adSJeff Niu // TestOpFoldWithFoldAdaptor 937e95e94adSJeff Niu //===----------------------------------------------------------------------===// 938e95e94adSJeff Niu 939e95e94adSJeff Niu OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { 940e95e94adSJeff Niu int64_t sum = 0; 941e95e94adSJeff Niu if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp())) 942e95e94adSJeff Niu sum += value.getValue().getSExtValue(); 943e95e94adSJeff Niu 944e95e94adSJeff Niu for (Attribute attr : adaptor.getVariadic()) 945e95e94adSJeff Niu if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) 946e95e94adSJeff Niu sum += 2 * value.getValue().getSExtValue(); 947e95e94adSJeff Niu 948e95e94adSJeff Niu for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) 949e95e94adSJeff Niu for (Attribute attr : attrs) 950e95e94adSJeff Niu if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) 951e95e94adSJeff Niu sum += 3 * value.getValue().getSExtValue(); 952e95e94adSJeff Niu 953e95e94adSJeff Niu sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); 954e95e94adSJeff Niu 955e95e94adSJeff Niu return IntegerAttr::get(getType(), sum); 956e95e94adSJeff Niu } 957e95e94adSJeff Niu 958e95e94adSJeff Niu //===----------------------------------------------------------------------===// 959e95e94adSJeff Niu // OpWithInferTypeAdaptorInterfaceOp 960e95e94adSJeff Niu //===----------------------------------------------------------------------===// 961e95e94adSJeff Niu 962e95e94adSJeff Niu LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( 963e95e94adSJeff Niu MLIRContext *, std::optional<Location> location, 964e95e94adSJeff Niu OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, 965e95e94adSJeff Niu SmallVectorImpl<Type> &inferredReturnTypes) { 966e95e94adSJeff Niu if (adaptor.getX().getType() != adaptor.getY().getType()) { 967e95e94adSJeff Niu return emitOptionalError(location, "operand type mismatch ", 968e95e94adSJeff Niu adaptor.getX().getType(), " vs ", 969e95e94adSJeff Niu adaptor.getY().getType()); 970e95e94adSJeff Niu } 971e95e94adSJeff Niu inferredReturnTypes.assign({adaptor.getX().getType()}); 972e95e94adSJeff Niu return success(); 973e95e94adSJeff Niu } 974e95e94adSJeff Niu 975e95e94adSJeff Niu //===----------------------------------------------------------------------===// 976e95e94adSJeff Niu // OpWithRefineTypeInterfaceOp 977e95e94adSJeff Niu //===----------------------------------------------------------------------===// 978e95e94adSJeff Niu 979e95e94adSJeff Niu // TODO: We should be able to only define either inferReturnType or 980e95e94adSJeff Niu // refineReturnType, currently only refineReturnType can be omitted. 981e95e94adSJeff Niu LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( 982e95e94adSJeff Niu MLIRContext *context, std::optional<Location> location, ValueRange operands, 983e95e94adSJeff Niu DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 984e95e94adSJeff Niu SmallVectorImpl<Type> &returnTypes) { 985e95e94adSJeff Niu returnTypes.clear(); 986e95e94adSJeff Niu return OpWithRefineTypeInterfaceOp::refineReturnTypes( 987e95e94adSJeff Niu context, location, operands, attributes, properties, regions, 988e95e94adSJeff Niu returnTypes); 989e95e94adSJeff Niu } 990e95e94adSJeff Niu 991e95e94adSJeff Niu LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( 992e95e94adSJeff Niu MLIRContext *, std::optional<Location> location, ValueRange operands, 993e95e94adSJeff Niu DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 994e95e94adSJeff Niu SmallVectorImpl<Type> &returnTypes) { 995e95e94adSJeff Niu if (operands[0].getType() != operands[1].getType()) { 996e95e94adSJeff Niu return emitOptionalError(location, "operand type mismatch ", 997e95e94adSJeff Niu operands[0].getType(), " vs ", 998e95e94adSJeff Niu operands[1].getType()); 999e95e94adSJeff Niu } 1000e95e94adSJeff Niu // TODO: Add helper to make this more concise to write. 1001e95e94adSJeff Niu if (returnTypes.empty()) 1002e95e94adSJeff Niu returnTypes.resize(1, nullptr); 1003e95e94adSJeff Niu if (returnTypes[0] && returnTypes[0] != operands[0].getType()) 1004e95e94adSJeff Niu return emitOptionalError(location, 1005e95e94adSJeff Niu "required first operand and result to match"); 1006e95e94adSJeff Niu returnTypes[0] = operands[0].getType(); 1007e95e94adSJeff Niu return success(); 1008e95e94adSJeff Niu } 1009e95e94adSJeff Niu 1010e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1011e95e94adSJeff Niu // OpWithShapedTypeInferTypeAdaptorInterfaceOp 1012e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1013e95e94adSJeff Niu 1014e95e94adSJeff Niu LogicalResult 1015e95e94adSJeff Niu OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( 1016e95e94adSJeff Niu MLIRContext *context, std::optional<Location> location, 1017e95e94adSJeff Niu OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, 1018e95e94adSJeff Niu SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1019e95e94adSJeff Niu // Create return type consisting of the last element of the first operand. 1020e95e94adSJeff Niu auto operandType = adaptor.getOperand1().getType(); 1021e95e94adSJeff Niu auto sval = dyn_cast<ShapedType>(operandType); 1022e95e94adSJeff Niu if (!sval) 1023e95e94adSJeff Niu return emitOptionalError(location, "only shaped type operands allowed"); 1024e95e94adSJeff Niu int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; 1025e95e94adSJeff Niu auto type = IntegerType::get(context, 17); 1026e95e94adSJeff Niu 1027e95e94adSJeff Niu Attribute encoding; 1028e95e94adSJeff Niu if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) 1029e95e94adSJeff Niu encoding = rankedTy.getEncoding(); 1030e95e94adSJeff Niu inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); 1031e95e94adSJeff Niu return success(); 1032e95e94adSJeff Niu } 1033e95e94adSJeff Niu 1034e95e94adSJeff Niu LogicalResult 1035e95e94adSJeff Niu OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( 1036e95e94adSJeff Niu OpBuilder &builder, ValueRange operands, 1037e95e94adSJeff Niu llvm::SmallVectorImpl<Value> &shapes) { 1038e95e94adSJeff Niu shapes = SmallVector<Value, 1>{ 1039e95e94adSJeff Niu builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1040e95e94adSJeff Niu return success(); 1041e95e94adSJeff Niu } 1042e95e94adSJeff Niu 1043e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1044e95e94adSJeff Niu // TestOpWithPropertiesAndInferredType 1045e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1046e95e94adSJeff Niu 1047e95e94adSJeff Niu LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( 1048e95e94adSJeff Niu MLIRContext *context, std::optional<Location>, ValueRange operands, 1049e95e94adSJeff Niu DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 1050e95e94adSJeff Niu SmallVectorImpl<Type> &inferredReturnTypes) { 1051e95e94adSJeff Niu 1052e95e94adSJeff Niu Adaptor adaptor(operands, attributes, properties, regions); 1053e95e94adSJeff Niu inferredReturnTypes.push_back(IntegerType::get( 1054e95e94adSJeff Niu context, adaptor.getLhs() + adaptor.getProperties().rhs)); 1055e95e94adSJeff Niu return success(); 1056e95e94adSJeff Niu } 1057e95e94adSJeff Niu 1058e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1059e95e94adSJeff Niu // LoopBlockOp 1060e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1061e95e94adSJeff Niu 1062e95e94adSJeff Niu void LoopBlockOp::getSuccessorRegions( 1063e95e94adSJeff Niu RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1064e95e94adSJeff Niu regions.emplace_back(&getBody(), getBody().getArguments()); 1065e95e94adSJeff Niu if (point.isParent()) 1066e95e94adSJeff Niu return; 1067e95e94adSJeff Niu 1068e95e94adSJeff Niu regions.emplace_back((*this)->getResults()); 1069e95e94adSJeff Niu } 1070e95e94adSJeff Niu 1071e95e94adSJeff Niu OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { 1072e95e94adSJeff Niu assert(point == getBody()); 1073e95e94adSJeff Niu return MutableOperandRange(getInitMutable()); 1074e95e94adSJeff Niu } 1075e95e94adSJeff Niu 1076e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1077e95e94adSJeff Niu // LoopBlockTerminatorOp 1078e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1079e95e94adSJeff Niu 1080e95e94adSJeff Niu MutableOperandRange 1081e95e94adSJeff Niu LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { 1082e95e94adSJeff Niu if (point.isParent()) 1083e95e94adSJeff Niu return getExitArgMutable(); 1084e95e94adSJeff Niu return getNextIterArgMutable(); 1085e95e94adSJeff Niu } 1086e95e94adSJeff Niu 1087e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1088e95e94adSJeff Niu // SwitchWithNoBreakOp 1089e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1090e95e94adSJeff Niu 1091e95e94adSJeff Niu void TestNoTerminatorOp::getSuccessorRegions( 1092e95e94adSJeff Niu RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} 1093e95e94adSJeff Niu 1094e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1095e95e94adSJeff Niu // Test InferIntRangeInterface 1096e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1097e95e94adSJeff Niu 1098e95e94adSJeff Niu OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { 1099e95e94adSJeff Niu // Just a simple fold for testing purposes that reads an operands constant 1100e95e94adSJeff Niu // value and returns it. 1101e95e94adSJeff Niu if (!attributes.empty()) 1102e95e94adSJeff Niu return attributes.front(); 1103e95e94adSJeff Niu return nullptr; 1104e95e94adSJeff Niu } 1105e95e94adSJeff Niu 1106e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1107e95e94adSJeff Niu // Tensor/Buffer Ops 1108e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1109e95e94adSJeff Niu 1110e95e94adSJeff Niu void ReadBufferOp::getEffects( 1111e95e94adSJeff Niu SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1112e95e94adSJeff Niu &effects) { 1113e95e94adSJeff Niu // The buffer operand is read. 11142c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(), 1115e95e94adSJeff Niu SideEffects::DefaultResource::get()); 1116e95e94adSJeff Niu // The buffer contents are dumped. 1117e95e94adSJeff Niu effects.emplace_back(MemoryEffects::Write::get(), 1118e95e94adSJeff Niu SideEffects::DefaultResource::get()); 1119e95e94adSJeff Niu } 1120e95e94adSJeff Niu 1121e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1122e95e94adSJeff Niu // Test Dataflow 1123e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1124e95e94adSJeff Niu 1125e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1126e95e94adSJeff Niu // TestCallAndStoreOp 1127e95e94adSJeff Niu 1128e95e94adSJeff Niu CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { 1129e95e94adSJeff Niu return getCallee(); 1130e95e94adSJeff Niu } 1131e95e94adSJeff Niu 1132e95e94adSJeff Niu void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { 113335e89897SKazu Hirata setCalleeAttr(cast<SymbolRefAttr>(callee)); 1134e95e94adSJeff Niu } 1135e95e94adSJeff Niu 1136e95e94adSJeff Niu Operation::operand_range TestCallAndStoreOp::getArgOperands() { 1137e95e94adSJeff Niu return getCalleeOperands(); 1138e95e94adSJeff Niu } 1139e95e94adSJeff Niu 1140e95e94adSJeff Niu MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { 1141e95e94adSJeff Niu return getCalleeOperandsMutable(); 1142e95e94adSJeff Niu } 1143e95e94adSJeff Niu 1144e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1145e95e94adSJeff Niu // TestCallOnDeviceOp 1146e95e94adSJeff Niu 1147e95e94adSJeff Niu CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { 1148e95e94adSJeff Niu return getCallee(); 1149e95e94adSJeff Niu } 1150e95e94adSJeff Niu 1151e95e94adSJeff Niu void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { 115235e89897SKazu Hirata setCalleeAttr(cast<SymbolRefAttr>(callee)); 1153e95e94adSJeff Niu } 1154e95e94adSJeff Niu 1155e95e94adSJeff Niu Operation::operand_range TestCallOnDeviceOp::getArgOperands() { 1156e95e94adSJeff Niu return getForwardedOperands(); 1157e95e94adSJeff Niu } 1158e95e94adSJeff Niu 1159e95e94adSJeff Niu MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { 1160e95e94adSJeff Niu return getForwardedOperandsMutable(); 1161e95e94adSJeff Niu } 1162e95e94adSJeff Niu 1163e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1164e95e94adSJeff Niu // TestStoreWithARegion 1165e95e94adSJeff Niu 1166e95e94adSJeff Niu void TestStoreWithARegion::getSuccessorRegions( 1167e95e94adSJeff Niu RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1168e95e94adSJeff Niu if (point.isParent()) 1169e95e94adSJeff Niu regions.emplace_back(&getBody(), getBody().front().getArguments()); 1170e95e94adSJeff Niu else 1171e95e94adSJeff Niu regions.emplace_back(); 1172e95e94adSJeff Niu } 1173e95e94adSJeff Niu 1174e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1175e95e94adSJeff Niu // TestStoreWithALoopRegion 1176e95e94adSJeff Niu 1177e95e94adSJeff Niu void TestStoreWithALoopRegion::getSuccessorRegions( 1178e95e94adSJeff Niu RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1179e95e94adSJeff Niu // Both the operation itself and the region may be branching into the body or 1180e95e94adSJeff Niu // back into the operation itself. It is possible for the operation not to 1181e95e94adSJeff Niu // enter the body. 1182e95e94adSJeff Niu regions.emplace_back( 1183e95e94adSJeff Niu RegionSuccessor(&getBody(), getBody().front().getArguments())); 1184e95e94adSJeff Niu regions.emplace_back(); 1185e95e94adSJeff Niu } 1186e95e94adSJeff Niu 1187e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1188e95e94adSJeff Niu // TestVersionedOpA 1189e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1190e95e94adSJeff Niu 1191e95e94adSJeff Niu LogicalResult 1192e95e94adSJeff Niu TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, 1193e95e94adSJeff Niu mlir::OperationState &state) { 1194e95e94adSJeff Niu auto &prop = state.getOrAddProperties<Properties>(); 1195e95e94adSJeff Niu if (mlir::failed(reader.readAttribute(prop.dims))) 1196e95e94adSJeff Niu return mlir::failure(); 1197e95e94adSJeff Niu 1198e95e94adSJeff Niu // Check if we have a version. If not, assume we are parsing the current 1199e95e94adSJeff Niu // version. 1200e95e94adSJeff Niu auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); 1201e95e94adSJeff Niu if (succeeded(maybeVersion)) { 1202e95e94adSJeff Niu // If version is less than 2.0, there is no additional attribute to parse. 1203e95e94adSJeff Niu // We can materialize missing properties post parsing before verification. 1204e95e94adSJeff Niu const auto *version = 1205e95e94adSJeff Niu reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1206e95e94adSJeff Niu if ((version->major_ < 2)) { 1207e95e94adSJeff Niu return success(); 1208e95e94adSJeff Niu } 1209e95e94adSJeff Niu } 1210e95e94adSJeff Niu 1211e95e94adSJeff Niu if (mlir::failed(reader.readAttribute(prop.modifier))) 1212e95e94adSJeff Niu return mlir::failure(); 1213e95e94adSJeff Niu return mlir::success(); 1214e95e94adSJeff Niu } 1215e95e94adSJeff Niu 1216e95e94adSJeff Niu void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { 1217e95e94adSJeff Niu auto &prop = getProperties(); 1218e95e94adSJeff Niu writer.writeAttribute(prop.dims); 1219e95e94adSJeff Niu 1220e95e94adSJeff Niu auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); 1221e95e94adSJeff Niu if (succeeded(maybeVersion)) { 1222e95e94adSJeff Niu // If version is less than 2.0, there is no additional attribute to write. 1223e95e94adSJeff Niu const auto *version = 1224e95e94adSJeff Niu reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1225e95e94adSJeff Niu if ((version->major_ < 2)) { 1226e95e94adSJeff Niu llvm::outs() << "downgrading op properties...\n"; 1227e95e94adSJeff Niu return; 1228e95e94adSJeff Niu } 1229e95e94adSJeff Niu } 1230e95e94adSJeff Niu writer.writeAttribute(prop.modifier); 1231e95e94adSJeff Niu } 1232e95e94adSJeff Niu 1233e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1234e95e94adSJeff Niu // TestOpWithVersionedProperties 1235e95e94adSJeff Niu //===----------------------------------------------------------------------===// 1236e95e94adSJeff Niu 1237db791b27SRamkumar Ramachandra llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( 1238e95e94adSJeff Niu mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { 1239e95e94adSJeff Niu uint64_t value1, value2 = 0; 1240e95e94adSJeff Niu if (failed(reader.readVarInt(value1))) 1241e95e94adSJeff Niu return failure(); 1242e95e94adSJeff Niu 1243e95e94adSJeff Niu // Check if we have a version. If not, assume we are parsing the current 1244e95e94adSJeff Niu // version. 1245e95e94adSJeff Niu auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); 1246e95e94adSJeff Niu bool needToParseAnotherInt = true; 1247e95e94adSJeff Niu if (succeeded(maybeVersion)) { 1248e95e94adSJeff Niu // If version is less than 2.0, there is no additional attribute to parse. 1249e95e94adSJeff Niu // We can materialize missing properties post parsing before verification. 1250e95e94adSJeff Niu const auto *version = 1251e95e94adSJeff Niu reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1252e95e94adSJeff Niu if ((version->major_ < 2)) 1253e95e94adSJeff Niu needToParseAnotherInt = false; 1254e95e94adSJeff Niu } 1255e95e94adSJeff Niu if (needToParseAnotherInt && failed(reader.readVarInt(value2))) 1256e95e94adSJeff Niu return failure(); 1257e95e94adSJeff Niu 1258e95e94adSJeff Niu prop.value1 = value1; 1259e95e94adSJeff Niu prop.value2 = value2; 1260e95e94adSJeff Niu return success(); 1261e95e94adSJeff Niu } 1262e95e94adSJeff Niu 1263e95e94adSJeff Niu void TestOpWithVersionedProperties::writeToMlirBytecode( 1264e95e94adSJeff Niu mlir::DialectBytecodeWriter &writer, 1265e95e94adSJeff Niu const test::VersionedProperties &prop) { 1266e95e94adSJeff Niu writer.writeVarInt(prop.value1); 1267e95e94adSJeff Niu writer.writeVarInt(prop.value2); 1268e95e94adSJeff Niu } 1269eeafc9daSChristian Ulmann 1270eeafc9daSChristian Ulmann //===----------------------------------------------------------------------===// 1271eeafc9daSChristian Ulmann // TestMultiSlotAlloca 1272eeafc9daSChristian Ulmann //===----------------------------------------------------------------------===// 1273eeafc9daSChristian Ulmann 1274eeafc9daSChristian Ulmann llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { 1275eeafc9daSChristian Ulmann SmallVector<MemorySlot> slots; 1276eeafc9daSChristian Ulmann for (Value result : getResults()) { 1277eeafc9daSChristian Ulmann slots.push_back(MemorySlot{ 1278eeafc9daSChristian Ulmann result, cast<MemRefType>(result.getType()).getElementType()}); 1279eeafc9daSChristian Ulmann } 1280eeafc9daSChristian Ulmann return slots; 1281eeafc9daSChristian Ulmann } 1282eeafc9daSChristian Ulmann 1283eeafc9daSChristian Ulmann Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, 1284eeafc9daSChristian Ulmann OpBuilder &builder) { 1285eeafc9daSChristian Ulmann return builder.create<TestOpConstant>(getLoc(), slot.elemType, 1286eeafc9daSChristian Ulmann builder.getI32IntegerAttr(42)); 1287eeafc9daSChristian Ulmann } 1288eeafc9daSChristian Ulmann 1289eeafc9daSChristian Ulmann void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, 1290eeafc9daSChristian Ulmann BlockArgument argument, 1291eeafc9daSChristian Ulmann OpBuilder &builder) { 1292eeafc9daSChristian Ulmann // Not relevant for testing. 1293eeafc9daSChristian Ulmann } 1294eeafc9daSChristian Ulmann 12950b5b2027SChristian Ulmann /// Creates a new TestMultiSlotAlloca operation, just without the `slot`. 12960b5b2027SChristian Ulmann static std::optional<TestMultiSlotAlloca> 12970b5b2027SChristian Ulmann createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, 12980b5b2027SChristian Ulmann TestMultiSlotAlloca oldOp) { 1299eeafc9daSChristian Ulmann 13000b5b2027SChristian Ulmann if (oldOp.getNumResults() == 1) { 13010b5b2027SChristian Ulmann oldOp.erase(); 1302eeafc9daSChristian Ulmann return std::nullopt; 1303eeafc9daSChristian Ulmann } 1304eeafc9daSChristian Ulmann 1305eeafc9daSChristian Ulmann SmallVector<Type> newTypes; 1306eeafc9daSChristian Ulmann SmallVector<Value> remainingValues; 1307eeafc9daSChristian Ulmann 13080b5b2027SChristian Ulmann for (Value oldResult : oldOp.getResults()) { 1309eeafc9daSChristian Ulmann if (oldResult == slot.ptr) 1310eeafc9daSChristian Ulmann continue; 1311eeafc9daSChristian Ulmann remainingValues.push_back(oldResult); 1312eeafc9daSChristian Ulmann newTypes.push_back(oldResult.getType()); 1313eeafc9daSChristian Ulmann } 1314eeafc9daSChristian Ulmann 1315eeafc9daSChristian Ulmann OpBuilder::InsertionGuard guard(builder); 13160b5b2027SChristian Ulmann builder.setInsertionPoint(oldOp); 13170b5b2027SChristian Ulmann auto replacement = 13180b5b2027SChristian Ulmann builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); 1319eeafc9daSChristian Ulmann for (auto [oldResult, newResult] : 1320eeafc9daSChristian Ulmann llvm::zip_equal(remainingValues, replacement.getResults())) 1321eeafc9daSChristian Ulmann oldResult.replaceAllUsesWith(newResult); 1322eeafc9daSChristian Ulmann 13230b5b2027SChristian Ulmann oldOp.erase(); 1324eeafc9daSChristian Ulmann return replacement; 1325eeafc9daSChristian Ulmann } 13260b5b2027SChristian Ulmann 13270b5b2027SChristian Ulmann std::optional<PromotableAllocationOpInterface> 13280b5b2027SChristian Ulmann TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, 13290b5b2027SChristian Ulmann Value defaultValue, 13300b5b2027SChristian Ulmann OpBuilder &builder) { 13310b5b2027SChristian Ulmann if (defaultValue && defaultValue.use_empty()) 13320b5b2027SChristian Ulmann defaultValue.getDefiningOp()->erase(); 13330b5b2027SChristian Ulmann return createNewMultiAllocaWithoutSlot(slot, builder, *this); 13340b5b2027SChristian Ulmann } 13350b5b2027SChristian Ulmann 13360b5b2027SChristian Ulmann SmallVector<DestructurableMemorySlot> 13370b5b2027SChristian Ulmann TestMultiSlotAlloca::getDestructurableSlots() { 13380b5b2027SChristian Ulmann SmallVector<DestructurableMemorySlot> slots; 13390b5b2027SChristian Ulmann for (Value result : getResults()) { 13400b5b2027SChristian Ulmann auto memrefType = cast<MemRefType>(result.getType()); 13410b5b2027SChristian Ulmann auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType); 13420b5b2027SChristian Ulmann if (!destructurable) 13430b5b2027SChristian Ulmann continue; 13440b5b2027SChristian Ulmann 13450b5b2027SChristian Ulmann std::optional<DenseMap<Attribute, Type>> destructuredType = 13460b5b2027SChristian Ulmann destructurable.getSubelementIndexMap(); 13470b5b2027SChristian Ulmann if (!destructuredType) 13480b5b2027SChristian Ulmann continue; 13490b5b2027SChristian Ulmann slots.emplace_back( 13500b5b2027SChristian Ulmann DestructurableMemorySlot{{result, memrefType}, *destructuredType}); 13510b5b2027SChristian Ulmann } 13520b5b2027SChristian Ulmann return slots; 13530b5b2027SChristian Ulmann } 13540b5b2027SChristian Ulmann 13550b5b2027SChristian Ulmann DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( 13560b5b2027SChristian Ulmann const DestructurableMemorySlot &slot, 13570b5b2027SChristian Ulmann const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, 13580b5b2027SChristian Ulmann SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { 13590b5b2027SChristian Ulmann OpBuilder::InsertionGuard guard(builder); 13600b5b2027SChristian Ulmann builder.setInsertionPointAfter(*this); 13610b5b2027SChristian Ulmann 13620b5b2027SChristian Ulmann DenseMap<Attribute, MemorySlot> slotMap; 13630b5b2027SChristian Ulmann 13640b5b2027SChristian Ulmann for (Attribute usedIndex : usedIndices) { 136569d3793fSThéo Degioanni Type elemType = slot.subelementTypes.lookup(usedIndex); 13660b5b2027SChristian Ulmann MemRefType elemPtr = MemRefType::get({}, elemType); 13670b5b2027SChristian Ulmann auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); 13680b5b2027SChristian Ulmann newAllocators.push_back(subAlloca); 13690b5b2027SChristian Ulmann slotMap.try_emplace<MemorySlot>(usedIndex, 13700b5b2027SChristian Ulmann {subAlloca.getResult(0), elemType}); 13710b5b2027SChristian Ulmann } 13720b5b2027SChristian Ulmann 13730b5b2027SChristian Ulmann return slotMap; 13740b5b2027SChristian Ulmann } 13750b5b2027SChristian Ulmann 13760b5b2027SChristian Ulmann std::optional<DestructurableAllocationOpInterface> 13770b5b2027SChristian Ulmann TestMultiSlotAlloca::handleDestructuringComplete( 13780b5b2027SChristian Ulmann const DestructurableMemorySlot &slot, OpBuilder &builder) { 13790b5b2027SChristian Ulmann return createNewMultiAllocaWithoutSlot(slot, builder, *this); 13800b5b2027SChristian Ulmann } 1381