19f6c0056SKolya Panchenko //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===// 29f6c0056SKolya Panchenko // 39f6c0056SKolya Panchenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49f6c0056SKolya Panchenko // See https://llvm.org/LICENSE.txt for license information. 59f6c0056SKolya Panchenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69f6c0056SKolya Panchenko // 79f6c0056SKolya Panchenko //===----------------------------------------------------------------------===// 89f6c0056SKolya Panchenko 99f6c0056SKolya Panchenko #include "mlir/Dialect/Arith/IR/Arith.h" 109f6c0056SKolya Panchenko #include "mlir/Dialect/Func/IR/FuncOps.h" 119f6c0056SKolya Panchenko #include "mlir/Dialect/LLVMIR/VCIXDialect.h" 129f6c0056SKolya Panchenko #include "mlir/Dialect/Math/IR/Math.h" 139f6c0056SKolya Panchenko #include "mlir/Dialect/Vector/IR/VectorOps.h" 149f6c0056SKolya Panchenko #include "mlir/IR/PatternMatch.h" 159f6c0056SKolya Panchenko #include "mlir/Pass/Pass.h" 169f6c0056SKolya Panchenko #include "mlir/Pass/PassManager.h" 179f6c0056SKolya Panchenko #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 189f6c0056SKolya Panchenko 199f6c0056SKolya Panchenko namespace mlir { 209f6c0056SKolya Panchenko namespace { 219f6c0056SKolya Panchenko 229f6c0056SKolya Panchenko /// Return number of extracts required to make input VectorType \vt legal and 239f6c0056SKolya Panchenko /// also return thatlegal vector type. 249f6c0056SKolya Panchenko /// For fixed vectors nothing special is needed. Scalable vectors are legalizes 259f6c0056SKolya Panchenko /// according to LLVM's encoding: 269f6c0056SKolya Panchenko /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html 279f6c0056SKolya Panchenko static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { 28a5757c5bSChristian Sigg VectorType vt = cast<VectorType>(type); 299f6c0056SKolya Panchenko // To simplify test pass, avoid multi-dimensional vectors. 309f6c0056SKolya Panchenko if (!vt || vt.getRank() != 1) 319f6c0056SKolya Panchenko return {0, nullptr}; 329f6c0056SKolya Panchenko 339f6c0056SKolya Panchenko if (!vt.isScalable()) 349f6c0056SKolya Panchenko return {1, vt}; 359f6c0056SKolya Panchenko 369f6c0056SKolya Panchenko Type eltTy = vt.getElementType(); 379f6c0056SKolya Panchenko unsigned sew = 0; 389f6c0056SKolya Panchenko if (eltTy.isF32()) 399f6c0056SKolya Panchenko sew = 32; 409f6c0056SKolya Panchenko else if (eltTy.isF64()) 419f6c0056SKolya Panchenko sew = 64; 42a5757c5bSChristian Sigg else if (auto intTy = dyn_cast<IntegerType>(eltTy)) 439f6c0056SKolya Panchenko sew = intTy.getWidth(); 449f6c0056SKolya Panchenko else 459f6c0056SKolya Panchenko return {0, nullptr}; 469f6c0056SKolya Panchenko 479f6c0056SKolya Panchenko unsigned eltCount = vt.getShape()[0]; 489f6c0056SKolya Panchenko const unsigned lmul = eltCount * sew / 64; 499f6c0056SKolya Panchenko 509f6c0056SKolya Panchenko unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1; 519f6c0056SKolya Panchenko return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; 529f6c0056SKolya Panchenko } 539f6c0056SKolya Panchenko 549f6c0056SKolya Panchenko /// Replace math.cos(v) operation with vcix.v.iv(v). 559f6c0056SKolya Panchenko struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { 569f6c0056SKolya Panchenko using OpRewritePattern::OpRewritePattern; 579f6c0056SKolya Panchenko 589f6c0056SKolya Panchenko LogicalResult matchAndRewrite(math::CosOp op, 599f6c0056SKolya Panchenko PatternRewriter &rewriter) const override { 609f6c0056SKolya Panchenko const Type opType = op.getOperand().getType(); 619f6c0056SKolya Panchenko auto [n, legalType] = legalizeVectorType(opType); 629f6c0056SKolya Panchenko if (!legalType) 639f6c0056SKolya Panchenko return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 649f6c0056SKolya Panchenko Location loc = op.getLoc(); 659f6c0056SKolya Panchenko Value vec = op.getOperand(); 669f6c0056SKolya Panchenko Attribute immAttr = rewriter.getI32IntegerAttr(0); 679f6c0056SKolya Panchenko Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 689f6c0056SKolya Panchenko Value rvl = nullptr; 699f6c0056SKolya Panchenko if (legalType.isScalable()) 709f6c0056SKolya Panchenko // Use arbitrary runtime vector length when vector type is scalable. 719f6c0056SKolya Panchenko // Proper conversion pass should take it from the IR. 729f6c0056SKolya Panchenko rvl = rewriter.create<arith::ConstantOp>(loc, 739f6c0056SKolya Panchenko rewriter.getI64IntegerAttr(9)); 749f6c0056SKolya Panchenko Value res; 759f6c0056SKolya Panchenko if (n == 1) { 769f6c0056SKolya Panchenko res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, 779f6c0056SKolya Panchenko immAttr, rvl); 789f6c0056SKolya Panchenko } else { 799f6c0056SKolya Panchenko const unsigned eltCount = legalType.getShape()[0]; 809f6c0056SKolya Panchenko Type eltTy = legalType.getElementType(); 819f6c0056SKolya Panchenko Value zero = rewriter.create<arith::ConstantOp>( 829f6c0056SKolya Panchenko loc, eltTy, rewriter.getZeroAttr(eltTy)); 839f6c0056SKolya Panchenko res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 849f6c0056SKolya Panchenko for (unsigned i = 0; i < n; ++i) { 859f6c0056SKolya Panchenko Value extracted = rewriter.create<vector::ScalableExtractOp>( 869f6c0056SKolya Panchenko loc, legalType, vec, i * eltCount); 879f6c0056SKolya Panchenko Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, 889f6c0056SKolya Panchenko extracted, immAttr, rvl); 899f6c0056SKolya Panchenko res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 909f6c0056SKolya Panchenko i * eltCount); 919f6c0056SKolya Panchenko } 929f6c0056SKolya Panchenko } 939f6c0056SKolya Panchenko rewriter.replaceOp(op, res); 949f6c0056SKolya Panchenko return success(); 959f6c0056SKolya Panchenko } 969f6c0056SKolya Panchenko }; 979f6c0056SKolya Panchenko 989f6c0056SKolya Panchenko // Replace math.sin(v) operation with vcix.v.sv(v, v). 999f6c0056SKolya Panchenko struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { 1009f6c0056SKolya Panchenko using OpRewritePattern::OpRewritePattern; 1019f6c0056SKolya Panchenko 1029f6c0056SKolya Panchenko LogicalResult matchAndRewrite(math::SinOp op, 1039f6c0056SKolya Panchenko PatternRewriter &rewriter) const override { 1049f6c0056SKolya Panchenko const Type opType = op.getOperand().getType(); 1059f6c0056SKolya Panchenko auto [n, legalType] = legalizeVectorType(opType); 1069f6c0056SKolya Panchenko if (!legalType) 1079f6c0056SKolya Panchenko return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 1089f6c0056SKolya Panchenko Location loc = op.getLoc(); 1099f6c0056SKolya Panchenko Value vec = op.getOperand(); 1109f6c0056SKolya Panchenko Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 1119f6c0056SKolya Panchenko Value rvl = nullptr; 1129f6c0056SKolya Panchenko if (legalType.isScalable()) 1139f6c0056SKolya Panchenko // Use arbitrary runtime vector length when vector type is scalable. 1149f6c0056SKolya Panchenko // Proper conversion pass should take it from the IR. 1159f6c0056SKolya Panchenko rvl = rewriter.create<arith::ConstantOp>(loc, 1169f6c0056SKolya Panchenko rewriter.getI64IntegerAttr(9)); 1179f6c0056SKolya Panchenko Value res; 1189f6c0056SKolya Panchenko if (n == 1) { 1199f6c0056SKolya Panchenko res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 1209f6c0056SKolya Panchenko vec, rvl); 1219f6c0056SKolya Panchenko } else { 1229f6c0056SKolya Panchenko const unsigned eltCount = legalType.getShape()[0]; 1239f6c0056SKolya Panchenko Type eltTy = legalType.getElementType(); 1249f6c0056SKolya Panchenko Value zero = rewriter.create<arith::ConstantOp>( 1259f6c0056SKolya Panchenko loc, eltTy, rewriter.getZeroAttr(eltTy)); 1269f6c0056SKolya Panchenko res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 1279f6c0056SKolya Panchenko for (unsigned i = 0; i < n; ++i) { 1289f6c0056SKolya Panchenko Value extracted = rewriter.create<vector::ScalableExtractOp>( 1299f6c0056SKolya Panchenko loc, legalType, vec, i * eltCount); 1309f6c0056SKolya Panchenko Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 1319f6c0056SKolya Panchenko extracted, extracted, rvl); 1329f6c0056SKolya Panchenko res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 1339f6c0056SKolya Panchenko i * eltCount); 1349f6c0056SKolya Panchenko } 1359f6c0056SKolya Panchenko } 1369f6c0056SKolya Panchenko rewriter.replaceOp(op, res); 1379f6c0056SKolya Panchenko return success(); 1389f6c0056SKolya Panchenko } 1399f6c0056SKolya Panchenko }; 1409f6c0056SKolya Panchenko 1419f6c0056SKolya Panchenko // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f). 1429f6c0056SKolya Panchenko struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { 1439f6c0056SKolya Panchenko using OpRewritePattern::OpRewritePattern; 1449f6c0056SKolya Panchenko 1459f6c0056SKolya Panchenko LogicalResult matchAndRewrite(math::TanOp op, 1469f6c0056SKolya Panchenko PatternRewriter &rewriter) const override { 1479f6c0056SKolya Panchenko const Type opType = op.getOperand().getType(); 1489f6c0056SKolya Panchenko auto [n, legalType] = legalizeVectorType(opType); 1499f6c0056SKolya Panchenko Type eltTy = legalType.getElementType(); 1509f6c0056SKolya Panchenko if (!legalType) 1519f6c0056SKolya Panchenko return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 1529f6c0056SKolya Panchenko Location loc = op.getLoc(); 1539f6c0056SKolya Panchenko Value vec = op.getOperand(); 1549f6c0056SKolya Panchenko Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 1559f6c0056SKolya Panchenko Value zero = rewriter.create<arith::ConstantOp>( 1569f6c0056SKolya Panchenko loc, eltTy, rewriter.getZeroAttr(eltTy)); 1579f6c0056SKolya Panchenko Value rvl = nullptr; 1589f6c0056SKolya Panchenko if (legalType.isScalable()) 1599f6c0056SKolya Panchenko // Use arbitrary runtime vector length when vector type is scalable. 1609f6c0056SKolya Panchenko // Proper conversion pass should take it from the IR. 1619f6c0056SKolya Panchenko rvl = rewriter.create<arith::ConstantOp>(loc, 1629f6c0056SKolya Panchenko rewriter.getI64IntegerAttr(9)); 1639f6c0056SKolya Panchenko Value res; 1649f6c0056SKolya Panchenko if (n == 1) { 1659f6c0056SKolya Panchenko res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 1669f6c0056SKolya Panchenko zero, rvl); 1679f6c0056SKolya Panchenko } else { 1689f6c0056SKolya Panchenko const unsigned eltCount = legalType.getShape()[0]; 1699f6c0056SKolya Panchenko res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 1709f6c0056SKolya Panchenko for (unsigned i = 0; i < n; ++i) { 1719f6c0056SKolya Panchenko Value extracted = rewriter.create<vector::ScalableExtractOp>( 1729f6c0056SKolya Panchenko loc, legalType, vec, i * eltCount); 1739f6c0056SKolya Panchenko Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 1749f6c0056SKolya Panchenko extracted, zero, rvl); 1759f6c0056SKolya Panchenko res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 1769f6c0056SKolya Panchenko i * eltCount); 1779f6c0056SKolya Panchenko } 1789f6c0056SKolya Panchenko } 1799f6c0056SKolya Panchenko rewriter.replaceOp(op, res); 1809f6c0056SKolya Panchenko return success(); 1819f6c0056SKolya Panchenko } 1829f6c0056SKolya Panchenko }; 1839f6c0056SKolya Panchenko 1849f6c0056SKolya Panchenko // Replace math.log(v) operation with vcix.v.sv(v, 0). 1859f6c0056SKolya Panchenko struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { 1869f6c0056SKolya Panchenko using OpRewritePattern::OpRewritePattern; 1879f6c0056SKolya Panchenko 1889f6c0056SKolya Panchenko LogicalResult matchAndRewrite(math::LogOp op, 1899f6c0056SKolya Panchenko PatternRewriter &rewriter) const override { 1909f6c0056SKolya Panchenko const Type opType = op.getOperand().getType(); 1919f6c0056SKolya Panchenko auto [n, legalType] = legalizeVectorType(opType); 1929f6c0056SKolya Panchenko if (!legalType) 1939f6c0056SKolya Panchenko return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); 1949f6c0056SKolya Panchenko Location loc = op.getLoc(); 1959f6c0056SKolya Panchenko Value vec = op.getOperand(); 1969f6c0056SKolya Panchenko Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); 1979f6c0056SKolya Panchenko Value rvl = nullptr; 1989f6c0056SKolya Panchenko Value zeroInt = rewriter.create<arith::ConstantOp>( 1999f6c0056SKolya Panchenko loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 2009f6c0056SKolya Panchenko if (legalType.isScalable()) 2019f6c0056SKolya Panchenko // Use arbitrary runtime vector length when vector type is scalable. 2029f6c0056SKolya Panchenko // Proper conversion pass should take it from the IR. 2039f6c0056SKolya Panchenko rvl = rewriter.create<arith::ConstantOp>(loc, 2049f6c0056SKolya Panchenko rewriter.getI64IntegerAttr(9)); 2059f6c0056SKolya Panchenko Value res; 2069f6c0056SKolya Panchenko if (n == 1) { 2079f6c0056SKolya Panchenko res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, 2089f6c0056SKolya Panchenko zeroInt, rvl); 2099f6c0056SKolya Panchenko } else { 2109f6c0056SKolya Panchenko const unsigned eltCount = legalType.getShape()[0]; 2119f6c0056SKolya Panchenko Type eltTy = legalType.getElementType(); 2129f6c0056SKolya Panchenko Value zero = rewriter.create<arith::ConstantOp>( 2139f6c0056SKolya Panchenko loc, eltTy, rewriter.getZeroAttr(eltTy)); 2149f6c0056SKolya Panchenko res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); 2159f6c0056SKolya Panchenko for (unsigned i = 0; i < n; ++i) { 2169f6c0056SKolya Panchenko Value extracted = rewriter.create<vector::ScalableExtractOp>( 2179f6c0056SKolya Panchenko loc, legalType, vec, i * eltCount); 2189f6c0056SKolya Panchenko Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, 2199f6c0056SKolya Panchenko extracted, zeroInt, rvl); 2209f6c0056SKolya Panchenko res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, 2219f6c0056SKolya Panchenko i * eltCount); 2229f6c0056SKolya Panchenko } 2239f6c0056SKolya Panchenko } 2249f6c0056SKolya Panchenko rewriter.replaceOp(op, res); 2259f6c0056SKolya Panchenko return success(); 2269f6c0056SKolya Panchenko } 2279f6c0056SKolya Panchenko }; 2289f6c0056SKolya Panchenko 2299f6c0056SKolya Panchenko struct TestMathToVCIX 2309f6c0056SKolya Panchenko : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> { 2319f6c0056SKolya Panchenko MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX) 2329f6c0056SKolya Panchenko 2339f6c0056SKolya Panchenko StringRef getArgument() const final { return "test-math-to-vcix"; } 2349f6c0056SKolya Panchenko 2359f6c0056SKolya Panchenko StringRef getDescription() const final { 2369f6c0056SKolya Panchenko return "Test lowering patterns that converts some vector operations to " 2379f6c0056SKolya Panchenko "VCIX. Since DLA can implement VCIX instructions in completely " 2389f6c0056SKolya Panchenko "different way, conversions of that test pass only lives here."; 2399f6c0056SKolya Panchenko } 2409f6c0056SKolya Panchenko 2419f6c0056SKolya Panchenko void getDependentDialects(DialectRegistry ®istry) const override { 2429f6c0056SKolya Panchenko registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect, 2439f6c0056SKolya Panchenko vcix::VCIXDialect, vector::VectorDialect>(); 2449f6c0056SKolya Panchenko } 2459f6c0056SKolya Panchenko 2469f6c0056SKolya Panchenko void runOnOperation() override { 2479f6c0056SKolya Panchenko MLIRContext *ctx = &getContext(); 2489f6c0056SKolya Panchenko RewritePatternSet patterns(ctx); 2499f6c0056SKolya Panchenko patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>( 2509f6c0056SKolya Panchenko ctx); 251*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2529f6c0056SKolya Panchenko } 2539f6c0056SKolya Panchenko }; 2549f6c0056SKolya Panchenko 2559f6c0056SKolya Panchenko } // namespace 2569f6c0056SKolya Panchenko 2579f6c0056SKolya Panchenko namespace test { 2589f6c0056SKolya Panchenko void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); } 2599f6c0056SKolya Panchenko } // namespace test 2609f6c0056SKolya Panchenko } // namespace mlir 261