13fef2d26SRiver Riddle //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle 93fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 103fef2d26SRiver Riddle #include "mlir/IR/BuiltinTypes.h" 113fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 123fef2d26SRiver Riddle 133fef2d26SRiver Riddle using namespace mlir; 143fef2d26SRiver Riddle 153fef2d26SRiver Riddle namespace { 163fef2d26SRiver Riddle struct TestMemRefStrideCalculation 1787d6bf37SRiver Riddle : public PassWrapper<TestMemRefStrideCalculation, 1887d6bf37SRiver Riddle InterfacePass<SymbolOpInterface>> { 195e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefStrideCalculation) 205e50dd04SRiver Riddle 21b5e22e6dSMehdi Amini StringRef getArgument() const final { 22b5e22e6dSMehdi Amini return "test-memref-stride-calculation"; 23b5e22e6dSMehdi Amini } 24b5e22e6dSMehdi Amini StringRef getDescription() const final { 25b5e22e6dSMehdi Amini return "Test operation constant folding"; 26b5e22e6dSMehdi Amini } 2741574554SRiver Riddle void runOnOperation() override; 283fef2d26SRiver Riddle }; 29be0a7e9fSMehdi Amini } // namespace 303fef2d26SRiver Riddle 313fef2d26SRiver Riddle /// Traverse AllocOp and compute strides of each MemRefType independently. 3241574554SRiver Riddle void TestMemRefStrideCalculation::runOnOperation() { 3341574554SRiver Riddle llvm::outs() << "Testing: " << getOperation().getName() << "\n"; 3441574554SRiver Riddle getOperation().walk([&](memref::AllocOp allocOp) { 355550c821STres Popp auto memrefType = cast<MemRefType>(allocOp.getResult().getType()); 363fef2d26SRiver Riddle int64_t offset; 373fef2d26SRiver Riddle SmallVector<int64_t, 4> strides; 38*6aaa8f25SMatthias Springer if (failed(memrefType.getStridesAndOffset(strides, offset))) { 393fef2d26SRiver Riddle llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " 403fef2d26SRiver Riddle << "strided form\n"; 413fef2d26SRiver Riddle return; 423fef2d26SRiver Riddle } 433fef2d26SRiver Riddle llvm::outs() << "MemRefType offset: "; 44399638f9SAliia Khasanova if (ShapedType::isDynamic(offset)) 453fef2d26SRiver Riddle llvm::outs() << "?"; 463fef2d26SRiver Riddle else 473fef2d26SRiver Riddle llvm::outs() << offset; 483fef2d26SRiver Riddle llvm::outs() << " strides: "; 493fef2d26SRiver Riddle llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { 50399638f9SAliia Khasanova if (ShapedType::isDynamic(v)) 513fef2d26SRiver Riddle llvm::outs() << "?"; 523fef2d26SRiver Riddle else 533fef2d26SRiver Riddle llvm::outs() << v; 543fef2d26SRiver Riddle }); 553fef2d26SRiver Riddle llvm::outs() << "\n"; 563fef2d26SRiver Riddle }); 573fef2d26SRiver Riddle llvm::outs().flush(); 583fef2d26SRiver Riddle } 593fef2d26SRiver Riddle 603fef2d26SRiver Riddle namespace mlir { 613fef2d26SRiver Riddle namespace test { 623fef2d26SRiver Riddle void registerTestMemRefStrideCalculation() { 63b5e22e6dSMehdi Amini PassRegistration<TestMemRefStrideCalculation>(); 643fef2d26SRiver Riddle } 653fef2d26SRiver Riddle } // namespace test 663fef2d26SRiver Riddle } // namespace mlir 67