Inliner.cpp 11.9 KB
//===- Inliner.cpp - Pass to inline function calls ------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a basic inlining algorithm that operates bottom up over
// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
// incremental propagation of inlining decisions from the leafs to the roots of
// the callgraph.
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Parallel.h"

#define DEBUG_TYPE "inlining"

using namespace mlir;

static llvm::cl::opt<bool> disableCanonicalization(
    "mlir-disable-inline-simplify",
    llvm::cl::desc("Disable running simplifications during inlining"),
    llvm::cl::ReallyHidden, llvm::cl::init(false));

static llvm::cl::opt<unsigned> maxInliningIterations(
    "mlir-max-inline-iterations",
    llvm::cl::desc("Maximum number of iterations when inlining within an SCC"),
    llvm::cl::ReallyHidden, llvm::cl::init(4));

//===----------------------------------------------------------------------===//
// CallGraph traversal
//===----------------------------------------------------------------------===//

/// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal.
static void runTransformOnCGSCCs(
    const CallGraph &cg,
    function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
  std::vector<CallGraphNode *> currentSCCVec;
  auto cgi = llvm::scc_begin(&cg);
  while (!cgi.isAtEnd()) {
    // Copy the current SCC and increment so that the transformer can modify the
    // SCC without invalidating our iterator.
    currentSCCVec = *cgi;
    ++cgi;
    sccTransformer(currentSCCVec);
  }
}

namespace {
/// This struct represents a resolved call to a given callgraph node. Given that
/// the call does not actually contain a direct reference to the
/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
/// explicitly.
struct ResolvedCall {
  ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
      : call(call), targetNode(targetNode) {}
  CallOpInterface call;
  CallGraphNode *targetNode;
};
} // end anonymous namespace

/// Collect all of the callable operations within the given range of blocks. If
/// `traverseNestedCGNodes` is true, this will also collect call operations
/// inside of nested callgraph nodes.
static void collectCallOps(iterator_range<Region::iterator> blocks,
                           CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
                           bool traverseNestedCGNodes) {
  SmallVector<Block *, 8> worklist;
  auto addToWorklist = [&](iterator_range<Region::iterator> blocks) {
    for (Block &block : blocks)
      worklist.push_back(&block);
  };

  addToWorklist(blocks);
  while (!worklist.empty()) {
    for (Operation &op : *worklist.pop_back_val()) {
      if (auto call = dyn_cast<CallOpInterface>(op)) {
        CallInterfaceCallable callable = call.getCallableForCallee();

        // TODO(riverriddle) Support inlining nested call references.
        if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
          if (!symRef.isa<FlatSymbolRefAttr>())
            continue;
        }

        CallGraphNode *node = cg.resolveCallable(callable, &op);
        if (!node->isExternal())
          calls.emplace_back(call, node);
        continue;
      }

      // If this is not a call, traverse the nested regions. If
      // `traverseNestedCGNodes` is false, then don't traverse nested call graph
      // regions.
      for (auto &nestedRegion : op.getRegions())
        if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
          addToWorklist(nestedRegion);
    }
  }
}

//===----------------------------------------------------------------------===//
// Inliner
//===----------------------------------------------------------------------===//
namespace {
/// This class provides a specialization of the main inlining interface.
struct Inliner : public InlinerInterface {
  Inliner(MLIRContext *context, CallGraph &cg)
      : InlinerInterface(context), cg(cg) {}

  /// Process a set of blocks that have been inlined. This callback is invoked
  /// *before* inlined terminator operations have been processed.
  void
  processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
    collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
  }

  /// The current set of call instructions to consider for inlining.
  SmallVector<ResolvedCall, 8> calls;

  /// The callgraph being operated on.
  CallGraph &cg;
};
} // namespace

/// Returns true if the given call should be inlined.
static bool shouldInline(ResolvedCall &resolvedCall) {
  // Don't allow inlining terminator calls. We currently don't support this
  // case.
  if (resolvedCall.call.getOperation()->isKnownTerminator())
    return false;

  // Don't allow inlining if the target is an ancestor of the call. This
  // prevents inlining recursively.
  if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
          resolvedCall.call.getParentRegion()))
    return false;

  // Otherwise, inline.
  return true;
}

/// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise.
static LogicalResult inlineCallsInSCC(Inliner &inliner,
                                      ArrayRef<CallGraphNode *> currentSCC) {
  CallGraph &cg = inliner.cg;
  auto &calls = inliner.calls;

  // Collect all of the direct calls within the nodes of the current SCC. We
  // don't traverse nested callgraph nodes, because they are handled separately
  // likely within a different SCC.
  for (auto *node : currentSCC) {
    if (!node->isExternal())
      collectCallOps(*node->getCallableRegion(), cg, calls,
                     /*traverseNestedCGNodes=*/false);
  }
  if (calls.empty())
    return failure();

  // Try to inline each of the call operations. Don't cache the end iterator
  // here as more calls may be added during inlining.
  bool inlinedAnyCalls = false;
  for (unsigned i = 0; i != calls.size(); ++i) {
    ResolvedCall &it = calls[i];
    LLVM_DEBUG({
      llvm::dbgs() << "* Considering inlining call: ";
      it.call.dump();
    });
    if (!shouldInline(it))
      continue;

    CallOpInterface call = it.call;
    Region *targetRegion = it.targetNode->getCallableRegion();
    LogicalResult inlineResult = inlineCall(
        inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
        targetRegion);
    if (failed(inlineResult))
      continue;

    // If the inlining was successful, then erase the call.
    call.erase();
    inlinedAnyCalls = true;
  }
  calls.clear();
  return success(inlinedAnyCalls);
}

/// Canonicalize the nodes within the given SCC with the given set of
/// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC,
                            MLIRContext *context,
                            const OwningRewritePatternList &canonPatterns) {
  // Collect the sets of nodes to canonicalize.
  SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
  for (auto *node : currentSCC) {
    // Don't canonicalize the external node, it has no valid callable region.
    if (node->isExternal())
      continue;

    // Don't canonicalize nodes with children. Nodes with children
    // require special handling as we may remove the node during
    // canonicalization. In the future, we should be able to handle this
    // case with proper node deletion tracking.
    if (node->hasChildren())
      continue;

    // We also won't apply canonicalizations for nodes that are not
    // isolated. This avoids potentially mutating the regions of nodes defined
    // above, this is also a stipulation of the 'applyPatternsGreedily' driver.
    auto *region = node->getCallableRegion();
    if (!region->getParentOp()->isKnownIsolatedFromAbove())
      continue;
    nodesToCanonicalize.push_back(node);
  }
  if (nodesToCanonicalize.empty())
    return;

  // Canonicalize each of the nodes within the SCC in parallel.
  // NOTE: This is simple now, because we don't enable canonicalizing nodes
  // within children. When we remove this restriction, this logic will need to
  // be reworked.
  ParallelDiagnosticHandler canonicalizationHandler(context);
  llvm::parallel::for_each_n(
      llvm::parallel::par, /*Begin=*/size_t(0),
      /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
        // Set the order for this thread so that diagnostics will be properly
        // ordered.
        canonicalizationHandler.setOrderIDForThread(index);

        // Apply the canonicalization patterns to this region.
        auto *node = nodesToCanonicalize[index];
        applyPatternsGreedily(*node->getCallableRegion(), canonPatterns);

        // Make sure to reset the order ID for the diagnostic handler, as this
        // thread may be used in a different context.
        canonicalizationHandler.eraseOrderIDForThread();
      });
}

/// Attempt to inline calls within the given scc, and run canonicalizations with
/// the given patterns, until a fixed point is reached. This allows for the
/// inlining of newly devirtualized calls.
static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC,
                      MLIRContext *context,
                      const OwningRewritePatternList &canonPatterns) {
  // If we successfully inlined any calls, run some simplifications on the
  // nodes of the scc. Continue attempting to inline until we reach a fixed
  // point, or a maximum iteration count. We canonicalize here as it may
  // devirtualize new calls, as well as give us a better cost model.
  unsigned iterationCount = 0;
  while (succeeded(inlineCallsInSCC(inliner, currentSCC))) {
    // If we aren't allowing simplifications or the max iteration count was
    // reached, then bail out early.
    if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
      break;
    canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns);
  }
}

//===----------------------------------------------------------------------===//
// InlinerPass
//===----------------------------------------------------------------------===//

// TODO(riverriddle) This pass should currently only be used for basic testing
// of inlining functionality.
namespace {
struct InlinerPass : public OperationPass<InlinerPass> {
  void runOnOperation() override {
    CallGraph &cg = getAnalysis<CallGraph>();
    auto *context = &getContext();

    // The inliner should only be run on operations that define a symbol table,
    // as the callgraph will need to resolve references.
    Operation *op = getOperation();
    if (!op->hasTrait<OpTrait::SymbolTable>()) {
      op->emitOpError() << " was scheduled to run under the inliner, but does "
                           "not define a symbol table";
      return signalPassFailure();
    }

    // Collect a set of canonicalization patterns to use when simplifying
    // callable regions within an SCC.
    OwningRewritePatternList canonPatterns;
    for (auto *op : context->getRegisteredOperations())
      op->getCanonicalizationPatterns(canonPatterns, context);

    // Run the inline transform in post-order over the SCCs in the callgraph.
    Inliner inliner(context, cg);
    runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
      inlineSCC(inliner, scc, context, canonPatterns);
    });
  }
};
} // end anonymous namespace

std::unique_ptr<Pass> mlir::createInlinerPass() {
  return std::make_unique<InlinerPass>();
}

static PassRegistration<InlinerPass> pass("inline", "Inline function calls");