CallGraph.cpp
7.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces and analyses for defining a nested callgraph.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// CallInterfaces
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/CallInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// CallGraphNode
//===----------------------------------------------------------------------===//
/// Returns if this node refers to the indirect/external node.
bool CallGraphNode::isExternal() const { return !callableRegion; }
/// Return the callable region this node represents. This can only be called
/// on non-external nodes.
Region *CallGraphNode::getCallableRegion() const {
assert(!isExternal() && "the external node has no callable region");
return callableRegion;
}
/// Adds an reference edge to the given node. This is only valid on the
/// external node.
void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
assert(isExternal() && "abstract edges are only valid on external nodes");
addEdge(node, Edge::Kind::Abstract);
}
/// Add an outgoing call edge from this node.
void CallGraphNode::addCallEdge(CallGraphNode *node) {
addEdge(node, Edge::Kind::Call);
}
/// Adds a reference edge to the given child node.
void CallGraphNode::addChildEdge(CallGraphNode *child) {
addEdge(child, Edge::Kind::Child);
}
/// Returns true if this node has any child edges.
bool CallGraphNode::hasChildren() const {
return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
}
/// Add an edge to 'node' with the given kind.
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
edges.insert({node, kind});
}
//===----------------------------------------------------------------------===//
// CallGraph
//===----------------------------------------------------------------------===//
/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
CallGraphNode *parentNode, bool resolveCalls) {
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
// If there is no parent node, we ignore this operation. Even if this
// operation was a call, there would be no callgraph node to attribute it
// to.
if (!resolveCalls || !parentNode)
return;
parentNode->addCallEdge(
cg.resolveCallable(call.getCallableForCallee(), op));
return;
}
// Compute the callgraph nodes and edges for each of the nested operations.
if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
if (auto *callableRegion = callable.getCallableRegion())
parentNode = cg.getOrAddNode(callableRegion, parentNode);
else
return;
}
for (Region ®ion : op->getRegions())
for (Block &block : region)
for (Operation &nested : block)
computeCallGraph(&nested, cg, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
// Make two passes over the graph, one to compute the callables and one to
// resolve the calls. We split these up as we may have nested callable objects
// that need to be reserved before the calls.
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
}
/// Get or add a call graph node for the given region.
CallGraphNode *CallGraph::getOrAddNode(Region *region,
CallGraphNode *parentNode) {
assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
"expected parent operation to be callable");
std::unique_ptr<CallGraphNode> &node = nodes[region];
if (!node) {
node.reset(new CallGraphNode(region));
// Add this node to the given parent node if necessary.
if (parentNode)
parentNode->addChildEdge(node.get());
else
// Otherwise, connect all callable nodes to the external node, this allows
// for conservatively including all callable nodes within the graph.
// FIXME(riverriddle) This isn't correct, this is only necessary for
// callable nodes that *could* be called from external sources. This
// requires extending the interface for callables to check if they may be
// referenced externally.
externalNode.addAbstractEdge(node.get());
}
return node.get();
}
/// Lookup a call graph node for the given region, or nullptr if none is
/// registered.
CallGraphNode *CallGraph::lookupNode(Region *region) const {
auto it = nodes.find(region);
return it == nodes.end() ? nullptr : it->second.get();
}
/// Resolve the callable for given callee to a node in the callgraph, or the
/// external node if a valid node was not resolved.
CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
Operation *from) const {
// Get the callee operation from the callable.
Operation *callee;
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
else
callee = callable.get<Value>().getDefiningOp();
// If the callee is non-null and is a valid callable object, try to get the
// called region from it.
if (callee && callee->getNumRegions()) {
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
}
}
// If we don't have a valid direct region, this is an external call.
return getExternalNode();
}
//===----------------------------------------------------------------------===//
// Printing
/// Dump the graph in a human readable format.
void CallGraph::dump() const { print(llvm::errs()); }
void CallGraph::print(raw_ostream &os) const {
os << "// ---- CallGraph ----\n";
// Functor used to output the name for the given node.
auto emitNodeName = [&](const CallGraphNode *node) {
if (node->isExternal()) {
os << "<External-Node>";
return;
}
auto *callableRegion = node->getCallableRegion();
auto *parentOp = callableRegion->getParentOp();
os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
<< callableRegion->getRegionNumber();
if (auto attrs = parentOp->getAttrList().getDictionary())
os << " : " << attrs;
};
for (auto &nodeIt : nodes) {
const CallGraphNode *node = nodeIt.second.get();
// Dump the header for this node.
os << "// - Node : ";
emitNodeName(node);
os << "\n";
// Emit each of the edges.
for (auto &edge : *node) {
os << "// -- ";
if (edge.isCall())
os << "Call";
else if (edge.isChild())
os << "Child";
os << "-Edge : ";
emitNodeName(edge.getTarget());
os << "\n";
}
os << "//\n";
}
os << "// -- SCCs --\n";
for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
os << "// - SCC : \n";
for (auto &node : scc) {
os << "// -- Node :";
emitNodeName(node);
os << "\n";
}
os << "\n";
}
os << "// -------------------\n";
}