fusion.mlir 25.8 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
// RUN: mlir-opt %s -linalg-fusion | FileCheck %s

#map0 = affine_map<(d0) -> (d0 + 2)>
#map1 = affine_map<(d0) -> (d0 + 4)>
#map2 = affine_map<(d0) -> (d0 + 3)>
#map3 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
#map4 = affine_map<(d0) -> (d0)>
#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#map6 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>

func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>, %B: memref<?x?xf32, offset: 0, strides: [?, 1]>, %C: memref<?x?xf32, offset: 0, strides: [?, 1]>, %D: memref<?x?xf32, offset: 0, strides: [?, 1]>, %E: memref<?x?xf32, offset: 0, strides: [?, 1]>) -> memref<?x?xf32, offset: 0, strides: [?, 1]> {
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %0 = dim %A, 0 : memref<?x?xf32, offset: 0, strides: [?, 1]>
  %1 = dim %A, 1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
  %2 = dim %B, 1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, 1]>, memref<?x?xf32, offset: 0, strides: [?, 1]>, memref<?x?xf32, offset: 0, strides: [?, 1]>
  %c1 = constant 1 : index
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// CHECK-LABEL: func @f1
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// No RAW dependences, the pass does not fuse RAR atm.
//      CHECK: linalg.matmul
//      CHECK: loop.for
//      CHECK:   loop.for
//      CHECK:     loop.for
//      CHECK:       linalg.matmul

func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %0 = dim %C, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %C, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f2
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
//   CHECK-DAG:   %[[C_0:.*]] = dim %[[C]], 0 : memref<?x?xf32, #[[strided2D]]>
//   CHECK-DAG:   %[[C_1:.*]] = dim %[[C]], 1 : memref<?x?xf32, #[[strided2D]]>
//   CHECK-DAG:   %[[D_1:.*]] = dim %[[D]], 1 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
//       CHECK:         linalg.matmul
//       CHECK:         linalg.matmul

func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %0 = dim %D, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %C, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %7 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f3
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
//          CHECK:   %[[D_0:.*]] = dim %[[D]], 0 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   %[[D_1:.*]] = dim %[[D]], 1 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   %[[C_1:.*]] = dim %[[C]], 1 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
//          CHECK:         linalg.matmul
//          CHECK:         linalg.matmul

func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %B, %D) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %0 = dim %C, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %C, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f4
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
//          CHECK:   %[[C_0:.*]] = dim %[[C]], 0 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   %[[C_1:.*]] = dim %[[C]], 1 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   %[[D_1:.*]] = dim %[[D]], 1 : memref<?x?xf32, #[[strided2D]]>
//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// Fuse D then fuse C, no false dependence prevent it.
//          CHECK:         linalg.matmul
//          CHECK:         linalg.matmul
//          CHECK:         linalg.matmul

func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %0 = dim %B, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %D, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%C, %B, %D) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %1 step %c2 {
    loop.for %arg6 = %c0 to %0 step %c3 {
      loop.for %arg7 = %c0 to %2 step %c4 {
        %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f5
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
//      CHECK-DAG:   %[[B_1:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
//      CHECK-DAG:   %[[D_0:.*]] = dim %[[D]], 0 : memref<?x?xf32, #[[strided2D]]>
//      CHECK-DAG:   %[[D_1:.*]] = dim %[[D]], 1 : memref<?x?xf32, #[[strided2D]]>
// Don't fuse C due to false dependence, note that this is too conservative though.
//          CHECK:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} {
//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
//          CHECK:         linalg.matmul
//          CHECK:         linalg.matmul

func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %0 = dim %C, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %C, %E) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %C, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %1 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %0 step %c4 {
        %3 = affine.apply #map0(%arg5)
        %4 = affine.apply #map1(%arg7)
        %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %6 = affine.apply #map2(%arg6)
        %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f6
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// Cannot fuse C due to interleaved read of C that would be bypassed.
// Cannot fuse E (WAW).
//   CHECK:   linalg.matmul
//   CHECK:   linalg.matmul
//   CHECK:   loop.for
//   CHECK:     loop.for
//   CHECK:       loop.for
//   CHECK:         linalg.matmul
// CHECK-NOT:       linalg.matmul

func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %0 = dim %A, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %A, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %C, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %3 = dim %C, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %4 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %C, %E) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %7 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %9 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  loop.for %arg5 = %c0 to %3 step %c2 {
    loop.for %arg6 = %c0 to %4 step %c3 {
      loop.for %arg7 = %c0 to %2 step %c4 {
        %7 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %9 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f7
//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
//       CHECK:   %[[A_0:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   %[[A_1:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   %[[C_1:.*]] = dim %[[C]], 1 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   %[[C_0:.*]] = dim %[[C]], 0 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   %[[D_1:.*]] = dim %[[D]], 1 : memref<?x?xf32, #[[strided2D]]>
//       CHECK:   linalg.matmul(%[[A]], %[[C]], %[[E]])
//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
//       CHECK:         linalg.matmul
//       CHECK:         linalg.matmul
//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
//       CHECK:         linalg.matmul
//   CHECK-NOT:         linalg.matmul

func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %0 = dim %A, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %A, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %C, %D) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %2 = dim %D, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg5 = %c0 to %0 step %c2 {
    loop.for %arg6 = %c0 to %2 step %c3 {
      loop.for %arg7 = %c0 to %1 step %c4 {
        %3 = affine.apply #map0(%arg5)
        %4 = affine.apply #map1(%arg7)
        %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %6 = affine.apply #map2(%arg6)
        %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
        linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
      }
    }
  }
  return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f8
//       CHECK:   (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}})
//   CHECK:   linalg.matmul
//   CHECK:   linalg.matmul
//   CHECK:   loop.for
//   CHECK:     loop.for
//   CHECK:       loop.for
//   CHECK:         linalg.matmul
// CHECK-NOT:       linalg.matmul

#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
  args_in = 2,
  args_out = 1,
  indexing_maps = [#id_2d, #id_2d, #id_2d],
  iterator_types = ["parallel", "parallel"]
}
func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  linalg.generic #pointwise_2d_trait %A, %A, %B {
  ^bb0(%E: f32, %arg5: f32, %arg6: f32):   // no predecessors
    %2 = addf %E, %arg5 : f32
    linalg.yield %2 : f32
  }: memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
  %0 = dim %B, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  %1 = dim %B, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
  loop.for %arg4 = %c0 to %0 step %c2 {
    loop.for %arg5 = %c0 to %1 step %c3 {
      %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      linalg.generic #pointwise_2d_trait %4, %5, %6 {
      ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):       // no predecessors
        %7 = mulf %arg6, %arg7 : f32
        linalg.yield %7 : f32
      }: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
    }
  }
  return
}
// CHECK-LABEL: func @pointwise
//       CHECK:   loop.for
//       CHECK:     loop.for
//   CHECK-NOT:   loop.for
//       CHECK:       linalg.generic
//       CHECK:         addf
//       CHECK:       linalg.generic
//       CHECK:         mulf

