ConvertShapeConstraints.cpp 5.26 KB
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"

#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"

using namespace mlir;

namespace {
class ConvertCstrBroadcastableOp
    : public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
                                PatternRewriter &rewriter) const override {
    if (op.getType().isa<shape::ShapeType>() ||
        op.lhs().getType().isa<shape::ShapeType>() ||
        op.rhs().getType().isa<shape::ShapeType>()) {
      return rewriter.notifyMatchFailure(
          op, "cannot convert error-propagating shapes");
    }

    auto loc = op.getLoc();
    Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
    Value one = rewriter.create<ConstantIndexOp>(loc, 1);

    // Find smaller and greater rank and extent tensor.
    Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
    Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
    Value lhsRankULE =
        rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
    Type indexTy = rewriter.getIndexType();
    Value lesserRank =
        rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
    Value greaterRank =
        rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
    Value lesserRankOperand =
        rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
    Value greaterRankOperand =
        rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());

    Value rankDiff =
        rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);

    // Generate code to compare the shapes extent by extent, and emit errors for
    // non-broadcast-compatible shapes.
    // Two extents are broadcast-compatible if
    // 1. they are both equal, or
    // 2. at least one of them is 1.

    rewriter.create<scf::ForOp>(
        loc, rankDiff, greaterRank, one, llvm::None,
        [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
          Value greaterRankOperandExtent = b.create<ExtractElementOp>(
              loc, greaterRankOperand, ValueRange{iv});
          Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
          Value lesserRankOperandExtent = b.create<ExtractElementOp>(
              loc, lesserRankOperand, ValueRange{ivShifted});

          Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
              loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
          Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
              loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
          Value extentsAgree =
              b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
                               lesserRankOperandExtent);
          auto broadcastIsValid =
              b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
                             b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
                                            lesserRankOperandExtentIsOne));
          b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
          b.create<scf::YieldOp>(loc);
        });

    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
    return success();
  }
};
} // namespace

namespace {
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
public:
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(shape::CstrRequireOp op,
                                PatternRewriter &rewriter) const override {
    rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
    return success();
  }
};
} // namespace

void mlir::populateConvertShapeConstraintsConversionPatterns(
    OwningRewritePatternList &patterns, MLIRContext *ctx) {
  patterns.insert<ConvertCstrBroadcastableOp>(ctx);
  patterns.insert<ConvertCstrRequireOp>(ctx);
}

namespace {
// This pass eliminates shape constraints from the program, converting them to
// eager (side-effecting) error handling code. After eager error handling code
// is emitted, witnesses are satisfied, so they are replace with
// `shape.const_witness true`.
class ConvertShapeConstraints
    : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
  void runOnOperation() override {
    auto func = getOperation();
    auto *context = &getContext();

    OwningRewritePatternList patterns;
    populateConvertShapeConstraintsConversionPatterns(patterns, context);

    if (failed(applyPatternsAndFoldGreedily(func, patterns)))
      return signalPassFailure();
  }
};
} // namespace

std::unique_ptr<OperationPass<FuncOp>>
mlir::createConvertShapeConstraintsPass() {
  return std::make_unique<ConvertShapeConstraints>();
}