1// RUN: mlir-opt %s \ 2// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ 3// RUN: -lower-vector-mask \ 4// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ 5// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ 6// RUN: %mcr_aarch64_cmd \ 7// RUN: -e=entry -entry-point-result=void \ 8// RUN: -march=aarch64 -mattr="+sve,+sme" \ 9// RUN: -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib | \ 10// RUN: FileCheck %s 11 12func.func @entry() { 13 %c0 = arith.constant 0 : index 14 %c4 = arith.constant 4 : index 15 %step = arith.constant 1 : index 16 17 %c123_f32 = arith.constant 123.0 : f32 18 19 // "svl" refers to the Streaming Vector Length and "svl_s" the number of 20 // 32-bit elements in a vector of SVL bits. 21 %svl_s = arm_sme.streaming_vl <word> 22 23 %tile_init = bufferization.alloc_tensor(%svl_s, %svl_s) : tensor<?x?xf32> 24 25 // Initialize tile with "123.0". 26 // TODO: this could be simplified to tensor.splat + tensor.insert_slice once 27 // splat supports dynamically shaped tensors. 28 %tile_0 = scf.for %i = %c0 to %svl_s step %step iter_args(%tile_partial = %tile_init) -> tensor<?x?xf32> { 29 %inner_tile = scf.for %j = %c0 to %svl_s step %step iter_args(%inner_tile_partial = %tile_partial) -> tensor<?x?xf32> { 30 %tile_update = tensor.insert %c123_f32 into %inner_tile_partial[%i, %j] : tensor<?x?xf32> 31 scf.yield %tile_update : tensor<?x?xf32> 32 } 33 scf.yield %inner_tile : tensor<?x?xf32> 34 } 35 36 // Print tile after initialization. The smallest SVL is 128-bits so the tile 37 // will be at least 4x4xf32. 38 // 39 // CHECK: ( 123, 123, 123, 123 40 // CHECK-NEXT: ( 123, 123, 123, 123 41 // CHECK-NEXT: ( 123, 123, 123, 123 42 // CHECK-NEXT: ( 123, 123, 123, 123 43 scf.for %i = %c0 to %svl_s step %step { 44 vector.print punctuation <open> 45 scf.for %j = %c0 to %svl_s step %step { 46 %element = tensor.extract %tile_0[%i, %j] : tensor<?x?xf32> 47 vector.print %element : f32 punctuation <no_punctuation> 48 49 // Print comma unless last element. 50 %c1_index = arith.constant 1 : index 51 %last_i = arith.subi %svl_s, %c1_index : index 52 %isNotLastIter = arith.cmpi ult, %j, %last_i : index 53 scf.if %isNotLastIter { 54 vector.print punctuation <comma> 55 } 56 } 57 vector.print punctuation <close> 58 vector.print punctuation <newline> 59 } 60 61 // Fill tile with pi. 62 %pi = arith.constant 3.14 : f32 63 %tile_1 = linalg.fill ins(%pi : f32) outs(%tile_0 : tensor<?x?xf32>) -> tensor<?x?xf32> 64 65 // Print tile after filling with pi. The smallest SVL is 128-bits so the tile 66 // will be at least 4x4xf32. 67 // 68 // CHECK: ( 3.14, 3.14, 3.14, 3.14 69 // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14 70 // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14 71 // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14 72 scf.for %i = %c0 to %svl_s step %step { 73 vector.print punctuation <open> 74 scf.for %j = %c0 to %svl_s step %step { 75 %element = tensor.extract %tile_1[%i, %j] : tensor<?x?xf32> 76 vector.print %element : f32 punctuation <no_punctuation> 77 78 // Print comma unless last element. 79 %c1_index = arith.constant 1 : index 80 %last_i = arith.subi %svl_s, %c1_index : index 81 %isNotLastIter = arith.cmpi ult, %j, %last_i : index 82 scf.if %isNotLastIter { 83 vector.print punctuation <comma> 84 } 85 } 86 vector.print punctuation <close> 87 vector.print punctuation <newline> 88 } 89 90 // CHECK: SME: END OF TEST OUTPUT 91 vector.print str "SME: END OF TEST OUTPUT\n" 92 93 return 94} 95 96module attributes {transform.with_named_sequence} { 97 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 98 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op 99 transform.structured.vectorize %0 vector_sizes [[4], [4]] : !transform.any_op 100 transform.yield 101 } 102} 103