ConvertShapeConstraints.cpp
5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
using namespace mlir;
namespace {
class ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
if (op.getType().isa<shape::ShapeType>() ||
op.lhs().getType().isa<shape::ShapeType>() ||
op.rhs().getType().isa<shape::ShapeType>()) {
return rewriter.notifyMatchFailure(
op, "cannot convert error-propagating shapes");
}
auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
// Generate code to compare the shapes extent by extent, and emit errors for
// non-broadcast-compatible shapes.
// Two extents are broadcast-compatible if
// 1. they are both equal, or
// 2. at least one of them is 1.
rewriter.create<scf::ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
loc, greaterRankOperand, ValueRange{iv});
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAgree =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
auto broadcastIsValid =
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne));
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
b.create<scf::YieldOp>(loc);
});
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
} // namespace
namespace {
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
} // namespace
void mlir::populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
patterns.insert<ConvertCstrRequireOp>(ctx);
}
namespace {
// This pass eliminates shape constraints from the program, converting them to
// eager (side-effecting) error handling code. After eager error handling code
// is emitted, witnesses are satisfied, so they are replace with
// `shape.const_witness true`.
class ConvertShapeConstraints
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
populateConvertShapeConstraintsConversionPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createConvertShapeConstraintsPass() {
return std::make_unique<ConvertShapeConstraints>();
}