1// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file -verify-diagnostics | FileCheck %s 2 3//===----------------------------------------------------------------------===// 4// arm_sme.tile_load 5//===----------------------------------------------------------------------===// 6 7// CHECK-LABEL: func.func @arm_sme_tile_load_hor( 8// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) { 9// CHECK-DAG: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> 10// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 11// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 12// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 13// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 14// CHECK-DAG: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1> 15// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index 16// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) { 17// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index 18// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> 19// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> 20func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) { 21 %c0 = arith.constant 0 : index 22 %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32> 23 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 24 return 25} 26 27// ----- 28 29// CHECK-LABEL: @arm_sme_tile_load_ver 30// CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> 31func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) { 32 %c0 = arith.constant 0 : index 33 %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> 34 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 35 return 36} 37 38// ----- 39 40// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero( 41// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) { 42// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 43// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 44// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 45// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index 46// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 47// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index 48// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64 49// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64 50// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64 51// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index 52// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> 53// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32> 54// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) { 55// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index 56// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> 57// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> 58func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) { 59 %c0 = arith.constant 0 : index 60 %c2 = arith.constant 2 : index 61 %c3 = arith.constant 3 : index 62 %pad = arith.constant 0 : i32 63 %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> 64 %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> 65 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 66 return 67} 68 69// ----- 70 71// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad( 72// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>, 73// CHECK-SAME: %[[PAD:.*]]: i32) { 74// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> 75// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 76// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 77// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 78// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index 79// CHECK-DAG: %[[NUM_COLS:.*]] = arith.constant 2 : index 80// CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32 81// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 82// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index 83// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE]]) -> (vector<[4]x[4]xi32>) { 84// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index 85// CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32 86// CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32 87// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index 88// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1> 89// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index 90// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32> 91// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> 92// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32> 93// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> 94func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) { 95 %c0 = arith.constant 0 : index 96 %c2 = arith.constant 2 : index 97 %c3 = arith.constant 3 : index 98 %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> 99 %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> 100 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 101 return 102} 103 104// ----- 105 106func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %mask : vector<[4]x[4]xi1>) { 107 %c0 = arith.constant 0 : index 108 %pad = arith.constant 0 : i32 109 // expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}} 110 %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> 111 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 112 return 113} 114 115// ----- 116 117func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?xi32>, %pad : i32, %mask : vector<[4]x[4]xi1>) { 118 %c0 = arith.constant 0 : index 119 // expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}} 120 %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> 121 "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> () 122 return 123} 124 125//===----------------------------------------------------------------------===// 126// arm_sme.tile_store 127//===----------------------------------------------------------------------===// 128 129// ----- 130 131// CHECK-LABEL: func.func @arm_sme_tile_store_hor( 132// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, 133// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) { 134// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 135// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 136// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 137// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 138// CHECK-DAG: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1> 139// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index 140// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { 141// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index 142// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> 143func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) { 144 %c0 = arith.constant 0 : index 145 arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32> 146 return 147} 148 149// ----- 150 151// CHECK-LABEL: @arm_sme_tile_store_ver 152// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> 153func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) { 154 %c0 = arith.constant 0 : index 155 arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> 156 return 157} 158 159// ----- 160 161// CHECK-LABEL: func.func @arm_sme_tile_store_hor_with_mask( 162// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, 163// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) { 164// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 165// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 166// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index 167// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index 168// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale 169// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index 170// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64 171// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64 172// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64 173// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index 174// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> 175// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] { 176// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index 177// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> 178func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) { 179 %c0 = arith.constant 0 : index 180 %c2 = arith.constant 2 : index 181 %c3 = arith.constant 3 : index 182 %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> 183 arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32> 184 return 185} 186