19f7fff7fSCullen Rhodes //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// 29f7fff7fSCullen Rhodes // 39f7fff7fSCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49f7fff7fSCullen Rhodes // See https://llvm.org/LICENSE.txt for license information. 59f7fff7fSCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69f7fff7fSCullen Rhodes // 79f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 89f7fff7fSCullen Rhodes 99f7fff7fSCullen Rhodes #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" 109f7fff7fSCullen Rhodes 119f7fff7fSCullen Rhodes #include "mlir/Dialect/Arith/IR/Arith.h" 129f7fff7fSCullen Rhodes #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 139f7fff7fSCullen Rhodes #include "mlir/Dialect/ArmSME/Utils/Utils.h" 149f7fff7fSCullen Rhodes #include "mlir/Pass/Pass.h" 159f7fff7fSCullen Rhodes #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 169f7fff7fSCullen Rhodes 179f7fff7fSCullen Rhodes namespace mlir { 189f7fff7fSCullen Rhodes #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS 199f7fff7fSCullen Rhodes #include "mlir/Conversion/Passes.h.inc" 209f7fff7fSCullen Rhodes } // namespace mlir 219f7fff7fSCullen Rhodes 229f7fff7fSCullen Rhodes #define DEBUG_TYPE "arith-to-arm-sme" 239f7fff7fSCullen Rhodes 249f7fff7fSCullen Rhodes using namespace mlir; 259f7fff7fSCullen Rhodes 269f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 279f7fff7fSCullen Rhodes // Conversion helpers 289f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 299f7fff7fSCullen Rhodes 309f7fff7fSCullen Rhodes /// Returns true if 'val' is a splat of zero, false otherwise. 319f7fff7fSCullen Rhodes static bool isSplatZero(Type elemType, DenseElementsAttr val) { 329f7fff7fSCullen Rhodes if (llvm::isa<FloatType>(elemType)) 339f7fff7fSCullen Rhodes return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); 349f7fff7fSCullen Rhodes if (llvm::isa<IntegerType>(elemType)) 359f7fff7fSCullen Rhodes return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); 369f7fff7fSCullen Rhodes return false; 379f7fff7fSCullen Rhodes } 389f7fff7fSCullen Rhodes 399f7fff7fSCullen Rhodes namespace { 409f7fff7fSCullen Rhodes 419f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 429f7fff7fSCullen Rhodes // ConstantOp 439f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 449f7fff7fSCullen Rhodes 459f7fff7fSCullen Rhodes /// Conversion pattern for dense arith.constant. 469f7fff7fSCullen Rhodes struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { 479f7fff7fSCullen Rhodes using OpRewritePattern<arith::ConstantOp>::OpRewritePattern; 489f7fff7fSCullen Rhodes 499f7fff7fSCullen Rhodes LogicalResult matchAndRewrite(arith::ConstantOp constantOp, 509f7fff7fSCullen Rhodes PatternRewriter &rewriter) const final { 519f7fff7fSCullen Rhodes auto tileType = dyn_cast<VectorType>(constantOp.getType()); 529f7fff7fSCullen Rhodes if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 539f7fff7fSCullen Rhodes return failure(); 549f7fff7fSCullen Rhodes 559f7fff7fSCullen Rhodes auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); 569f7fff7fSCullen Rhodes if (!denseAttr || !denseAttr.isSplat()) 579f7fff7fSCullen Rhodes return failure(); 589f7fff7fSCullen Rhodes 599f7fff7fSCullen Rhodes auto tileElementType = tileType.getElementType(); 609f7fff7fSCullen Rhodes 619f7fff7fSCullen Rhodes // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. 629f7fff7fSCullen Rhodes if (isSplatZero(tileElementType, denseAttr)) { 639f7fff7fSCullen Rhodes rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType); 649f7fff7fSCullen Rhodes return success(); 659f7fff7fSCullen Rhodes } 669f7fff7fSCullen Rhodes 67c4251243SBenjamin Maxwell // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice' 689f7fff7fSCullen Rhodes // ops that broadcast the constant to each tile slice. 699f7fff7fSCullen Rhodes auto loc = constantOp.getLoc(); 709f7fff7fSCullen Rhodes 719f7fff7fSCullen Rhodes // To fill a tile with a constant, we create a 1-D splat of the constant, 729f7fff7fSCullen Rhodes // then move that into each tile slice (the largest unit we can set at once, 739f7fff7fSCullen Rhodes // outside of operations like the outerproduct). 749f7fff7fSCullen Rhodes VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 759f7fff7fSCullen Rhodes auto denseAttr1D = DenseElementsAttr::get( 769f7fff7fSCullen Rhodes tileSliceType, denseAttr.getSplatValue<Attribute>()); 779f7fff7fSCullen Rhodes auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); 789f7fff7fSCullen Rhodes 799f7fff7fSCullen Rhodes auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 809f7fff7fSCullen Rhodes auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 819f7fff7fSCullen Rhodes Value currentTile) { 82c4251243SBenjamin Maxwell // Create 'arm_sme.insert_tile_slice' to write vector to tile 839f7fff7fSCullen Rhodes // slice. 84c4251243SBenjamin Maxwell auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 859f7fff7fSCullen Rhodes loc, tileType, constantOp1D, currentTile, tileSliceIndex); 869f7fff7fSCullen Rhodes return nextTile.getResult(); 879f7fff7fSCullen Rhodes }; 889f7fff7fSCullen Rhodes auto forOp = mlir::arm_sme::createLoopOverTileSlices( 899f7fff7fSCullen Rhodes rewriter, loc, initTile, makeLoopBody); 909f7fff7fSCullen Rhodes rewriter.replaceOp(constantOp, forOp.getResult(0)); 919f7fff7fSCullen Rhodes 929f7fff7fSCullen Rhodes return success(); 939f7fff7fSCullen Rhodes } 949f7fff7fSCullen Rhodes }; 959f7fff7fSCullen Rhodes 969f7fff7fSCullen Rhodes } // namespace 979f7fff7fSCullen Rhodes 989f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 999f7fff7fSCullen Rhodes // Pattern population 1009f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 1019f7fff7fSCullen Rhodes 1029f7fff7fSCullen Rhodes void mlir::arith::populateArithToArmSMEConversionPatterns( 1039f7fff7fSCullen Rhodes RewritePatternSet &patterns) { 1049f7fff7fSCullen Rhodes patterns.add<ConstantOpToArmSMELowering>(patterns.getContext()); 1059f7fff7fSCullen Rhodes } 1069f7fff7fSCullen Rhodes 1079f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 1089f7fff7fSCullen Rhodes // Pass definition 1099f7fff7fSCullen Rhodes //===----------------------------------------------------------------------===// 1109f7fff7fSCullen Rhodes 1119f7fff7fSCullen Rhodes namespace { 1129f7fff7fSCullen Rhodes struct ArithToArmSMEConversionPass final 1139f7fff7fSCullen Rhodes : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> { 1149f7fff7fSCullen Rhodes using impl::ArithToArmSMEConversionPassBase< 1159f7fff7fSCullen Rhodes ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; 1169f7fff7fSCullen Rhodes 1179f7fff7fSCullen Rhodes void runOnOperation() override { 1189f7fff7fSCullen Rhodes RewritePatternSet patterns(&getContext()); 1199f7fff7fSCullen Rhodes arith::populateArithToArmSMEConversionPatterns(patterns); 120*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 1219f7fff7fSCullen Rhodes return signalPassFailure(); 1229f7fff7fSCullen Rhodes } 1239f7fff7fSCullen Rhodes }; 1249f7fff7fSCullen Rhodes } // namespace 125