LowerToAffineLoops.cpp
12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc.getOperation()->getBlock();
alloc.getOperation()->moveBefore(&parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
return alloc;
}
/// This defines the function type used to process an iteration of a lowered
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
/// to the operands of the input operation, and the set of loop induction
/// variables for the iteration. It returns a value to store at the current
/// index of the iteration.
using LoopIterationFn = function_ref<Value(PatternRewriter &rewriter,
ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs)>;
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Create an empty affine loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
for (auto dim : tensorType.getShape()) {
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
loop.getBody()->clear();
loopIvs.push_back(loop.getInductionVar());
// Terminate the loop body and update the rewriter insertion point to the
// beginning of the loop.
rewriter.setInsertionPointToStart(loop.getBody());
rewriter.create<AffineTerminatorOp>(loc);
rewriter.setInsertionPointToStart(loop.getBody());
}
// Generate a call to the processing function with the rewriter, the memref
// operands, and the loop induction variables. This function will return the
// value to store at the current index.
Value valueToStore = processIteration(rewriter, operands, loopIvs);
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
llvm::makeArrayRef(loopIvs));
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
}
namespace {
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs) {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Constant operations
//===----------------------------------------------------------------------===//
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// We will be generating constant indices up-to the largest dimension.
// Create these constants up-front to avoid large amounts of redundant
// operations.
auto valueShape = memRefType.getShape();
SmallVector<Value, 8> constantIndices;
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<AffineStoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};
// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Transpose operations
//===----------------------------------------------------------------------===//
struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
ArrayRef<Value> loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// ToyToAffineLoweringPass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to affine loops of the toy operations that are
/// computationally intensive (like matmul for example...) while keeping the
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void ToyToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<toy::ToyDialect>();
target.addLegalOp<toy::PrintOp>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
OwningRewritePatternList patterns;
patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
/// for a subset of the Toy IR (e.g. matmul).
std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
return std::make_unique<ToyToAffineLoweringPass>();
}