func @pointwise_no_view(%M: index, %N: index) {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c3 = constant 3 : index
  %c2 = constant 2 : index
  %A = alloc (%M, %N): memref<?x?xf32>
  %B = alloc (%M, %N): memref<?x?xf32>
  %C = alloc (%M, %N): memref<?x?xf32>
  %D = alloc (%M, %N): memref<?x?xf32>
  %E = alloc (%M, %N): memref<?x?xf32>
  linalg.generic #pointwise_2d_trait %A, %A, %B {
  ^bb0(%e: f32, %arg5: f32, %arg6: f32):   // no predecessors
    %2 = addf %e, %arg5 : f32
    linalg.yield %2 : f32
  }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
  %0 = dim %B, 0 : memref<?x?xf32>
  %1 = dim %B, 1 : memref<?x?xf32>
  loop.for %arg4 = %c0 to %0 step %c2 {
    loop.for %arg5 = %c0 to %1 step %c3 {
      %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
      linalg.generic #pointwise_2d_trait %4, %5, %6 {
      ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):       // no predecessors
        %7 = mulf %arg6, %arg7 : f32
        linalg.yield %7 : f32
      }: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
    }
  }
  return
}
// CHECK-LABEL: func @pointwise_no_view
//       CHECK:   loop.for
//       CHECK:     loop.for
//   CHECK-NOT:   loop.for
//       CHECK:       linalg.generic
//       CHECK:         addf
//       CHECK:       linalg.generic
//       CHECK:         mulf

func @indexed_generic_test(%A: memref<?x?xf32>,
                           %B: memref<?x?xf32>,
                           %C: memref<?x?xf32>,
                           %D: memref<?x?xf32>) {
  linalg.generic #pointwise_2d_trait %A, %B, %C {
  ^bb0(%e: f32, %arg5: f32, %arg6: f32):   // no predecessors
    %2 = addf %e, %arg5 : f32
    linalg.yield %2 : f32
  }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c25 = constant 25 : index
  %c10 = constant 10 : index
  %0 = dim %C, 0 : memref<?x?xf32>
  %1 = dim %C, 1 : memref<?x?xf32>
  %2 = dim %D, 0 : memref<?x?xf32>
  %3 = dim %D, 1 : memref<?x?xf32>
  loop.for %arg2 = %c0 to %0 step %c10 {
    loop.for %arg3 = %c0 to %1 step %c25 {
      %4 = std.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] :
          memref<?x?xf32> to memref<?x?xf32, #map5>
      %5 = std.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] :
          memref<?x?xf32> to memref<?x?xf32, #map5>
      linalg.indexed_generic {
        indexing_maps = [#map6, #map6],
        iterator_types = ["parallel", "parallel"],
        args_in = 1,
        args_out = 1
      } %4, %5 {
      ^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32):
        %6 = addi %arg4, %arg2 : index
        %7 = addi %arg5, %arg3 : index
        %8 = index_cast %6 : index to i32
        %9 = sitofp %8 : i32 to f32
        %10 = index_cast %7 : index to i32
        %11 = sitofp %10 : i32 to f32
        %12 = addf %9, %11 : f32
        linalg.yield %12 : f32
      }: memref<?x?xf32, #map5>, memref<?x?xf32, #map5>
    }
  }
  return
}
// CHECK-LABEL: func @indexed_generic_test
//       CHECK:   loop.for
//       CHECK:     loop.for
//   CHECK-NOT:   loop.for
//       CHECK:       linalg.generic
//       CHECK:         addf
//       CHECK:       linalg.indexed_generic
//       CHECK:         index_cast