1 //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 struct TestMemRefStrideCalculation 17 : public PassWrapper<TestMemRefStrideCalculation, 18 InterfacePass<SymbolOpInterface>> { 19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefStrideCalculation) 20 21 StringRef getArgument() const final { 22 return "test-memref-stride-calculation"; 23 } 24 StringRef getDescription() const final { 25 return "Test operation constant folding"; 26 } 27 void runOnOperation() override; 28 }; 29 } // namespace 30 31 /// Traverse AllocOp and compute strides of each MemRefType independently. 32 void TestMemRefStrideCalculation::runOnOperation() { 33 llvm::outs() << "Testing: " << getOperation().getName() << "\n"; 34 getOperation().walk([&](memref::AllocOp allocOp) { 35 auto memrefType = cast<MemRefType>(allocOp.getResult().getType()); 36 int64_t offset; 37 SmallVector<int64_t, 4> strides; 38 if (failed(memrefType.getStridesAndOffset(strides, offset))) { 39 llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " 40 << "strided form\n"; 41 return; 42 } 43 llvm::outs() << "MemRefType offset: "; 44 if (ShapedType::isDynamic(offset)) 45 llvm::outs() << "?"; 46 else 47 llvm::outs() << offset; 48 llvm::outs() << " strides: "; 49 llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { 50 if (ShapedType::isDynamic(v)) 51 llvm::outs() << "?"; 52 else 53 llvm::outs() << v; 54 }); 55 llvm::outs() << "\n"; 56 }); 57 llvm::outs().flush(); 58 } 59 60 namespace mlir { 61 namespace test { 62 void registerTestMemRefStrideCalculation() { 63 PassRegistration<TestMemRefStrideCalculation>(); 64 } 65 } // namespace test 66 } // namespace mlir 67