OpenMPToLLVM.cpp
2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
using namespace mlir;
namespace {
struct ParallelOpConversion : public ConvertToLLVMPattern {
explicit ParallelOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context,
typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto curOp = cast<omp::ParallelOp>(op);
auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(),
operands, curOp.getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
return failure();
rewriter.eraseOp(op);
return success();
}
};
} // namespace
void mlir::populateOpenMPToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
patterns.insert<ParallelOpConversion>(context, converter);
}
namespace {
struct ConvertOpenMPToLLVMPass
: public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
void runOnOperation() override;
};
} // namespace
void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
MLIRContext *context = &getContext();
// Convert to OpenMP operations with LLVM IR dialect
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp>(
[&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
omp::BarrierOp, omp::TaskwaitOp>();
if (failed(applyPartialConversion(module, target, patterns)))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
return std::make_unique<ConvertOpenMPToLLVMPass>();
}