ReferenceImplGen.cpp
3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// ReferenceImplGen uses the description of operations to generate reference
// implementations for the ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Operator;
static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("Reference implementation file", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n"
<< " using namespace ::mlir::edsc;\n"
<< "if (false) {}";
for (auto *def : defs) {
Operator op(def);
auto referenceImplGenerator = def->getValueInit("referenceImplementation");
if (!referenceImplGenerator)
continue;
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
<< " edsc::ScopedContext scope(f);\n";
for (auto en : llvm::enumerate(op.getOperands())) {
os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
"(void)arg_{0};\n",
en.value().name, en.index());
// TODO(jpienaar): this is generally incorrect, not all args are memref
// in the general case.
os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
"(void)view_{0};\n",
en.value().name, en.index());
}
unsigned numOperands = op.getNumOperands();
unsigned numResults = op.getNumResults();
for (unsigned idx = 0; idx < numResults; ++idx) {
os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
"(void)arg_{0};\n",
op.getResult(idx).name, numOperands + idx);
// TODO(jpienaar): this is generally incorrect, not all args are memref
// in the general case.
os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
"(void)view_{0};\n",
op.getResult(idx).name, numOperands + idx);
}
// Print the EDSC.
os << referenceImplGenerator->getAsUnquotedString() << "\n";
os.indent(2) << "f->print(llvm::outs());\n\n";
os << "}";
}
os << " else {\n";
os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
os.indent(2) << "return;\n";
os << "}\n";
os << "}\n";
}
static mlir::GenRegistration
genRegister("gen-reference-implementations",
"Generate reference implementations",
[](const RecordKeeper &records, raw_ostream &os) {
emitReferenceImplementations(records, os);
return false;
});