1// RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s 2 3// CHECK-LABEL: func @extract_strided_metadata_constants 4// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>) 5func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>) 6 -> (memref<f32>, index, index, index, index, index) { 7 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 8 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 9 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 10 // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index 11 12 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 13 %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base : 14 memref<5x4xf32, strided<[4,1], offset:2>> 15 -> memref<f32>, index, index, index, index, index 16 17 // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]] 18 return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 19 memref<f32>, index, index, index, index, index 20} 21 22// ----- 23 24// Check that we simplify subview(src) into: 25// base, offset, sizes, strides xtract_strided_metadata src 26// final_sizes = subSizes 27// final_strides = <some math> strides 28// final_offset = <some math> offset 29// reinterpret_cast base to final_offset, final_sizes, final_ strides 30// 31// Orig strides: [s0, s1, s2] 32// Sub strides: [subS0, subS1, subS2] 33// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2] 34// ==> 1 affine map (used for each stride) with two values. 35// 36// Orig offset: origOff 37// Sub offsets: [subO0, subO1, subO2] 38// => Final offset: s0 * * subO0 + ... + s2 * * subO2 + origOff 39// ==> 1 affine map with (rank * 2 + 1) symbols 40// 41// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> 42// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> 43// CHECK-LABEL: func @simplify_subview_all_dynamic 44// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) 45// 46// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] 47// 48// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0] 49// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] 50// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] 51// 52// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2] 53// 54// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]] 55// 56// CHECK: return %[[RES]] 57func.func @simplify_subview_all_dynamic( 58 %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>, 59 %offset0: index, %offset1: index, %offset2: index, 60 %size0: index, %size1: index, %size2: index, 61 %stride0: index, %stride1: index, %stride2: index) 62 -> memref<?x?x?xf32, strided<[?,?,?], offset:?>> { 63 64 %subview = memref.subview %base[%offset0, %offset1, %offset2] 65 [%size0, %size1, %size2] 66 [%stride0, %stride1, %stride2] : 67 memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to 68 memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 69 70 return %subview : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 71} 72 73// ----- 74 75// Check that we simplify extract_strided_metadata of subview to 76// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata 77// strides = base_stride_i * subview_stride_i 78// offset = base_offset + sum(subview_offsets_i * base_strides_i). 79// 80// This test also checks that we don't create useless arith operations 81// when subview_offsets_i is 0. 82// 83// CHECK-LABEL: func @extract_strided_metadata_of_subview 84// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>) 85// 86// Materialize the offset for dimension 1. 87// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 88// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 89// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 90// 91// Plain extract_strided_metadata. 92// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 93// 94// Final offset is: 95// origOffset + (== 0) 96// base_stride0 * subview_offset0 + (== 4 * 0 == 0) 97// base_stride1 * subview_offset1 (== 1 * 2) 98// == 2 99// 100// Return the new tuple. 101// CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]] 102func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) 103 -> (memref<f32>, index, index, index, index, index) { 104 105 %subview = memref.subview %base[0, 2][2, 2][1, 1] : 106 memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>> 107 108 %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : 109 memref<2x2xf32, strided<[4,1], offset:2>> 110 -> memref<f32>, index, index, index, index, index 111 112 return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 113 memref<f32>, index, index, index, index, index 114} 115 116// ----- 117 118// Check that we simplify extract_strided_metadata of subview properly 119// when dynamic sizes are involved. 120// See extract_strided_metadata_of_subview for an explanation of the actual 121// expansion. 122// Orig strides: [64, 4, 1] 123// Sub strides: [1, 1, 1] 124// => New strides: [64, 4, 1] 125// 126// Orig offset: 0 127// Sub offsets: [3, 4, 2] 128// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 129// 130// Final sizes == subview sizes == [%size, 6, 3] 131// 132// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size 133// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, 134// CHECK-SAME: %[[DYN_SIZE:.*]]: index) 135// 136// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index 137// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index 138// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index 139// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 140// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 141// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 142// 143// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] 144// 145// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]] 146func.func @extract_strided_metadata_of_subview_with_dynamic_size( 147 %base: memref<8x16x4xf32>, %size: index) 148 -> (memref<f32>, index, index, index, index, index, index, index) { 149 150 %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : 151 memref<8x16x4xf32> to memref<?x6x3xf32, strided<[64, 4, 1], offset: 210>> 152 153 %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : 154 memref<?x6x3xf32, strided<[64,4,1], offset: 210>> 155 -> memref<f32>, index, index, index, index, index, index, index 156 157 return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : 158 memref<f32>, index, index, index, index, index, index, index 159} 160 161// ----- 162 163// Check that we simplify extract_strided_metadata of subview properly 164// when the subview reduces the ranks. 165// In particular the returned strides must come from #1 and #2 of the %strides 166// value of the new extract_strided_metadata_of_subview, not #0 and #1. 167// See extract_strided_metadata_of_subview for an explanation of the actual 168// expansion. 169// 170// Orig strides: [64, 4, 1] 171// Sub strides: [1, 1, 1] 172// => New strides: [64, 4, 1] 173// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1] 174// 175// Orig offset: 0 176// Sub offsets: [3, 4, 2] 177// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 178// 179// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] 180// 181// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview 182// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) 183// 184// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index 185// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index 186// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 187// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 188// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 189// 190// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] 191// 192// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]] 193func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) 194 -> (memref<f32>, index, index, index, index, index) { 195 196 %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : 197 memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> 198 199 %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : 200 memref<6x3xf32, strided<[4,1], offset: 210>> 201 -> memref<f32>, index, index, index, index, index 202 203 return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 204 memref<f32>, index, index, index, index, index 205} 206 207// ----- 208 209// Check that we simplify extract_strided_metadata of subview properly 210// when the subview reduces the rank and some of the strides are variable. 211// In particular, we check that: 212// A. The dynamic stride is multiplied with the base stride to create the new 213// stride for dimension 1. 214// B. The first returned stride is the value computed in #A. 215// See extract_strided_metadata_of_subview for an explanation of the actual 216// expansion. 217// 218// Orig strides: [64, 4, 1] 219// Sub strides: [1, %stride, 1] 220// => New strides: [64, 4 * %stride, 1] 221// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1] 222// 223// Orig offset: 0 224// Sub offsets: [3, 4, 2] 225// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 226// 227// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> 228// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides 229// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, 230// CHECK-SAME: %[[DYN_STRIDE:.*]]: index) 231// 232// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index 233// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index 234// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 235// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 236// 237// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] 238// 239// CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] 240// 241// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] 242func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( 243 %base: memref<8x16x4xf32>, %stride: index) 244 -> (memref<f32>, index, index, index, index, index) { 245 246 %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : 247 memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>> 248 249 %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : 250 memref<6x3xf32, strided<[?, 1], offset: 210>> 251 -> memref<f32>, index, index, index, index, index 252 253 return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 254 memref<f32>, index, index, index, index, index 255} 256 257// ----- 258 259// Check that we simplify extract_strided_metadata of subview properly 260// when the subview uses variable offsets. 261// See extract_strided_metadata_of_subview for an explanation of the actual 262// expansion. 263// 264// Orig strides: [128, 1] 265// Sub strides: [1, 1] 266// => New strides: [128, 1] 267// 268// Orig offset: 0 269// Sub offsets: [%arg1, %arg2] 270// => Final offset: 128 * arg1 + 1 * %arg2 + 0 271// 272// CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)> 273// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset 274// CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>, 275// CHECK-SAME: %[[DYN_OFFSET0:.*]]: index, 276// CHECK-SAME: %[[DYN_OFFSET1:.*]]: index) 277// 278// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index 279// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index 280// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 281// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 282// 283// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]] 284// 285// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]] 286func.func @extract_strided_metadata_of_subview_w_variable_offset( 287 %arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) 288 -> (memref<f32>, index, index, index, index, index) { 289 290 %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : 291 memref<384x128xf32> to memref<64x64xf32, strided<[128, 1], offset: ?>> 292 293 %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : 294 memref<64x64xf32, strided<[128, 1], offset: ?>> -> memref<f32>, index, index, index, index, index 295 296 return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 297 memref<f32>, index, index, index, index, index 298} 299 300// ----- 301 302// Check that all the math is correct for all types of computations. 303// We achieve that by using dynamic values for all the different types: 304// - Offsets 305// - Sizes 306// - Strides 307// 308// Orig strides: [s0, s1, s2] 309// Sub strides: [subS0, subS1, subS2] 310// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2] 311// ==> 1 affine map (used for each stride) with two values. 312// 313// Orig offset: origOff 314// Sub offsets: [subO0, subO1, subO2] 315// => Final offset: s0 * * subO0 + ... + s2 * subO2 + origOff 316// ==> 1 affine map with (rank * 2 + 1) symbols 317// 318// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> 319// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> 320// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic 321// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) 322// 323// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] 324// 325// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0] 326// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] 327// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] 328// 329// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2] 330// 331// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]] 332func.func @extract_strided_metadata_of_subview_all_dynamic( 333 %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>, 334 %offset0: index, %offset1: index, %offset2: index, 335 %size0: index, %size1: index, %size2: index, 336 %stride0: index, %stride1: index, %stride2: index) 337 -> (memref<f32>, index, index, index, index, index, index, index) { 338 339 %subview = memref.subview %base[%offset0, %offset1, %offset2] 340 [%size0, %size1, %size2] 341 [%stride0, %stride1, %stride2] : 342 memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to 343 memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 344 345 %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : 346 memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 347 -> memref<f32>, index, index, index, index, index, index, index 348 349 return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : 350 memref<f32>, index, index, index, index, index, index, index 351} 352 353// ----- 354 355// Check that we properly simplify expand_shape into: 356// reinterpret_cast(extract_strided_metadata) + <some math> 357// 358// Here we have: 359// For the group applying to dim0: 360// size 0 = baseSizes#0 / (all static sizes in that group) 361// = baseSizes#0 / (7 * 8 * 9) 362// = baseSizes#0 / 504 363// size 1 = 7 364// size 2 = 8 365// size 3 = 9 366// stride 0 = baseStrides#0 * 7 * 8 * 9 367// = baseStrides#0 * 504 368// stride 1 = baseStrides#0 * 8 * 9 369// = baseStrides#0 * 72 370// stride 2 = baseStrides#0 * 9 371// stride 3 = baseStrides#0 372// 373// For the group applying to dim1: 374// size 4 = 10 375// size 5 = 2 376// size 6 = baseSizes#1 / (all static sizes in that group) 377// = baseSizes#1 / (10 * 2 * 3) 378// = baseSizes#1 / 60 379// size 7 = 3 380// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 381// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 382// = baseStrides#1 * (baseSizes#1 / 60) * 6 383// and since we know that baseSizes#1 is a multiple of 60: 384// = baseStrides#1 * (baseSizes#1 / 10) 385// stride 5 = baseStrides#1 * size 6 * size 7 386// = baseStrides#1 * (baseSizes#1 / 60) * 3 387// = baseStrides#1 * (baseSizes#1 / 20) 388// stride 6 = baseStrides#1 * size 7 389// = baseStrides#1 * 3 390// stride 7 = baseStrides#1 391// 392// Base and offset are unchanged. 393// 394// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> 395// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> 396// 397// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> 398// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> 399// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> 400// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> 401// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> 402// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> 403// CHECK-LABEL: func @simplify_expand_shape 404// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32, 405// 406// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index 407// 408// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] 409// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] 410// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] 411// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] 412// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] 413// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] 414// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] 415// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] 416// 417// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1] 418// 419// CHECK: return %[[REINTERPRET_CAST]] 420func.func @simplify_expand_shape( 421 %base: memref<?x?xf32, strided<[?,?], offset:?>>, 422 %offset0: index, %offset1: index, %offset2: index, 423 %size0: index, %size1: index, %size2: index, 424 %stride0: index, %stride1: index, %stride2: index, 425 %sz0: index, %sz1: index) 426 -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> { 427 428 %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] : 429 memref<?x?xf32, strided<[?,?], offset: ?>> into 430 memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> 431 432 return %subview : 433 memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> 434} 435 436// ----- 437 438// Check that we properly simplify extract_strided_metadata of expand_shape 439// into: 440// baseBuffer, baseOffset, baseSizes, baseStrides = 441// extract_strided_metadata(memref) 442// sizes#reassIdx = 443// baseSizes#reassDim / product(expandShapeSizes#j, 444// for j in group excluding reassIdx) 445// strides#reassIdx = 446// baseStrides#reassDim * product(expandShapeSizes#j, for j in 447// reassIdx+1..reassIdx+group.size) 448// 449// Here we have: 450// For the group applying to dim0: 451// size 0 = 3 452// size 1 = 5 453// size 2 = 2 454// stride 0 = baseStrides#0 * 5 * 2 455// = 4 * 5 * 2 456// = 40 457// stride 1 = baseStrides#0 * 2 458// = 4 * 2 459// = 8 460// stride 2 = baseStrides#0 461// = 4 462// 463// For the group applying to dim1: 464// size 3 = 2 465// size 4 = 2 466// stride 3 = baseStrides#1 * 2 467// = 1 * 2 468// = 2 469// stride 4 = baseStrides#1 470// = 1 471// 472// Base and offset are unchanged. 473// 474// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static 475// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>) 476// 477// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index 478// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 479// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index 480// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 481// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 482// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 483// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 484// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 485// 486// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref<i16>, index, index, index, index, index 487// 488// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index 489func.func @extract_strided_metadata_of_expand_shape_all_static( 490 %arg : memref<30x4xi16>) 491 -> (memref<i16>, index, 492 index, index, index, index, index, 493 index, index, index, index, index) { 494 495 %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] output_shape [3, 5, 2, 2, 2] : 496 memref<30x4xi16> into memref<3x5x2x2x2xi16> 497 498 %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : 499 memref<3x5x2x2x2xi16> 500 -> memref<i16>, index, 501 index, index, index, index, index, 502 index, index, index, index, index 503 504 return %base, %offset, 505 %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, 506 %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : 507 memref<i16>, index, 508 index, index, index, index, index, 509 index, index, index, index, index 510} 511 512// ----- 513 514// Check that we properly simplify extract_strided_metadata of expand_shape 515// when dynamic sizes, strides, and offsets are involved. 516// See extract_strided_metadata_of_expand_shape_all_static for an explanation 517// of the expansion. 518// 519// One of the important characteristic of this test is that the dynamic 520// dimensions produced by the expand_shape appear both in the first dimension 521// (for group 1) and the non-first dimension (second dimension for group 2.) 522// The idea is to make sure that: 523// 1. We properly account for dynamic shapes even when the strides are not 524// affected by them. (When the dynamic dimension is the first one.) 525// 2. We properly compute the strides affected by dynamic shapes. (When the 526// dynamic dimension is not the first one.) 527// 528// Here we have: 529// For the group applying to dim0: 530// size 0 = baseSizes#0 / (all static sizes in that group) 531// = baseSizes#0 / (7 * 8 * 9) 532// = baseSizes#0 / 504 533// size 1 = 7 534// size 2 = 8 535// size 3 = 9 536// stride 0 = baseStrides#0 * 7 * 8 * 9 537// = baseStrides#0 * 504 538// stride 1 = baseStrides#0 * 8 * 9 539// = baseStrides#0 * 72 540// stride 2 = baseStrides#0 * 9 541// stride 3 = baseStrides#0 542// 543// For the group applying to dim1: 544// size 4 = 10 545// size 5 = 2 546// size 6 = baseSizes#1 / (all static sizes in that group) 547// = baseSizes#1 / (10 * 2 * 3) 548// = baseSizes#1 / 60 549// size 7 = 3 550// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 551// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 552// = baseStrides#1 * (baseSizes#1 / 60) * 6 553// and since we know that baseSizes#1 is a multiple of 60: 554// = baseStrides#1 * (baseSizes#1 / 10) 555// stride 5 = baseStrides#1 * size 6 * size 7 556// = baseStrides#1 * (baseSizes#1 / 60) * 3 557// = baseStrides#1 * (baseSizes#1 / 20) 558// stride 6 = baseStrides#1 * size 7 559// = baseStrides#1 * 3 560// stride 7 = baseStrides#1 561// 562// Base and offset are unchanged. 563// 564// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> 565// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> 566// 567// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> 568// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> 569// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> 570// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> 571// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> 572// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> 573// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic 574// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32, 575// 576// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index 577// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index 578// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index 579// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index 580// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 581// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 582// 583// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index 584// 585// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] 586// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] 587// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] 588// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] 589// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] 590// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] 591// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] 592// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] 593 594// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index 595func.func @extract_strided_metadata_of_expand_shape_all_dynamic( 596 %base: memref<?x?xf32, strided<[?,?], offset:?>>, 597 %offset0: index, %offset1: index, %offset2: index, 598 %size0: index, %size1: index, %size2: index, 599 %stride0: index, %stride1: index, %stride2: index, 600 %sz0: index, %sz1: index) 601 -> (memref<f32>, index, 602 index, index, index, index, index, index, index, index, 603 index, index, index, index, index, index, index, index) { 604 605 %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] : 606 memref<?x?xf32, strided<[?,?], offset: ?>> into 607 memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> 608 609 %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview : 610 memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> 611 -> memref<f32>, index, 612 index, index, index, index, index, index, index, index, 613 index, index, index, index, index, index, index, index 614 615 return %base_buffer, %offset, 616 %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7, 617 %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 : 618 memref<f32>, index, 619 index, index, index, index, index, index, index, index, 620 index, index, index, index, index, index, index, index 621} 622 623 624// ----- 625 626// Check that we properly handle extract_strided_metadata of expand_shape for 627// 0-D input. 628// The 0-D case is pretty boring: 629// All expanded sizes are 1, likewise for the strides, and we keep the 630// original base and offset. 631// We have still a test for it, because since the input reassociation map 632// of the expand_shape is empty, the handling of such shape hits a corner 633// case. 634// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank 635// CHECK-SAME: (%[[ARG:.*]]: memref<i16, strided<[], offset: ?>>) 636// 637// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 638// 639// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref<i16, strided<[], offset: ?>> -> memref<i16>, index 640// 641// CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index 642func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank( 643 %arg : memref<i16, strided<[], offset: ?>>) 644 -> (memref<i16>, index, 645 index, index, index, index, index, 646 index, index, index, index, index) { 647 648 %expand_shape = memref.expand_shape %arg[] output_shape [1, 1, 1, 1, 1] : 649 memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> 650 651 %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : 652 memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> 653 -> memref<i16>, index, 654 index, index, index, index, index, 655 index, index, index, index, index 656 657 return %base, %offset, 658 %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, 659 %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : 660 memref<i16>, index, 661 index, index, index, index, index, 662 index, index, index, index, index 663} 664 665// ----- 666 667// Check that we simplify extract_strided_metadata(alloc) 668// into simply the alloc with the information extracted from 669// the memref type and arguments of the alloc. 670// 671// baseBuffer = reinterpret_cast alloc 672// offset = 0 673// sizes = shape(memref) 674// strides = strides(memref) 675// 676// For dynamic shapes, we simply use the values that feed the alloc. 677// 678// Simple rank 0 test: we don't need a reinterpret_cast here. 679// CHECK-LABEL: func @extract_strided_metadata_of_alloc_all_static_0_rank 680// 681// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 682// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() 683// CHECK: return %[[ALLOC]], %[[C0]] : memref<i16>, index 684func.func @extract_strided_metadata_of_alloc_all_static_0_rank() 685 -> (memref<i16>, index) { 686 687 %A = memref.alloc() : memref<i16> 688 %base, %offset = memref.extract_strided_metadata %A : 689 memref<i16> 690 -> memref<i16>, index 691 692 return %base, %offset : 693 memref<i16>, index 694} 695 696// ----- 697 698// Simplification of extract_strided_metadata(alloc). 699// Check that we properly use the dynamic sizes to 700// create the new sizes and strides. 701// size 0 = dyn_size0 702// size 1 = 4 703// size 2 = dyn_size2 704// size 3 = dyn_size3 705// 706// stride 0 = size 1 * size 2 * size 3 707// = 4 * dyn_size2 * dyn_size3 708// stride 1 = size 2 * size 3 709// = dyn_size2 * dyn_size3 710// stride 2 = size 3 711// = dyn_size3 712// stride 3 = 1 713// 714// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> 715// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> 716// CHECK-LABEL: extract_strided_metadata_of_alloc_dyn_size 717// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_SIZE3:.*]]: index) 718// 719// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 720// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 721// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 722// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYN_SIZE0]], %[[DYN_SIZE2]], %[[DYN_SIZE3]]) 723// 724// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]] 725// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]] 726// 727// CHECK-DAG: %[[CASTED_ALLOC:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<?x4x?x?xi16> to memref<i16> 728// 729// CHECK: return %[[CASTED_ALLOC]], %[[C0]], %[[DYN_SIZE0]], %[[C4]], %[[DYN_SIZE2]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]] 730func.func @extract_strided_metadata_of_alloc_dyn_size( 731 %dyn_size0 : index, %dyn_size2 : index, %dyn_size3 : index) 732 -> (memref<i16>, index, 733 index, index, index, index, 734 index, index, index, index) { 735 736 %A = memref.alloc(%dyn_size0, %dyn_size2, %dyn_size3) : memref<?x4x?x?xi16> 737 738 %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A : 739 memref<?x4x?x?xi16> 740 -> memref<i16>, index, 741 index, index, index, index, 742 index, index, index, index 743 744 return %base, %offset, 745 %sizes#0, %sizes#1, %sizes#2, %sizes#3, 746 %strides#0, %strides#1, %strides#2, %strides#3 : 747 memref<i16>, index, 748 index, index, index, index, 749 index, index, index, index 750} 751 752// ----- 753 754// Same check as extract_strided_metadata_of_alloc_dyn_size but alloca 755// instead of alloc. Just to make sure we handle allocas the same way 756// we do with alloc. 757// While at it, test a slightly different shape than 758// extract_strided_metadata_of_alloc_dyn_size. 759// 760// size 0 = dyn_size0 761// size 1 = dyn_size1 762// size 2 = 4 763// size 3 = dyn_size3 764// 765// stride 0 = size 1 * size 2 * size 3 766// = dyn_size1 * 4 * dyn_size3 767// stride 1 = size 2 * size 3 768// = 4 * dyn_size3 769// stride 2 = size 3 770// = dyn_size3 771// stride 3 = 1 772// 773// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> 774// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> 775// CHECK-LABEL: extract_strided_metadata_of_alloca_dyn_size 776// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE3:.*]]: index) 777// 778// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 779// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 780// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 781// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca(%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE3]]) 782// 783// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE1]], %[[DYN_SIZE3]]] 784// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE3]]] 785// 786// CHECK-DAG: %[[CASTED_ALLOCA:.*]] = memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [], strides: [] : memref<?x?x4x?xi16> to memref<i16> 787// 788// CHECK: return %[[CASTED_ALLOCA]], %[[C0]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[C4]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]] 789func.func @extract_strided_metadata_of_alloca_dyn_size( 790 %dyn_size0 : index, %dyn_size1 : index, %dyn_size3 : index) 791 -> (memref<i16>, index, 792 index, index, index, index, 793 index, index, index, index) { 794 795 %A = memref.alloca(%dyn_size0, %dyn_size1, %dyn_size3) : memref<?x?x4x?xi16> 796 797 %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A : 798 memref<?x?x4x?xi16> 799 -> memref<i16>, index, 800 index, index, index, index, 801 index, index, index, index 802 803 return %base, %offset, 804 %sizes#0, %sizes#1, %sizes#2, %sizes#3, 805 %strides#0, %strides#1, %strides#2, %strides#3 : 806 memref<i16>, index, 807 index, index, index, index, 808 index, index, index, index 809} 810 811// ----- 812 813// The following few alloc tests are negative tests (the simplification 814// doesn't happen) to make sure non trivial memref types are treated 815// as "not been normalized". 816// CHECK-LABEL: extract_strided_metadata_of_alloc_with_variable_offset 817// CHECK: %[[ALLOC:.*]] = memref.alloc 818// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] 819// CHECK: return %[[BASE]] 820#map0 = affine_map<(d0)[s0] -> (d0 + s0)> 821func.func @extract_strided_metadata_of_alloc_with_variable_offset(%arg : index) 822 -> (memref<i16>, index, index, index) { 823 824 %A = memref.alloc()[%arg] : memref<4xi16, #map0> 825 %base, %offset, %size, %stride = memref.extract_strided_metadata %A : 826 memref<4xi16, #map0> 827 -> memref<i16>, index, index, index 828 829 return %base, %offset, %size, %stride : 830 memref<i16>, index, index, index 831} 832 833// ----- 834 835// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset 836// CHECK: %[[ALLOC:.*]] = memref.alloc 837// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] 838// CHECK: return %[[BASE]] 839#map0 = affine_map<(d0) -> (d0 + 12)> 840func.func @extract_strided_metadata_of_alloc_with_cst_offset(%arg : index) 841 -> (memref<i16>, index, index, index) { 842 843 %A = memref.alloc() : memref<4xi16, #map0> 844 %base, %offset, %size, %stride = memref.extract_strided_metadata %A : 845 memref<4xi16, #map0> 846 -> memref<i16>, index, index, index 847 848 return %base, %offset, %size, %stride : 849 memref<i16>, index, index, index 850} 851 852// ----- 853 854// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset_in_type 855// CHECK: %[[ALLOC:.*]] = memref.alloc 856// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] 857// CHECK: return %[[BASE]] 858func.func @extract_strided_metadata_of_alloc_with_cst_offset_in_type(%arg : index) 859 -> (memref<i16>, index, index, index) { 860 861 %A = memref.alloc() : memref<4xi16, strided<[1], offset : 10>> 862 %base, %offset, %size, %stride = memref.extract_strided_metadata %A : 863 memref<4xi16, strided<[1], offset : 10>> 864 -> memref<i16>, index, index, index 865 866 return %base, %offset, %size, %stride : 867 memref<i16>, index, index, index 868} 869 870// ----- 871 872// CHECK-LABEL: extract_strided_metadata_of_alloc_with_strided 873// CHECK: %[[ALLOC:.*]] = memref.alloc 874// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] 875// CHECK: return %[[BASE]] 876func.func @extract_strided_metadata_of_alloc_with_strided(%arg : index) 877 -> (memref<i16>, index, index, index) { 878 879 %A = memref.alloc() : memref<4xi16, strided<[12]>> 880 %base, %offset, %size, %stride = memref.extract_strided_metadata %A : 881 memref<4xi16, strided<[12]>> 882 -> memref<i16>, index, index, index 883 884 return %base, %offset, %size, %stride : 885 memref<i16>, index, index, index 886} 887 888// ----- 889 890// CHECK-LABEL: extract_aligned_pointer_as_index 891// CHECK-SAME: (%[[ARG0:.*]]: memref<f32> 892func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index { 893 // CHECK-NOT: memref.subview 894 // CHECK: memref.extract_aligned_pointer_as_index %[[ARG0]] 895 %c = memref.subview %arg0[] [] [] : memref<f32> to memref<f32> 896 %r = memref.extract_aligned_pointer_as_index %arg0: memref<f32> -> index 897 return %r : index 898} 899 900// ----- 901 902// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source 903// CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32> 904func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index { 905 // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index 906 // CHECK: return %[[I]] 907 908 %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32> 909 %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index 910 return %i : index 911} 912 913// ----- 914 915// Check that we simplify collapse_shape into 916// reinterpret_cast(extract_strided_metadata) + <some math> 917// 918// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5] 919// Size 0 = origSize0 920// Size 1 = origSize1 * origSize2 * origSize3 921// = origSize1 * 4 * origSize3 922// Size 2 = origSize4 * origSize5 923// = 6 * 7 924// = 42 925// Stride 0 = min(origStride0) 926// = Right now the folder of affine.min is not smart 927// enough to just return origStride0 928// Stride 1 = min(origStride1, origStride2, origStride3) 929// = min(origStride1, origStride2, 42) 930// Stride 2 = min(origStride4, origStride5) 931// = min(7, 1) 932// = 1 933// 934// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> 935// CHECK-LABEL: func @simplify_collapse( 936// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>) 937// 938// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32> 939// 940// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] 941// 942// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] 943func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>) 944 -> memref<?x?x42xi32> { 945 946 %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] : 947 memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32> 948 949 return %collapsed_view : memref<?x?x42xi32> 950 951} 952 953// ----- 954 955// Check that we simplify collapse_shape into 956// reinterpret_cast(extract_strided_metadata) + <some math> 957// when there are dimensions of size 1 involved. 958// 959// We transform: 3x1 to [0, 1] 960// 961// The tricky bit here is the strides between dimension 0 and 1 962// are not truly contiguous, but since we dealing with a dimension of size 1 963// this is actually fine (i.e., we are not going to jump around.) 964// 965// As a result the resulting stride needs to ignore the strides of the 966// dimensions of size 1. 967// 968// Size 0 = origSize0 * origSize1 969// = 3 * 1 970// = 3 971// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1) 972// = min(origStride0) 973// = min(2) 974// = 2 975// 976// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1( 977// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>, 978// 979// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>> 980// 981// 982// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2] 983func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) { 984 985 %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : 986 memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>> 987 988 memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32> 989 990 return 991} 992 993 994// ----- 995 996// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. 997// 998// The tricky bit here is also the resulting stride is meaningless, we still 999// have to please the type system. 1000// 1001// In this case, we're collapsing two strides of respectively 2 and 1 and the 1002// resulting type wants a stride of 2. 1003// 1004// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride( 1005// CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1] 1006// 1007// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>> 1008// 1009// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2] 1010func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride 1011 (%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>) 1012 -> memref<1xi32, strided<[2], offset: ?>> { 1013 1014 %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : 1015 memref<1x1xi32, strided<[2, 1], offset: ?>> 1016 into memref<1xi32, strided<[2], offset: ?>> 1017 1018 return %collapse_shape : memref<1xi32, strided<[2], offset: ?>> 1019} 1020 1021// ----- 1022 1023// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. 1024// We also have a couple of collapsed dimensions before the 1x1x...x1 group 1025// to make sure we properly index into the dynamic strides based on the 1026// group ID. 1027// 1028// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic 1029// so we have to propagate one of the dynamic dimension for this group. 1030// 1031// For this test we have: 1032// Size0 = origSize0 * origSize1 1033// = 2 * 3 1034// = 6 1035// Size1 = origSize2 * origSize3 * origSize4 1036// = 1 * 1 * 1 1037// = 1 1038// 1039// Stride0 = min(origStride0, origStride1) 1040// Stride1 = we actually don't know, this is dynamic but we don't know 1041// which one to pick. 1042// We just return the first dynamic one for this group. 1043// 1044// 1045// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride( 1046// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2] 1047// 1048// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> 1049// 1050// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2] 1051func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride 1052 (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>) 1053 -> memref<6x1xi32, strided<[?, ?], offset: ?>> { 1054 1055 %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] : 1056 memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> 1057 into memref<6x1xi32, strided<[?, ?], offset: ?>> 1058 1059 return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>> 1060} 1061 1062// ----- 1063 1064// Check that we simplify extract_strided_metadata of collapse_shape. 1065// 1066// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5] 1067// Size 0 = origSize0 1068// Size 1 = origSize1 * origSize2 * origSize3 1069// = origSize1 * 4 * origSize3 1070// Size 2 = origSize4 * origSize5 1071// = 6 * 7 1072// = 42 1073// Stride 0 = origStride0 1074// Stride 1 = origStride3 (orig stride of the inner most dimension) 1075// = 42 1076// Stride 2 = origStride5 1077// = 1 1078// 1079// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> 1080// CHECK-LABEL: func @extract_strided_metadata_of_collapse( 1081// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>) 1082// 1083// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index 1084// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 1085// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 1086// 1087// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32> 1088// 1089// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] 1090// 1091// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]] 1092func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>) 1093 -> (memref<i32>, index, 1094 index, index, index, 1095 index, index, index) { 1096 1097 %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] : 1098 memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32> 1099 1100 %base, %offset, %sizes:3, %strides:3 = 1101 memref.extract_strided_metadata %collapsed_view : memref<?x?x42xi32> 1102 -> memref<i32>, index, 1103 index, index, index, 1104 index, index, index 1105 1106 return %base, %offset, 1107 %sizes#0, %sizes#1, %sizes#2, 1108 %strides#0, %strides#1, %strides#2 : 1109 memref<i32>, index, 1110 index, index, index, 1111 index, index, index 1112 1113} 1114 1115// ----- 1116 1117// Check that we simplify extract_strided_metadata of collapse_shape to 1118// a 0-ranked shape. 1119// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0( 1120// CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>) 1121// 1122// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 1123// 1124// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32> 1125// 1126// CHECK: return %[[BASE]], %[[C0]] 1127func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>) 1128 -> (memref<i32>, index) { 1129 1130 %collapsed_view = memref.collapse_shape %arg [] : 1131 memref<1x1x1x1x1x1xi32> into memref<i32> 1132 1133 %base, %offset = 1134 memref.extract_strided_metadata %collapsed_view : memref<i32> 1135 -> memref<i32>, index 1136 1137 return %base, %offset : 1138 memref<i32>, index 1139} 1140 1141// ----- 1142 1143// Check that we simplify extract_strided_metadata of 1144// extract_strided_metadata. 1145// 1146// CHECK-LABEL: func @extract_strided_metadata_of_extract_strided_metadata( 1147// CHECK-SAME: %[[ARG:.*]]: memref<i32>) 1148// 1149// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 1150// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] 1151// 1152// CHECK: return %[[BASE]], %[[C0]] 1153func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i32>) 1154 -> (memref<i32>, index) { 1155 1156 %base, %offset = 1157 memref.extract_strided_metadata %arg:memref<i32> 1158 -> memref<i32>, index 1159 %base2, %offset2 = 1160 memref.extract_strided_metadata %base:memref<i32> 1161 -> memref<i32>, index 1162 1163 return %base2, %offset2 : 1164 memref<i32>, index 1165} 1166 1167// ----- 1168 1169// Check that we simplify extract_strided_metadata of reinterpret_cast 1170// when the source of the reinterpret_cast is compatible with what 1171// `extract_strided_metadata`s accept. 1172// 1173// When we apply the transformation the resulting offset, sizes and strides 1174// should come straight from the inputs of the reinterpret_cast. 1175// 1176// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast 1177// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) 1178// 1179// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]] 1180// 1181// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]] 1182func.func @extract_strided_metadata_of_reinterpret_cast( 1183 %arg : memref<?x?xi32, strided<[?, ?], offset:?>>, 1184 %offset: index, 1185 %size0 : index, %size1 : index, 1186 %stride0 : index, %stride1 : index) 1187 -> (memref<i32>, index, 1188 index, index, 1189 index, index) { 1190 1191 %cast = 1192 memref.reinterpret_cast %arg to 1193 offset: [%offset], 1194 sizes: [%size0, %size1], 1195 strides: [%stride0, %stride1] : 1196 memref<?x?xi32, strided<[?, ?], offset: ?>> to 1197 memref<?x?xi32, strided<[?, ?], offset: ?>> 1198 1199 %base, %base_offset, %sizes:2, %strides:2 = 1200 memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> 1201 -> memref<i32>, index, 1202 index, index, 1203 index, index 1204 1205 return %base, %base_offset, 1206 %sizes#0, %sizes#1, 1207 %strides#0, %strides#1 : 1208 memref<i32>, index, 1209 index, index, 1210 index, index 1211} 1212 1213// ----- 1214 1215// Check that we don't simplify extract_strided_metadata of 1216// reinterpret_cast when the source of the cast is unranked. 1217// Unranked memrefs cannot feed into extract_strided_metadata operations. 1218// Note: Technically we could still fold the sizes and strides. 1219// 1220// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked 1221// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) 1222// 1223// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]] 1224// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] 1225// 1226// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 1227func.func @extract_strided_metadata_of_reinterpret_cast_unranked( 1228 %arg : memref<*xi32>, 1229 %offset: index, 1230 %size0 : index, %size1 : index, 1231 %stride0 : index, %stride1 : index) 1232 -> (memref<i32>, index, 1233 index, index, 1234 index, index) { 1235 1236 %cast = 1237 memref.reinterpret_cast %arg to 1238 offset: [%offset], 1239 sizes: [%size0, %size1], 1240 strides: [%stride0, %stride1] : 1241 memref<*xi32> to 1242 memref<?x?xi32, strided<[?, ?], offset: ?>> 1243 1244 %base, %base_offset, %sizes:2, %strides:2 = 1245 memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> 1246 -> memref<i32>, index, 1247 index, index, 1248 index, index 1249 1250 return %base, %base_offset, 1251 %sizes#0, %sizes#1, 1252 %strides#0, %strides#1 : 1253 memref<i32>, index, 1254 index, index, 1255 index, index 1256} 1257 1258// ----- 1259 1260// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure 1261// we handle 0-D properly. 1262// 1263// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0 1264// CHECK-SAME: %[[ARG:.*]]: memref<i32, strided<[], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) 1265// 1266// CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] 1267// 1268// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]] 1269func.func @extract_strided_metadata_of_reinterpret_cast_rank0( 1270 %arg : memref<i32, strided<[], offset:?>>, 1271 %offset: index, 1272 %size0 : index, %size1 : index, 1273 %stride0 : index, %stride1 : index) 1274 -> (memref<i32>, index, 1275 index, index, 1276 index, index) { 1277 1278 %cast = 1279 memref.reinterpret_cast %arg to 1280 offset: [%offset], 1281 sizes: [%size0, %size1], 1282 strides: [%stride0, %stride1] : 1283 memref<i32, strided<[], offset: ?>> to 1284 memref<?x?xi32, strided<[?, ?], offset: ?>> 1285 1286 %base, %base_offset, %sizes:2, %strides:2 = 1287 memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> 1288 -> memref<i32>, index, 1289 index, index, 1290 index, index 1291 1292 return %base, %base_offset, 1293 %sizes#0, %sizes#1, 1294 %strides#0, %strides#1 : 1295 memref<i32>, index, 1296 index, index, 1297 index, index 1298} 1299 1300// ----- 1301 1302// Check that for `memref.get_global` -> `memref.extract_strided_metadata` resolves 1303// with the consumer replaced with the strides, sizes and offsets computed from 1304// `memref.get_global`. Since the result of `memref.get_global is always static shaped 1305// no need to check for dynamic shapes. 1306 1307// CHECK-LABEL: func @extract_strided_metadata_of_get_global() 1308// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 1309// CHECK-DAG: %[[C384:.+]] = arith.constant 384 : index 1310// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index 1311// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 1312// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 1313// CHECK: %[[CAST:.+]] = memref.reinterpret_cast %[[GET_GLOBAL]] 1314// CHECK-SAME: offset: [0], sizes: [], strides: [] 1315// CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]] 1316 1317memref.global "private" constant @const_i32 : memref<512x384xi32> = dense<42> 1318 1319func.func @extract_strided_metadata_of_get_global() 1320 -> (memref<i32>, index, index, index, index, index) { 1321 1322 %A = memref.get_global @const_i32 : memref<512x384xi32> 1323 1324 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : 1325 memref<512x384xi32> -> memref<i32>, index, index, index, index, index 1326 1327 return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 1328 memref<i32>, index, index, index, index, index 1329} 1330 1331// ----- 1332 1333// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not 1334// resolve when the strides are not identity. This is an unhandled case that could 1335// be covered in the future 1336 1337// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_strides() 1338// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 1339// CHECK: memref.extract_strided_metadata %[[GET_GLOBAL]] 1340memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> = dense<42> 1341 1342func.func @extract_strided_metadata_of_get_global_with_strides() 1343 -> (memref<i32>, index, index, index, index, index) { 1344 1345 %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> 1346 1347 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : 1348 memref<512x384xi32, strided<[420, 1], offset: 0>> 1349 -> memref<i32>, index, index, index, index, index 1350 1351 return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 1352 memref<i32>, index, index, index, index, index 1353} 1354 1355// ----- 1356 1357// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not 1358// resolve when the offset is non-zero. This is an unhandled case that could 1359// be covered in the future 1360 1361// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_offset() 1362// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 1363// CHECK: memref.extract_strided_metadata %[[GET_GLOBAL]] 1364memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> = dense<42> 1365 1366func.func @extract_strided_metadata_of_get_global_with_offset() 1367 -> (memref<i32>, index, index, index, index, index) { 1368 1369 %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> 1370 1371 %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : 1372 memref<512x384xi32, strided<[384, 1], offset: 20>> 1373 -> memref<i32>, index, index, index, index, index 1374 1375 return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : 1376 memref<i32>, index, index, index, index, index 1377} 1378 1379// ----- 1380 1381// Check that we simplify extract_strided_metadata of cast 1382// when the source of the cast is compatible with what 1383// `extract_strided_metadata`s accept. 1384// 1385// When we apply the transformation the resulting offset, sizes and strides 1386// should come straight from the inputs of the cast. 1387// Additionally the folder on extract_strided_metadata should propagate the 1388// static information. 1389// 1390// CHECK-LABEL: func @extract_strided_metadata_of_cast 1391// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>) 1392// 1393// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 1394// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 1395// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 1396// 1397// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1 1398func.func @extract_strided_metadata_of_cast( 1399 %arg : memref<3x?xi32, strided<[4, ?], offset:?>>) 1400 -> (memref<i32>, index, 1401 index, index, 1402 index, index) { 1403 1404 %cast = 1405 memref.cast %arg : 1406 memref<3x?xi32, strided<[4, ?], offset: ?>> to 1407 memref<?x?xi32, strided<[?, ?], offset: ?>> 1408 1409 %base, %base_offset, %sizes:2, %strides:2 = 1410 memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> 1411 -> memref<i32>, index, 1412 index, index, 1413 index, index 1414 1415 return %base, %base_offset, 1416 %sizes#0, %sizes#1, 1417 %strides#0, %strides#1 : 1418 memref<i32>, index, 1419 index, index, 1420 index, index 1421} 1422 1423// ----- 1424 1425// Check that we simplify extract_strided_metadata of cast 1426// when the source of the cast is compatible with what 1427// `extract_strided_metadata`s accept. 1428// 1429// Same as extract_strided_metadata_of_cast but with constant sizes and strides 1430// in the destination type. 1431// 1432// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts 1433// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>) 1434// 1435// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 1436// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index 1437// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index 1438// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] 1439// 1440// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]] 1441func.func @extract_strided_metadata_of_cast_w_csts( 1442 %arg : memref<?x?xi32, strided<[?, ?], offset:?>>) 1443 -> (memref<i32>, index, 1444 index, index, 1445 index, index) { 1446 1447 %cast = 1448 memref.cast %arg : 1449 memref<?x?xi32, strided<[?, ?], offset: ?>> to 1450 memref<4x?xi32, strided<[?, 18], offset: 25>> 1451 1452 %base, %base_offset, %sizes:2, %strides:2 = 1453 memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>> 1454 -> memref<i32>, index, 1455 index, index, 1456 index, index 1457 1458 return %base, %base_offset, 1459 %sizes#0, %sizes#1, 1460 %strides#0, %strides#1 : 1461 memref<i32>, index, 1462 index, index, 1463 index, index 1464} 1465 1466// ----- 1467 1468// Check that we don't simplify extract_strided_metadata of 1469// cast when the source of the cast is unranked. 1470// Unranked memrefs cannot feed into extract_strided_metadata operations. 1471// Note: Technically we could still fold the sizes and strides. 1472// 1473// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked 1474// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>) 1475// 1476// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : 1477// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] 1478// 1479// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 1480func.func @extract_strided_metadata_of_cast_unranked( 1481 %arg : memref<*xi32>) 1482 -> (memref<i32>, index, 1483 index, index, 1484 index, index) { 1485 1486 %cast = 1487 memref.cast %arg : 1488 memref<*xi32> to 1489 memref<?x?xi32, strided<[?, ?], offset: ?>> 1490 1491 %base, %base_offset, %sizes:2, %strides:2 = 1492 memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> 1493 -> memref<i32>, index, 1494 index, index, 1495 index, index 1496 1497 return %base, %base_offset, 1498 %sizes#0, %sizes#1, 1499 %strides#0, %strides#1 : 1500 memref<i32>, index, 1501 index, index, 1502 index, index 1503} 1504 1505 1506// ----- 1507 1508memref.global "private" @dynamicShmem : memref<0xf16,3> 1509 1510// CHECK-LABEL: func @zero_sized_memred 1511func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index) { 1512 %c0 = arith.constant 0 : index 1513 %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> 1514 1515 // CHECK: %[[BASE:.*]] = memref.get_global @dynamicShmem : memref<0xf16, 3> 1516 // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [], strides: [] : memref<0xf16, 3> to memref<f16, 3> 1517 // CHECK: return %[[CAST]] 1518 1519 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %dynamicMem : memref<0xf16, 3> -> memref<f16, 3>, index, index, index 1520 return %base_buffer, %offset, 1521 %sizes, %strides : 1522 memref<f16,3>, index, 1523 index, index 1524} 1525 1526// ----- 1527 1528func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>) 1529 -> (memref<f32>, index, index, index) { 1530 1531 %collapse = memref.collapse_shape %base[[0, 1]] : 1532 memref<5x4xf32> into memref<20xf32> 1533 1534 %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse : 1535 memref<20xf32> -> memref<f32>, index, index, index 1536 1537 return %base_buffer, %offset, %size, %stride : 1538 memref<f32>, index, index, index 1539} 1540 1541// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape 1542// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index 1543// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index 1544// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index 1545// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata 1546// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index 1547 1548// ----- 1549 1550func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>) 1551 -> (memref<f32, 1>, index, index, index) { 1552 1553 %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1> 1554 1555 %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast : 1556 memref<20xf32, 1> -> memref<f32, 1>, index, index, index 1557 1558 return %base_buffer, %offset, %size, %stride : 1559 memref<f32, 1>, index, index, index 1560} 1561 1562// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast 1563// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index 1564// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index 1565// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index 1566// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata 1567// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]] 1568// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index 1569 1570// ----- 1571 1572func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>) 1573 -> (index, index, index) { 1574 1575 %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1> 1576 1577 %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast : 1578 memref<20xf32, 1> -> memref<f32, 1>, index, index, index 1579 1580 return %offset, %size, %stride : index, index, index 1581} 1582 1583// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base 1584// CHECK-NOT: memref.memory_space_cast 1585