1b5d847b1SAlex Zinenko //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===// 2b5d847b1SAlex Zinenko // 3b5d847b1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b5d847b1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5b5d847b1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b5d847b1SAlex Zinenko // 7b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 8b5d847b1SAlex Zinenko 9b5d847b1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" 10b5d847b1SAlex Zinenko #include "MemRefDescriptor.h" 11b5d847b1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12b5d847b1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 137fb9bbe5SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 14b5d847b1SAlex Zinenko #include "mlir/IR/Builders.h" 150fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h" 16b5d847b1SAlex Zinenko 17b5d847b1SAlex Zinenko using namespace mlir; 18b5d847b1SAlex Zinenko 19b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 20b5d847b1SAlex Zinenko // MemRefDescriptor implementation 21b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 22b5d847b1SAlex Zinenko 23b5d847b1SAlex Zinenko /// Construct a helper for the given descriptor value. 24b5d847b1SAlex Zinenko MemRefDescriptor::MemRefDescriptor(Value descriptor) 25b5d847b1SAlex Zinenko : StructBuilder(descriptor) { 26b5d847b1SAlex Zinenko assert(value != nullptr && "value cannot be null"); 275550c821STres Popp indexType = cast<LLVM::LLVMStructType>(value.getType()) 28b5d847b1SAlex Zinenko .getBody()[kOffsetPosInMemRefDescriptor]; 29b5d847b1SAlex Zinenko } 30b5d847b1SAlex Zinenko 31b5d847b1SAlex Zinenko /// Builds IR creating an `undef` value of the descriptor type. 32b5d847b1SAlex Zinenko MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, 33b5d847b1SAlex Zinenko Type descriptorType) { 34b5d847b1SAlex Zinenko 35b5d847b1SAlex Zinenko Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); 36b5d847b1SAlex Zinenko return MemRefDescriptor(descriptor); 37b5d847b1SAlex Zinenko } 38b5d847b1SAlex Zinenko 39b5d847b1SAlex Zinenko /// Builds IR creating a MemRef descriptor that represents `type` and 40b5d847b1SAlex Zinenko /// populates it with static shape and stride information extracted from the 41b5d847b1SAlex Zinenko /// type. 42b5d847b1SAlex Zinenko MemRefDescriptor 43b5d847b1SAlex Zinenko MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, 44ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 45b5d847b1SAlex Zinenko MemRefType type, Value memory) { 46200266a0SQuentin Colombet return fromStaticShape(builder, loc, typeConverter, type, memory, memory); 47200266a0SQuentin Colombet } 48200266a0SQuentin Colombet 49200266a0SQuentin Colombet MemRefDescriptor MemRefDescriptor::fromStaticShape( 50ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 51200266a0SQuentin Colombet MemRefType type, Value memory, Value alignedMemory) { 52b5d847b1SAlex Zinenko assert(type.hasStaticShape() && "unexpected dynamic shape"); 53b5d847b1SAlex Zinenko 54b5d847b1SAlex Zinenko // Extract all strides and offsets and verify they are static. 556aaa8f25SMatthias Springer auto [strides, offset] = type.getStridesAndOffset(); 56b28a296cSChristian Ulmann assert(!ShapedType::isDynamic(offset) && "expected static offset"); 57399638f9SAliia Khasanova assert(!llvm::any_of(strides, ShapedType::isDynamic) && 58380a1b20SKazu Hirata "expected static strides"); 59b5d847b1SAlex Zinenko 60b5d847b1SAlex Zinenko auto convertedType = typeConverter.convertType(type); 61b5d847b1SAlex Zinenko assert(convertedType && "unexpected failure in memref type conversion"); 62b5d847b1SAlex Zinenko 63b5d847b1SAlex Zinenko auto descr = MemRefDescriptor::undef(builder, loc, convertedType); 64b5d847b1SAlex Zinenko descr.setAllocatedPtr(builder, loc, memory); 65200266a0SQuentin Colombet descr.setAlignedPtr(builder, loc, alignedMemory); 66b5d847b1SAlex Zinenko descr.setConstantOffset(builder, loc, offset); 67b5d847b1SAlex Zinenko 68b5d847b1SAlex Zinenko // Fill in sizes and strides 69b5d847b1SAlex Zinenko for (unsigned i = 0, e = type.getRank(); i != e; ++i) { 70b5d847b1SAlex Zinenko descr.setConstantSize(builder, loc, i, type.getDimSize(i)); 71b5d847b1SAlex Zinenko descr.setConstantStride(builder, loc, i, strides[i]); 72b5d847b1SAlex Zinenko } 73b5d847b1SAlex Zinenko return descr; 74b5d847b1SAlex Zinenko } 75b5d847b1SAlex Zinenko 76b5d847b1SAlex Zinenko /// Builds IR extracting the allocated pointer from the descriptor. 77b5d847b1SAlex Zinenko Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { 78b5d847b1SAlex Zinenko return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); 79b5d847b1SAlex Zinenko } 80b5d847b1SAlex Zinenko 81b5d847b1SAlex Zinenko /// Builds IR inserting the allocated pointer into the descriptor. 82b5d847b1SAlex Zinenko void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, 83b5d847b1SAlex Zinenko Value ptr) { 84b5d847b1SAlex Zinenko setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); 85b5d847b1SAlex Zinenko } 86b5d847b1SAlex Zinenko 87b5d847b1SAlex Zinenko /// Builds IR extracting the aligned pointer from the descriptor. 88b5d847b1SAlex Zinenko Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { 89b5d847b1SAlex Zinenko return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); 90b5d847b1SAlex Zinenko } 91b5d847b1SAlex Zinenko 92b5d847b1SAlex Zinenko /// Builds IR inserting the aligned pointer into the descriptor. 93b5d847b1SAlex Zinenko void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, 94b5d847b1SAlex Zinenko Value ptr) { 95b5d847b1SAlex Zinenko setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); 96b5d847b1SAlex Zinenko } 97b5d847b1SAlex Zinenko 98b5d847b1SAlex Zinenko // Creates a constant Op producing a value of `resultType` from an index-typed 99b5d847b1SAlex Zinenko // integer attribute. 100b5d847b1SAlex Zinenko static Value createIndexAttrConstant(OpBuilder &builder, Location loc, 101b5d847b1SAlex Zinenko Type resultType, int64_t value) { 1020af643f3SJeff Niu return builder.create<LLVM::ConstantOp>(loc, resultType, 1030af643f3SJeff Niu builder.getIndexAttr(value)); 104b5d847b1SAlex Zinenko } 105b5d847b1SAlex Zinenko 106b5d847b1SAlex Zinenko /// Builds IR extracting the offset from the descriptor. 107b5d847b1SAlex Zinenko Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { 1085c5af910SJeff Niu return builder.create<LLVM::ExtractValueOp>(loc, value, 1095c5af910SJeff Niu kOffsetPosInMemRefDescriptor); 110b5d847b1SAlex Zinenko } 111b5d847b1SAlex Zinenko 112b5d847b1SAlex Zinenko /// Builds IR inserting the offset into the descriptor. 113b5d847b1SAlex Zinenko void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, 114b5d847b1SAlex Zinenko Value offset) { 1155c5af910SJeff Niu value = builder.create<LLVM::InsertValueOp>(loc, value, offset, 1165c5af910SJeff Niu kOffsetPosInMemRefDescriptor); 117b5d847b1SAlex Zinenko } 118b5d847b1SAlex Zinenko 119b5d847b1SAlex Zinenko /// Builds IR inserting the offset into the descriptor. 120b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, 121b5d847b1SAlex Zinenko uint64_t offset) { 122b5d847b1SAlex Zinenko setOffset(builder, loc, 123b5d847b1SAlex Zinenko createIndexAttrConstant(builder, loc, indexType, offset)); 124b5d847b1SAlex Zinenko } 125b5d847b1SAlex Zinenko 126b5d847b1SAlex Zinenko /// Builds IR extracting the pos-th size from the descriptor. 127b5d847b1SAlex Zinenko Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { 128b5d847b1SAlex Zinenko return builder.create<LLVM::ExtractValueOp>( 12982973067SUday Bondhugula loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); 130b5d847b1SAlex Zinenko } 131b5d847b1SAlex Zinenko 132b5d847b1SAlex Zinenko Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, 133b5d847b1SAlex Zinenko int64_t rank) { 134b5d847b1SAlex Zinenko auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); 13550ea17b8SMarkus Böck 136b28a296cSChristian Ulmann auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); 137b5d847b1SAlex Zinenko 138b5d847b1SAlex Zinenko // Copy size values to stack-allocated memory. 139b5d847b1SAlex Zinenko auto one = createIndexAttrConstant(builder, loc, indexType, 1); 140b5d847b1SAlex Zinenko auto sizes = builder.create<LLVM::ExtractValueOp>( 141984b800aSserge-sans-paille loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); 142b28a296cSChristian Ulmann auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one, 14350ea17b8SMarkus Böck /*alignment=*/0); 144b5d847b1SAlex Zinenko builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); 145b5d847b1SAlex Zinenko 146b5d847b1SAlex Zinenko // Load an return size value of interest. 147b28a296cSChristian Ulmann auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr, 148b28a296cSChristian Ulmann ArrayRef<LLVM::GEPArg>{0, pos}); 14950ea17b8SMarkus Böck return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr); 150b5d847b1SAlex Zinenko } 151b5d847b1SAlex Zinenko 152b5d847b1SAlex Zinenko /// Builds IR inserting the pos-th size into the descriptor 153b5d847b1SAlex Zinenko void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, 154b5d847b1SAlex Zinenko Value size) { 155b5d847b1SAlex Zinenko value = builder.create<LLVM::InsertValueOp>( 15682973067SUday Bondhugula loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); 157b5d847b1SAlex Zinenko } 158b5d847b1SAlex Zinenko 159b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, 160b5d847b1SAlex Zinenko unsigned pos, uint64_t size) { 161b5d847b1SAlex Zinenko setSize(builder, loc, pos, 162b5d847b1SAlex Zinenko createIndexAttrConstant(builder, loc, indexType, size)); 163b5d847b1SAlex Zinenko } 164b5d847b1SAlex Zinenko 165b5d847b1SAlex Zinenko /// Builds IR extracting the pos-th stride from the descriptor. 166b5d847b1SAlex Zinenko Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { 167b5d847b1SAlex Zinenko return builder.create<LLVM::ExtractValueOp>( 16882973067SUday Bondhugula loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); 169b5d847b1SAlex Zinenko } 170b5d847b1SAlex Zinenko 171b5d847b1SAlex Zinenko /// Builds IR inserting the pos-th stride into the descriptor 172b5d847b1SAlex Zinenko void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, 173b5d847b1SAlex Zinenko Value stride) { 174b5d847b1SAlex Zinenko value = builder.create<LLVM::InsertValueOp>( 1755c5af910SJeff Niu loc, value, stride, 17682973067SUday Bondhugula ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); 177b5d847b1SAlex Zinenko } 178b5d847b1SAlex Zinenko 179b5d847b1SAlex Zinenko void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, 180b5d847b1SAlex Zinenko unsigned pos, uint64_t stride) { 181b5d847b1SAlex Zinenko setStride(builder, loc, pos, 182b5d847b1SAlex Zinenko createIndexAttrConstant(builder, loc, indexType, stride)); 183b5d847b1SAlex Zinenko } 184b5d847b1SAlex Zinenko 185b5d847b1SAlex Zinenko LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { 1865550c821STres Popp return cast<LLVM::LLVMPointerType>( 1875550c821STres Popp cast<LLVM::LLVMStructType>(value.getType()) 1885550c821STres Popp .getBody()[kAlignedPtrPosInMemRefDescriptor]); 189b5d847b1SAlex Zinenko } 190b5d847b1SAlex Zinenko 191e02d4142SQuentin Colombet Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, 192ce254598SMatthias Springer const LLVMTypeConverter &converter, 193e02d4142SQuentin Colombet MemRefType type) { 194e02d4142SQuentin Colombet // When we convert to LLVM, the input memref must have been normalized 195e02d4142SQuentin Colombet // beforehand. Hence, this call is guaranteed to work. 1966aaa8f25SMatthias Springer auto [strides, offsetCst] = type.getStridesAndOffset(); 197e02d4142SQuentin Colombet 198e02d4142SQuentin Colombet Value ptr = alignedPtr(builder, loc); 199bba9209fSQuentin Colombet // For zero offsets, we already have the base pointer. 200bba9209fSQuentin Colombet if (offsetCst == 0) 201bba9209fSQuentin Colombet return ptr; 202bba9209fSQuentin Colombet 203bba9209fSQuentin Colombet // Otherwise add the offset to the aligned base. 204e02d4142SQuentin Colombet Type indexType = converter.getIndexType(); 205e02d4142SQuentin Colombet Value offsetVal = 206e02d4142SQuentin Colombet ShapedType::isDynamic(offsetCst) 207e02d4142SQuentin Colombet ? offset(builder, loc) 208e02d4142SQuentin Colombet : createIndexAttrConstant(builder, loc, indexType, offsetCst); 209e02d4142SQuentin Colombet Type elementType = converter.convertType(type.getElementType()); 210e02d4142SQuentin Colombet ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr, 211e02d4142SQuentin Colombet offsetVal); 212e02d4142SQuentin Colombet return ptr; 213e02d4142SQuentin Colombet } 214e02d4142SQuentin Colombet 215b5d847b1SAlex Zinenko /// Creates a MemRef descriptor structure from a list of individual values 216b5d847b1SAlex Zinenko /// composing that descriptor, in the following order: 217b5d847b1SAlex Zinenko /// - allocated pointer; 218b5d847b1SAlex Zinenko /// - aligned pointer; 219b5d847b1SAlex Zinenko /// - offset; 220b5d847b1SAlex Zinenko /// - <rank> sizes; 221*69007687SMatthias Springer /// - <rank> strides; 222b5d847b1SAlex Zinenko /// where <rank> is the MemRef rank as provided in `type`. 223b5d847b1SAlex Zinenko Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, 224ce254598SMatthias Springer const LLVMTypeConverter &converter, 225ce254598SMatthias Springer MemRefType type, ValueRange values) { 226b5d847b1SAlex Zinenko Type llvmType = converter.convertType(type); 227b5d847b1SAlex Zinenko auto d = MemRefDescriptor::undef(builder, loc, llvmType); 228b5d847b1SAlex Zinenko 229b5d847b1SAlex Zinenko d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); 230b5d847b1SAlex Zinenko d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); 231b5d847b1SAlex Zinenko d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); 232b5d847b1SAlex Zinenko 233b5d847b1SAlex Zinenko int64_t rank = type.getRank(); 234b5d847b1SAlex Zinenko for (unsigned i = 0; i < rank; ++i) { 235b5d847b1SAlex Zinenko d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); 236b5d847b1SAlex Zinenko d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); 237b5d847b1SAlex Zinenko } 238b5d847b1SAlex Zinenko 239b5d847b1SAlex Zinenko return d; 240b5d847b1SAlex Zinenko } 241b5d847b1SAlex Zinenko 242b5d847b1SAlex Zinenko /// Builds IR extracting individual elements of a MemRef descriptor structure 243b5d847b1SAlex Zinenko /// and returning them as `results` list. 244b5d847b1SAlex Zinenko void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, 245b5d847b1SAlex Zinenko MemRefType type, 246b5d847b1SAlex Zinenko SmallVectorImpl<Value> &results) { 247b5d847b1SAlex Zinenko int64_t rank = type.getRank(); 248b5d847b1SAlex Zinenko results.reserve(results.size() + getNumUnpackedValues(type)); 249b5d847b1SAlex Zinenko 250b5d847b1SAlex Zinenko MemRefDescriptor d(packed); 251b5d847b1SAlex Zinenko results.push_back(d.allocatedPtr(builder, loc)); 252b5d847b1SAlex Zinenko results.push_back(d.alignedPtr(builder, loc)); 253b5d847b1SAlex Zinenko results.push_back(d.offset(builder, loc)); 254b5d847b1SAlex Zinenko for (int64_t i = 0; i < rank; ++i) 255b5d847b1SAlex Zinenko results.push_back(d.size(builder, loc, i)); 256b5d847b1SAlex Zinenko for (int64_t i = 0; i < rank; ++i) 257b5d847b1SAlex Zinenko results.push_back(d.stride(builder, loc, i)); 258b5d847b1SAlex Zinenko } 259b5d847b1SAlex Zinenko 260b5d847b1SAlex Zinenko /// Returns the number of non-aggregate values that would be produced by 261b5d847b1SAlex Zinenko /// `unpack`. 262b5d847b1SAlex Zinenko unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { 263*69007687SMatthias Springer // Two pointers, offset, <rank> sizes, <rank> strides. 264b5d847b1SAlex Zinenko return 3 + 2 * type.getRank(); 265b5d847b1SAlex Zinenko } 266b5d847b1SAlex Zinenko 267b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 268b5d847b1SAlex Zinenko // MemRefDescriptorView implementation. 269b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 270b5d847b1SAlex Zinenko 271b5d847b1SAlex Zinenko MemRefDescriptorView::MemRefDescriptorView(ValueRange range) 272b5d847b1SAlex Zinenko : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} 273b5d847b1SAlex Zinenko 274b5d847b1SAlex Zinenko Value MemRefDescriptorView::allocatedPtr() { 275b5d847b1SAlex Zinenko return elements[kAllocatedPtrPosInMemRefDescriptor]; 276b5d847b1SAlex Zinenko } 277b5d847b1SAlex Zinenko 278b5d847b1SAlex Zinenko Value MemRefDescriptorView::alignedPtr() { 279b5d847b1SAlex Zinenko return elements[kAlignedPtrPosInMemRefDescriptor]; 280b5d847b1SAlex Zinenko } 281b5d847b1SAlex Zinenko 282b5d847b1SAlex Zinenko Value MemRefDescriptorView::offset() { 283b5d847b1SAlex Zinenko return elements[kOffsetPosInMemRefDescriptor]; 284b5d847b1SAlex Zinenko } 285b5d847b1SAlex Zinenko 286b5d847b1SAlex Zinenko Value MemRefDescriptorView::size(unsigned pos) { 287b5d847b1SAlex Zinenko return elements[kSizePosInMemRefDescriptor + pos]; 288b5d847b1SAlex Zinenko } 289b5d847b1SAlex Zinenko 290b5d847b1SAlex Zinenko Value MemRefDescriptorView::stride(unsigned pos) { 291b5d847b1SAlex Zinenko return elements[kSizePosInMemRefDescriptor + rank + pos]; 292b5d847b1SAlex Zinenko } 293b5d847b1SAlex Zinenko 294b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 295b5d847b1SAlex Zinenko // UnrankedMemRefDescriptor implementation 296b5d847b1SAlex Zinenko //===----------------------------------------------------------------------===// 297b5d847b1SAlex Zinenko 298b5d847b1SAlex Zinenko /// Construct a helper for the given descriptor value. 299b5d847b1SAlex Zinenko UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) 300b5d847b1SAlex Zinenko : StructBuilder(descriptor) {} 301b5d847b1SAlex Zinenko 302b5d847b1SAlex Zinenko /// Builds IR creating an `undef` value of the descriptor type. 303b5d847b1SAlex Zinenko UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, 304b5d847b1SAlex Zinenko Location loc, 305b5d847b1SAlex Zinenko Type descriptorType) { 306b5d847b1SAlex Zinenko Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); 307b5d847b1SAlex Zinenko return UnrankedMemRefDescriptor(descriptor); 308b5d847b1SAlex Zinenko } 309d0f19ce7SKrzysztof Drewniak Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { 310b5d847b1SAlex Zinenko return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); 311b5d847b1SAlex Zinenko } 312b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, 313b5d847b1SAlex Zinenko Value v) { 314b5d847b1SAlex Zinenko setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); 315b5d847b1SAlex Zinenko } 316b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, 317d0f19ce7SKrzysztof Drewniak Location loc) const { 318b5d847b1SAlex Zinenko return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); 319b5d847b1SAlex Zinenko } 320b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, 321b5d847b1SAlex Zinenko Location loc, Value v) { 322b5d847b1SAlex Zinenko setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); 323b5d847b1SAlex Zinenko } 324b5d847b1SAlex Zinenko 325b5d847b1SAlex Zinenko /// Builds IR populating an unranked MemRef descriptor structure from a list 326b5d847b1SAlex Zinenko /// of individual constituent values in the following order: 327b5d847b1SAlex Zinenko /// - rank of the memref; 328b5d847b1SAlex Zinenko /// - pointer to the memref descriptor. 329b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, 330ce254598SMatthias Springer const LLVMTypeConverter &converter, 331b5d847b1SAlex Zinenko UnrankedMemRefType type, 332b5d847b1SAlex Zinenko ValueRange values) { 333b5d847b1SAlex Zinenko Type llvmType = converter.convertType(type); 334b5d847b1SAlex Zinenko auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); 335b5d847b1SAlex Zinenko 336b5d847b1SAlex Zinenko d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); 337b5d847b1SAlex Zinenko d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); 338b5d847b1SAlex Zinenko return d; 339b5d847b1SAlex Zinenko } 340b5d847b1SAlex Zinenko 341b5d847b1SAlex Zinenko /// Builds IR extracting individual elements that compose an unranked memref 342b5d847b1SAlex Zinenko /// descriptor and returns them as `results` list. 343b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, 344b5d847b1SAlex Zinenko Value packed, 345b5d847b1SAlex Zinenko SmallVectorImpl<Value> &results) { 346b5d847b1SAlex Zinenko UnrankedMemRefDescriptor d(packed); 347b5d847b1SAlex Zinenko results.reserve(results.size() + 2); 348b5d847b1SAlex Zinenko results.push_back(d.rank(builder, loc)); 349b5d847b1SAlex Zinenko results.push_back(d.memRefDescPtr(builder, loc)); 350b5d847b1SAlex Zinenko } 351b5d847b1SAlex Zinenko 352b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::computeSizes( 353ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 354d0f19ce7SKrzysztof Drewniak ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces, 355d0f19ce7SKrzysztof Drewniak SmallVectorImpl<Value> &sizes) { 356b5d847b1SAlex Zinenko if (values.empty()) 357b5d847b1SAlex Zinenko return; 358d0f19ce7SKrzysztof Drewniak assert(values.size() == addressSpaces.size() && 359d0f19ce7SKrzysztof Drewniak "must provide address space for each descriptor"); 360b5d847b1SAlex Zinenko // Cache the index type. 361b5d847b1SAlex Zinenko Type indexType = typeConverter.getIndexType(); 362b5d847b1SAlex Zinenko 363b5d847b1SAlex Zinenko // Initialize shared constants. 364b5d847b1SAlex Zinenko Value one = createIndexAttrConstant(builder, loc, indexType, 1); 365b5d847b1SAlex Zinenko Value two = createIndexAttrConstant(builder, loc, indexType, 2); 3660fb216fbSRamkumar Ramachandra Value indexSize = createIndexAttrConstant( 3670fb216fbSRamkumar Ramachandra builder, loc, indexType, 368e843f029SRamkumar Ramachandra llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8)); 369b5d847b1SAlex Zinenko 370b5d847b1SAlex Zinenko sizes.reserve(sizes.size() + values.size()); 371d0f19ce7SKrzysztof Drewniak for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) { 372b5d847b1SAlex Zinenko // Emit IR computing the memory necessary to store the descriptor. This 373b5d847b1SAlex Zinenko // assumes the descriptor to be 374b5d847b1SAlex Zinenko // { type*, type*, index, index[rank], index[rank] } 375b5d847b1SAlex Zinenko // and densely packed, so the total size is 376b5d847b1SAlex Zinenko // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). 377b5d847b1SAlex Zinenko // TODO: consider including the actual size (including eventual padding due 378b5d847b1SAlex Zinenko // to data layout) into the unranked descriptor. 379d0f19ce7SKrzysztof Drewniak Value pointerSize = createIndexAttrConstant( 380d0f19ce7SKrzysztof Drewniak builder, loc, indexType, 381e843f029SRamkumar Ramachandra llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); 382b5d847b1SAlex Zinenko Value doublePointerSize = 383b5d847b1SAlex Zinenko builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize); 384b5d847b1SAlex Zinenko 385b5d847b1SAlex Zinenko // (1 + 2 * rank) * sizeof(index) 386b5d847b1SAlex Zinenko Value rank = desc.rank(builder, loc); 387b5d847b1SAlex Zinenko Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank); 388b5d847b1SAlex Zinenko Value doubleRankIncremented = 389b5d847b1SAlex Zinenko builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one); 390b5d847b1SAlex Zinenko Value rankIndexSize = builder.create<LLVM::MulOp>( 391b5d847b1SAlex Zinenko loc, indexType, doubleRankIncremented, indexSize); 392b5d847b1SAlex Zinenko 393b5d847b1SAlex Zinenko // Total allocation size. 394b5d847b1SAlex Zinenko Value allocationSize = builder.create<LLVM::AddOp>( 395b5d847b1SAlex Zinenko loc, indexType, doublePointerSize, rankIndexSize); 396b5d847b1SAlex Zinenko sizes.push_back(allocationSize); 397b5d847b1SAlex Zinenko } 398b5d847b1SAlex Zinenko } 399b5d847b1SAlex Zinenko 40050ea17b8SMarkus Böck Value UnrankedMemRefDescriptor::allocatedPtr( 40150ea17b8SMarkus Böck OpBuilder &builder, Location loc, Value memRefDescPtr, 40250ea17b8SMarkus Böck LLVM::LLVMPointerType elemPtrType) { 403b28a296cSChristian Ulmann return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr); 404b5d847b1SAlex Zinenko } 405b5d847b1SAlex Zinenko 40650ea17b8SMarkus Böck void UnrankedMemRefDescriptor::setAllocatedPtr( 40750ea17b8SMarkus Böck OpBuilder &builder, Location loc, Value memRefDescPtr, 40850ea17b8SMarkus Böck LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { 409b28a296cSChristian Ulmann builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr); 410b5d847b1SAlex Zinenko } 411b5d847b1SAlex Zinenko 41250ea17b8SMarkus Böck static std::pair<Value, Type> 41350ea17b8SMarkus Böck castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, 41450ea17b8SMarkus Böck LLVM::LLVMPointerType elemPtrType) { 415b28a296cSChristian Ulmann auto elemPtrPtrType = LLVM::LLVMPointerType::get(builder.getContext()); 416b28a296cSChristian Ulmann return {memRefDescPtr, elemPtrPtrType}; 41750ea17b8SMarkus Böck } 41850ea17b8SMarkus Böck 419ce254598SMatthias Springer Value UnrankedMemRefDescriptor::alignedPtr( 420ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 421ce254598SMatthias Springer Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { 42250ea17b8SMarkus Böck auto [elementPtrPtr, elemPtrPtrType] = 42350ea17b8SMarkus Böck castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); 424b5d847b1SAlex Zinenko 42550ea17b8SMarkus Böck Value alignedGep = 42650ea17b8SMarkus Böck builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, 42750ea17b8SMarkus Böck elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); 42850ea17b8SMarkus Böck return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep); 429b5d847b1SAlex Zinenko } 430b5d847b1SAlex Zinenko 431ce254598SMatthias Springer void UnrankedMemRefDescriptor::setAlignedPtr( 432ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 433ce254598SMatthias Springer Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) { 43450ea17b8SMarkus Böck auto [elementPtrPtr, elemPtrPtrType] = 43550ea17b8SMarkus Böck castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); 436b5d847b1SAlex Zinenko 43750ea17b8SMarkus Böck Value alignedGep = 43850ea17b8SMarkus Böck builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, 43950ea17b8SMarkus Böck elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); 440b5d847b1SAlex Zinenko builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); 441b5d847b1SAlex Zinenko } 442b5d847b1SAlex Zinenko 4437fb9bbe5SKrzysztof Drewniak Value UnrankedMemRefDescriptor::offsetBasePtr( 444ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 4457fb9bbe5SKrzysztof Drewniak Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { 44650ea17b8SMarkus Böck auto [elementPtrPtr, elemPtrPtrType] = 44750ea17b8SMarkus Böck castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); 448b5d847b1SAlex Zinenko 449b28a296cSChristian Ulmann return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, 45050ea17b8SMarkus Böck elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); 4517fb9bbe5SKrzysztof Drewniak } 45250ea17b8SMarkus Böck 4537fb9bbe5SKrzysztof Drewniak Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, 454ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 4557fb9bbe5SKrzysztof Drewniak Value memRefDescPtr, 4567fb9bbe5SKrzysztof Drewniak LLVM::LLVMPointerType elemPtrType) { 4577fb9bbe5SKrzysztof Drewniak Value offsetPtr = 4587fb9bbe5SKrzysztof Drewniak offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); 4596a10381bSStephan Herhut return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(), 4606a10381bSStephan Herhut offsetPtr); 461b5d847b1SAlex Zinenko } 462b5d847b1SAlex Zinenko 463b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, 464ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 465b5d847b1SAlex Zinenko Value memRefDescPtr, 46650ea17b8SMarkus Böck LLVM::LLVMPointerType elemPtrType, 46750ea17b8SMarkus Böck Value offset) { 4687fb9bbe5SKrzysztof Drewniak Value offsetPtr = 4697fb9bbe5SKrzysztof Drewniak offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); 4707fb9bbe5SKrzysztof Drewniak builder.create<LLVM::StoreOp>(loc, offset, offsetPtr); 471b5d847b1SAlex Zinenko } 472b5d847b1SAlex Zinenko 473ce254598SMatthias Springer Value UnrankedMemRefDescriptor::sizeBasePtr( 474ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 475ce254598SMatthias Springer Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { 476b5d847b1SAlex Zinenko Type indexTy = typeConverter.getIndexType(); 47750ea17b8SMarkus Böck Type structTy = LLVM::LLVMStructType::getLiteral( 47850ea17b8SMarkus Böck indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); 479b28a296cSChristian Ulmann auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); 480b28a296cSChristian Ulmann return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr, 48150ea17b8SMarkus Böck ArrayRef<LLVM::GEPArg>{0, 3}); 482b5d847b1SAlex Zinenko } 483b5d847b1SAlex Zinenko 484b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, 485ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 486b5d847b1SAlex Zinenko Value sizeBasePtr, Value index) { 48750ea17b8SMarkus Böck 48850ea17b8SMarkus Böck Type indexTy = typeConverter.getIndexType(); 48997a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 49050ea17b8SMarkus Böck 491bd7eff1fSMarkus Böck Value sizeStoreGep = 49297a238e8SChristian Ulmann builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); 49350ea17b8SMarkus Böck return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep); 494b5d847b1SAlex Zinenko } 495b5d847b1SAlex Zinenko 496b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, 497ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 498b5d847b1SAlex Zinenko Value sizeBasePtr, Value index, 499b5d847b1SAlex Zinenko Value size) { 50050ea17b8SMarkus Böck Type indexTy = typeConverter.getIndexType(); 50197a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 50250ea17b8SMarkus Böck 503bd7eff1fSMarkus Böck Value sizeStoreGep = 50497a238e8SChristian Ulmann builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); 505b5d847b1SAlex Zinenko builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); 506b5d847b1SAlex Zinenko } 507b5d847b1SAlex Zinenko 508ce254598SMatthias Springer Value UnrankedMemRefDescriptor::strideBasePtr( 509ce254598SMatthias Springer OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, 510b5d847b1SAlex Zinenko Value sizeBasePtr, Value rank) { 51150ea17b8SMarkus Böck Type indexTy = typeConverter.getIndexType(); 51297a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 51350ea17b8SMarkus Böck 51497a238e8SChristian Ulmann return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank); 515b5d847b1SAlex Zinenko } 516b5d847b1SAlex Zinenko 517b5d847b1SAlex Zinenko Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, 518ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 519b5d847b1SAlex Zinenko Value strideBasePtr, Value index, 520b5d847b1SAlex Zinenko Value stride) { 52150ea17b8SMarkus Böck Type indexTy = typeConverter.getIndexType(); 52297a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 52350ea17b8SMarkus Böck 52497a238e8SChristian Ulmann Value strideStoreGep = 52597a238e8SChristian Ulmann builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); 52650ea17b8SMarkus Böck return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep); 527b5d847b1SAlex Zinenko } 528b5d847b1SAlex Zinenko 529b5d847b1SAlex Zinenko void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, 530ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 531b5d847b1SAlex Zinenko Value strideBasePtr, Value index, 532b5d847b1SAlex Zinenko Value stride) { 53350ea17b8SMarkus Böck Type indexTy = typeConverter.getIndexType(); 53497a238e8SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 53550ea17b8SMarkus Böck 53697a238e8SChristian Ulmann Value strideStoreGep = 53797a238e8SChristian Ulmann builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); 538b5d847b1SAlex Zinenko builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); 539b5d847b1SAlex Zinenko } 540