ConvertStandardToSPIRV.cpp
15.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Convert constant operation with IndexType return to SPIR-V constant
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
/// attribute are consistent.
// TODO(ravishankarm) : This should be moved into DRR.
class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert floating-point comparison operations to SPIR-V dialect.
class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
public:
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert integer binary operations to SPIR-V operations. Cannot use
/// tablegen for this. If the integer operation is on variables of IndexType,
/// the type of the return value of the replacement operation differs from
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
// TODO(ravishankarm) : This should be moved into DRR.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
this->typeConverter.convertType(operation.getResult().getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
operation, resultType, operands, ArrayRef<NamedAttribute>());
return this->matchSuccess();
}
};
/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
// TODO(ravishankarm) : This should be moved into DRR.
class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert return -> spv.Return.
// TODO(ravishankarm) : This should be moved into DRR.
class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
public:
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert select -> spv.Select
// TODO(ravishankarm) : This should be moved into DRR.
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert store -> spv.StoreOp. The operands of the replaced operation are
/// of IndexType while that of the replacement operation are of type i32. This
/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : This should be moved into DRR.
class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// Utility functions for operation conversion
//===----------------------------------------------------------------------===//
/// Performs the index computation to get to the element pointed to by
/// `indices` using the layout map of `baseType`.
// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
// MemRefType with AffineMap that has static strides. Handle dynamic strides
static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
SPIRVTypeConverter &typeConverter,
Location loc, MemRefType origBaseType,
Value basePtr,
ArrayRef<Value> indices) {
// Get base and offset of the MemRefType and verify they are static.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
return nullptr;
}
auto indexType = typeConverter.getIndexType(builder.getContext());
Value ptrLoc = nullptr;
assert(indices.size() == strides.size());
for (auto index : enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
ptrLoc =
(ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
: update);
}
SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, 0)));
linearizedIndices.push_back(ptrLoc);
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
//===----------------------------------------------------------------------===//
// ConstantOp with index type.
//===----------------------------------------------------------------------===//
PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!constIndexOp.getResult().getType().isa<IndexType>()) {
return matchFailure();
}
// The attribute has index type which is not directly supported in
// SPIR-V. Get the integer value and create a new IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
return matchFailure();
}
// Use the bitwidth set in the value attribute to decide the result type
// of the SPIR-V constant operation since SPIR-V does not support index
// types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
return matchFailure();
}
auto spirvConstType =
typeConverter.convertType(constIndexOp.getResult().getType());
auto spirvConstVal =
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
spirvConstVal);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
PatternMatchResult
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpFOpOperandAdaptor cmpFOpOperands(operands);
switch (cmpFOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
cmpFOpOperands.lhs(), \
cmpFOpOperands.rhs()); \
return matchSuccess();
// Ordered.
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
// Unordered.
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
#undef DISPATCH
default:
break;
}
return matchFailure();
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
PatternMatchResult
CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
return matchSuccess();
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
#undef DISPATCH
}
return matchFailure();
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
loadOp.memref().getType().cast<MemRefType>(),
loadOperands.memref(), loadOperands.indices());
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
PatternMatchResult
ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (returnOp.getNumOperands()) {
return matchFailure();
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
PatternMatchResult
SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
PatternMatchResult
StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
auto storePtr =
getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
storeOp.memref().getType().cast<MemRefType>(),
storeOperands.memref(), storeOperands.indices());
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value(),
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
namespace {
/// Import the Standard Ops to SPIR-V Patterns.
#include "StandardToSPIRV.cpp.inc"
} // namespace
namespace mlir {
void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
// Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
IntegerOpConversion<SignedRemIOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
context, typeConverter);
}
} // namespace mlir