Function.cpp
8.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
//===- Function.cpp - MLIR Function Classes -------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Function Operation.
//===----------------------------------------------------------------------===//
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) {
OperationState state(location, "func");
Builder builder(location->getContext());
FuncOp::build(&builder, state, name, type, attrs);
return cast<FuncOp>(Operation::create(state));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
iterator_range<dialect_attr_iterator> attrs) {
SmallVector<NamedAttribute, 8> attrRef(attrs);
return create(location, name, type, llvm::makeArrayRef(attrRef));
}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs) {
FuncOp func = create(location, name, type, attrs);
func.setAllArgAttrs(argAttrs);
return func;
}
void FuncOp::build(Builder *builder, OperationState &result, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder->getStringAttr(name));
result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
result.addRegion();
}
void FuncOp::build(Builder *builder, OperationState &result, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs) {
build(builder, result, name, type, attrs);
assert(type.getNumInputs() == argAttrs.size());
SmallString<8> argAttrName;
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (auto argDict = argAttrs[i].getDictionary())
result.addAttribute(getArgAttrName(i, argAttrName), argDict);
}
/// Parsing/Printing methods.
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
ArrayRef<Type> results, impl::VariadicFlag,
std::string &) {
return builder.getFunctionType(argTypes, results);
};
return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
FunctionType fnType = getType();
impl::printFunctionLikeOp(p, *this, fnType.getInputs(), /*isVariadic=*/false,
fnType.getResults());
}
LogicalResult FuncOp::verify() {
// If this function is external there is nothing to do.
if (isExternal())
return success();
// Verify that the argument list of the function and the arg list of the entry
// block line up. The trait already verified that the number of arguments is
// the same between the signature and the block.
auto fnInputTypes = getType().getInputs();
Block &entryBlock = front();
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
return emitOpError("type of entry block argument #")
<< i << '(' << entryBlock.getArgument(i).getType()
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';
return success();
}
void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
auto oldType = getType();
int originalNumArgs = oldType.getNumInputs();
llvm::BitVector eraseIndices(originalNumArgs);
for (auto index : argIndices)
eraseIndices.set(index);
auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
// Update the function type and arg attrs.
SmallVector<Type, 4> newInputTypes;
SmallVector<NamedAttributeList, 4> newArgAttrs;
for (int i = 0; i < originalNumArgs; i++) {
if (shouldEraseArg(i))
continue;
newInputTypes.emplace_back(oldType.getInput(i));
newArgAttrs.emplace_back(getArgAttrDict(i));
}
setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
setAllArgAttrs(newArgAttrs);
// Update the entry block's arguments.
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
Block &entry = front();
for (int i = 0; i < originalNumArgs; i++)
if (shouldEraseArg(originalNumArgs - i - 1))
entry.eraseArgument(originalNumArgs - i - 1);
}
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
Block *FuncOp::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
push_back(entry);
entry->addArguments(getType().getInputs());
return entry;
}
/// Add a normal block to the end of the function's block list. The function
/// should at least already have an entry block.
Block *FuncOp::addBlock() {
assert(!empty() && "function should at least have an entry block");
push_back(new Block());
return &back();
}
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<Identifier, Attribute> newAttrs;
for (auto &attr : dest.getAttrs())
newAttrs.insert(attr);
for (auto &attr : getAttrs())
newAttrs.insert(attr);
dest.getOperation()->setAttrs(
DictionaryAttr::get(newAttrs.takeVector(), getContext()));
// Clone the body.
getBody().cloneInto(&dest.getBody(), mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
FunctionType newType = getType();
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
// argument to the input type vector.
bool isExternalFn = isExternal();
if (!isExternalFn) {
SmallVector<Type, 4> inputTypes;
inputTypes.reserve(newType.getNumInputs());
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
inputTypes.push_back(newType.getInput(i));
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
}
// Create the new function.
FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
newFunc.setType(newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
newFunc.setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
FuncOp FuncOp::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}