Skip to content

Instantly share code, notes, and snippets.

@hanhanW
Created June 7, 2024 00:59
Show Gist options
  • Save hanhanW/b3652f5887b93fb8f0df6c6c39c1ef87 to your computer and use it in GitHub Desktop.
Save hanhanW/b3652f5887b93fb8f0df6c6c39c1ef87 to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1) -> (0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> ()>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map6 = affine_map<(d0) -> (d0)>
#map7 = affine_map<(d0, d1) -> (d1)>
#map8 = affine_map<(d0, d1) -> ()>
#map9 = affine_map<()[s0] -> (s0 * 9)>
#map10 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module {
util.func public @main_graph$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> tensor<1x9x1024xf32> attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant dense<[false, true]> : tensor<2xi1>
%cst_0 = arith.constant dense<1> : tensor<2xi32>
%cst_1 = arith.constant dense<2> : tensor<i32>
%cst_2 = arith.constant dense<[1, -1]> : tensor<2xi32>
%cst_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>
%cst_4 = arith.constant dense<1.000000e+00> : tensor<f32>
%cst_5 = arith.constant dense_resource<__elided__> : tensor<50265x1024xf32>
%cst_6 = arith.constant dense_resource<__elided__> : tensor<1026x1024xf32>
%c0_i32 = arith.constant 0 : i32
%c9 = arith.constant 9 : index
%cst_7 = arith.constant 0.000000e+00 : f32
%c1026_i32 = arith.constant 1026 : i32
%c50265_i32 = arith.constant 50265 : i32
%cst_8 = arith.constant 1.024000e+03 : f32
%0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<1x9xi64> as tensor<1x9xi32>
%1 = tensor.empty() : tensor<1x9xi1>
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<1x9xi32>) outs(%1 : tensor<1x9xi1>) {
^bb0(%in: i32, %out: i1):
%31 = arith.cmpi slt, %in, %c0_i32 : i32
linalg.yield %31 : i1
} -> tensor<1x9xi1>
%3 = tensor.empty() : tensor<1x9xi32>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<1x9xi32>) outs(%3 : tensor<1x9xi32>) {
^bb0(%in: i32, %out: i32):
%31 = arith.addi %in, %c50265_i32 : i32
linalg.yield %31 : i32
} -> tensor<1x9xi32>
%5 = linalg.generic {indexing_maps = [#map, #map, #map, #map1], iterator_types = ["parallel", "parallel"]} ins(%2, %4, %0 : tensor<1x9xi1>, tensor<1x9xi32>, tensor<1x9xi32>) outs(%3 : tensor<1x9xi32>) {
^bb0(%in: i1, %in_12: i32, %in_13: i32, %out: i32):
%31 = arith.select %in, %in_12, %in_13 : i32
linalg.yield %31 : i32
} -> tensor<1x9xi32>
%collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<1x9xi32> into tensor<9xi32>
%6 = tensor.empty() : tensor<9x1024xf32>
%7 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<9xi32>) outs(%6 : tensor<9x1024xf32>) {
^bb0(%in: i32, %out: f32):
%31 = arith.index_cast %in : i32 to index
%32 = linalg.index 1 : index
%extracted_12 = tensor.extract %cst_5[%31, %32] : tensor<50265x1024xf32>
linalg.yield %extracted_12 : f32
} -> tensor<9x1024xf32>
%expanded = tensor.expand_shape %7 [[0, 1], [2]] output_shape [1, 9, 1024] : tensor<9x1024xf32> into tensor<1x9x1024xf32>
%8 = tensor.empty() : tensor<1x9x1024xf32>
%9 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %cst_4 : tensor<1x9x1024xf32>, tensor<f32>) outs(%8 : tensor<1x9x1024xf32>) {
^bb0(%in: f32, %in_12: f32, %out: f32):
%31 = arith.mulf %in, %in_12 : f32
linalg.yield %31 : f32
} -> tensor<1x9x1024xf32>
%10 = tensor.empty() : tensor<2xi32>
%11 = linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_2 : tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) outs(%10 : tensor<2xi32>) {
^bb0(%in: i1, %in_12: i32, %in_13: i32, %out: i32):
%31 = arith.select %in, %in_12, %in_13 : i32
linalg.yield %31 : i32
} -> tensor<2xi32>
%extracted_slice = tensor.extract_slice %11[0] [1] [1] : tensor<2xi32> to tensor<1xi32>
%collapsed_9 = tensor.collapse_shape %extracted_slice [] : tensor<1xi32> into tensor<i32>
%extracted = tensor.extract %collapsed_9[] : tensor<i32>
%12 = arith.index_cast %extracted : i32 to index
%13 = tensor.empty(%12) : tensor<?x9xi32>
%14 = linalg.generic {indexing_maps = [#map7, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_3 : tensor<9xi32>) outs(%13 : tensor<?x9xi32>) {
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
} -> tensor<?x9xi32>
%15 = linalg.generic {indexing_maps = [#map1, #map8, #map1], iterator_types = ["parallel", "parallel"]} ins(%14, %cst_1 : tensor<?x9xi32>, tensor<i32>) outs(%13 : tensor<?x9xi32>) {
^bb0(%in: i32, %in_12: i32, %out: i32):
%31 = arith.addi %in, %in_12 : i32
linalg.yield %31 : i32
} -> tensor<?x9xi32>
%16 = tensor.empty(%12) : tensor<?x9xi1>
%17 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%15 : tensor<?x9xi32>) outs(%16 : tensor<?x9xi1>) {
^bb0(%in: i32, %out: i1):
%31 = arith.cmpi slt, %in, %c0_i32 : i32
linalg.yield %31 : i1
} -> tensor<?x9xi1>
%18 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%15 : tensor<?x9xi32>) outs(%13 : tensor<?x9xi32>) {
^bb0(%in: i32, %out: i32):
%31 = arith.addi %in, %c1026_i32 : i32
linalg.yield %31 : i32
} -> tensor<?x9xi32>
%19 = linalg.generic {indexing_maps = [#map1, #map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%17, %18, %15 : tensor<?x9xi1>, tensor<?x9xi32>, tensor<?x9xi32>) outs(%13 : tensor<?x9xi32>) {
^bb0(%in: i1, %in_12: i32, %in_13: i32, %out: i32):
%31 = arith.select %in, %in_12, %in_13 : i32
linalg.yield %31 : i32
} -> tensor<?x9xi32>
%collapsed_10 = tensor.collapse_shape %19 [[0, 1]] : tensor<?x9xi32> into tensor<?xi32>
%20 = affine.apply #map9()[%12]
%21 = tensor.empty(%20) : tensor<?x1024xf32>
%22 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_10 : tensor<?xi32>) outs(%21 : tensor<?x1024xf32>) {
^bb0(%in: i32, %out: f32):
%31 = arith.index_cast %in : i32 to index
%32 = linalg.index 1 : index
%extracted_12 = tensor.extract %cst_6[%31, %32] : tensor<1026x1024xf32>
linalg.yield %extracted_12 : f32
} -> tensor<?x1024xf32>
%23 = arith.divui %20, %c9 : index
%expanded_11 = tensor.expand_shape %22 [[0, 1], [2]] output_shape [%23, 9, 1024] : tensor<?x1024xf32> into tensor<?x9x1024xf32>
%24 = linalg.generic {indexing_maps = [#map3, #map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %expanded_11 : tensor<1x9x1024xf32>, tensor<?x9x1024xf32>) outs(%8 : tensor<1x9x1024xf32>) {
^bb0(%in: f32, %in_12: f32, %out: f32):
%31 = arith.addf %in, %in_12 : f32
linalg.yield %31 : f32
} -> tensor<1x9x1024xf32>
%25 = tensor.empty(%12) : tensor<?x9x1xf32>
%26 = linalg.fill ins(%cst_7 : f32) outs(%25 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
%27 = linalg.generic {indexing_maps = [#map5, #map10], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<1x9x1024xf32>) outs(%26 : tensor<?x9x1xf32>) {
^bb0(%in: f32, %out: f32):
%31 = arith.addf %in, %out : f32
linalg.yield %31 : f32
} -> tensor<?x9x1xf32>
%28 = linalg.generic {indexing_maps = [#map10, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%27 : tensor<?x9x1xf32>) outs(%25 : tensor<?x9x1xf32>) {
^bb0(%in: f32, %out: f32):
%31 = arith.divf %in, %cst_8 : f32
linalg.yield %31 : f32
} -> tensor<?x9x1xf32>
%29 = linalg.generic {indexing_maps = [#map10, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%28 : tensor<?x9x1xf32>) outs(%8 : tensor<1x9x1024xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x9x1024xf32>
%30 = linalg.generic {indexing_maps = [#map5, #map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%24, %29 : tensor<1x9x1024xf32>, tensor<1x9x1024xf32>) outs(%8 : tensor<1x9x1024xf32>) {
^bb0(%in: f32, %in_12: f32, %out: f32):
%31 = arith.subf %in, %in_12 : f32
linalg.yield %31 : f32
} -> tensor<1x9x1024xf32>
util.return %30 : tensor<1x9x1024xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment