CalledValuePropagation.cpp 17.7 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
//===- CalledValuePropagation.cpp - Propagate called values -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a transformation that attaches !callees metadata to
// indirect call sites. For a given call site, the metadata, if present,
// indicates the set of functions the call site could possibly target at
// run-time. This metadata is added to indirect call sites when the set of
// possible targets can be determined by analysis and is known to be small. The
// analysis driving the transformation is similar to constant propagation and
// makes uses of the generic sparse propagation solver.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/IPO/CalledValuePropagation.h"
#include "llvm/Analysis/SparsePropagation.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
using namespace llvm;

#define DEBUG_TYPE "called-value-propagation"

/// The maximum number of functions to track per lattice value. Once the number
/// of functions a call site can possibly target exceeds this threshold, it's
/// lattice value becomes overdefined. The number of possible lattice values is
/// bounded by Ch(F, M), where F is the number of functions in the module and M
/// is MaxFunctionsPerValue. As such, this value should be kept very small. We
/// likely can't do anything useful for call sites with a large number of
/// possible targets, anyway.
static cl::opt<unsigned> MaxFunctionsPerValue(
    "cvp-max-functions-per-value", cl::Hidden, cl::init(4),
    cl::desc("The maximum number of functions to track per lattice value"));

namespace {
/// To enable interprocedural analysis, we assign LLVM values to the following
/// groups. The register group represents SSA registers, the return group
/// represents the return values of functions, and the memory group represents
/// in-memory values. An LLVM Value can technically be in more than one group.
/// It's necessary to distinguish these groups so we can, for example, track a
/// global variable separately from the value stored at its location.
enum class IPOGrouping { Register, Return, Memory };

/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
using CVPLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;

/// The lattice value type used by our custom lattice function. It holds the
/// lattice state, and a set of functions.
class CVPLatticeVal {
public:
  /// The states of the lattice values. Only the FunctionSet state is
  /// interesting. It indicates the set of functions to which an LLVM value may
  /// refer.
  enum CVPLatticeStateTy { Undefined, FunctionSet, Overdefined, Untracked };

  /// Comparator for sorting the functions set. We want to keep the order
  /// deterministic for testing, etc.
  struct Compare {
    bool operator()(const Function *LHS, const Function *RHS) const {
      return LHS->getName() < RHS->getName();
    }
  };

  CVPLatticeVal() : LatticeState(Undefined) {}
  CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {}
  CVPLatticeVal(std::vector<Function *> &&Functions)
      : LatticeState(FunctionSet), Functions(std::move(Functions)) {
    assert(llvm::is_sorted(this->Functions, Compare()));
  }

  /// Get a reference to the functions held by this lattice value. The number
  /// of functions will be zero for states other than FunctionSet.
  const std::vector<Function *> &getFunctions() const {
    return Functions;
  }

  /// Returns true if the lattice value is in the FunctionSet state.
  bool isFunctionSet() const { return LatticeState == FunctionSet; }

  bool operator==(const CVPLatticeVal &RHS) const {
    return LatticeState == RHS.LatticeState && Functions == RHS.Functions;
  }

  bool operator!=(const CVPLatticeVal &RHS) const {
    return LatticeState != RHS.LatticeState || Functions != RHS.Functions;
  }

private:
  /// Holds the state this lattice value is in.
  CVPLatticeStateTy LatticeState;

  /// Holds functions indicating the possible targets of call sites. This set
  /// is empty for lattice values in the undefined, overdefined, and untracked
  /// states. The maximum size of the set is controlled by
  /// MaxFunctionsPerValue. Since most LLVM values are expected to be in
  /// uninteresting states (i.e., overdefined), CVPLatticeVal objects should be
  /// small and efficiently copyable.
  // FIXME: This could be a TinyPtrVector and/or merge with LatticeState.
  std::vector<Function *> Functions;
};

/// The custom lattice function used by the generic sparse propagation solver.
/// It handles merging lattice values and computing new lattice values for
/// constants, arguments, values returned from trackable functions, and values
/// located in trackable global variables. It also computes the lattice values
/// that change as a result of executing instructions.
class CVPLatticeFunc
    : public AbstractLatticeFunction<CVPLatticeKey, CVPLatticeVal> {
public:
  CVPLatticeFunc()
      : AbstractLatticeFunction(CVPLatticeVal(CVPLatticeVal::Undefined),
                                CVPLatticeVal(CVPLatticeVal::Overdefined),
                                CVPLatticeVal(CVPLatticeVal::Untracked)) {}

  /// Compute and return a CVPLatticeVal for the given CVPLatticeKey.
  CVPLatticeVal ComputeLatticeVal(CVPLatticeKey Key) override {
    switch (Key.getInt()) {
    case IPOGrouping::Register:
      if (isa<Instruction>(Key.getPointer())) {
        return getUndefVal();
      } else if (auto *A = dyn_cast<Argument>(Key.getPointer())) {
        if (canTrackArgumentsInterprocedurally(A->getParent()))
          return getUndefVal();
      } else if (auto *C = dyn_cast<Constant>(Key.getPointer())) {
        return computeConstant(C);
      }
      return getOverdefinedVal();
    case IPOGrouping::Memory:
    case IPOGrouping::Return:
      if (auto *GV = dyn_cast<GlobalVariable>(Key.getPointer())) {
        if (canTrackGlobalVariableInterprocedurally(GV))
          return computeConstant(GV->getInitializer());
      } else if (auto *F = cast<Function>(Key.getPointer()))
        if (canTrackReturnsInterprocedurally(F))
          return getUndefVal();
    }
    return getOverdefinedVal();
  }

  /// Merge the two given lattice values. The interesting cases are merging two
  /// FunctionSet values and a FunctionSet value with an Undefined value. For
  /// these cases, we simply union the function sets. If the size of the union
  /// is greater than the maximum functions we track, the merged value is
  /// overdefined.
  CVPLatticeVal MergeValues(CVPLatticeVal X, CVPLatticeVal Y) override {
    if (X == getOverdefinedVal() || Y == getOverdefinedVal())
      return getOverdefinedVal();
    if (X == getUndefVal() && Y == getUndefVal())
      return getUndefVal();
    std::vector<Function *> Union;
    std::set_union(X.getFunctions().begin(), X.getFunctions().end(),
                   Y.getFunctions().begin(), Y.getFunctions().end(),
                   std::back_inserter(Union), CVPLatticeVal::Compare{});
    if (Union.size() > MaxFunctionsPerValue)
      return getOverdefinedVal();
    return CVPLatticeVal(std::move(Union));
  }

  /// Compute the lattice values that change as a result of executing the given
  /// instruction. The changed values are stored in \p ChangedValues. We handle
  /// just a few kinds of instructions since we're only propagating values that
  /// can be called.
  void ComputeInstructionState(
      Instruction &I, DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
      SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) override {
    switch (I.getOpcode()) {
    case Instruction::Call:
    case Instruction::Invoke:
      return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
    case Instruction::Load:
      return visitLoad(*cast<LoadInst>(&I), ChangedValues, SS);
    case Instruction::Ret:
      return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
    case Instruction::Select:
      return visitSelect(*cast<SelectInst>(&I), ChangedValues, SS);
    case Instruction::Store:
      return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
    default:
      return visitInst(I, ChangedValues, SS);
    }
  }

  /// Print the given CVPLatticeVal to the specified stream.
  void PrintLatticeVal(CVPLatticeVal LV, raw_ostream &OS) override {
    if (LV == getUndefVal())
      OS << "Undefined  ";
    else if (LV == getOverdefinedVal())
      OS << "Overdefined";
    else if (LV == getUntrackedVal())
      OS << "Untracked  ";
    else
      OS << "FunctionSet";
  }

  /// Print the given CVPLatticeKey to the specified stream.
  void PrintLatticeKey(CVPLatticeKey Key, raw_ostream &OS) override {
    if (Key.getInt() == IPOGrouping::Register)
      OS << "<reg> ";
    else if (Key.getInt() == IPOGrouping::Memory)
      OS << "<mem> ";
    else if (Key.getInt() == IPOGrouping::Return)
      OS << "<ret> ";
    if (isa<Function>(Key.getPointer()))
      OS << Key.getPointer()->getName();
    else
      OS << *Key.getPointer();
  }

  /// We collect a set of indirect calls when visiting call sites. This method
  /// returns a reference to that set.
  SmallPtrSetImpl<CallBase *> &getIndirectCalls() { return IndirectCalls; }

private:
  /// Holds the indirect calls we encounter during the analysis. We will attach
  /// metadata to these calls after the analysis indicating the functions the
  /// calls can possibly target.
  SmallPtrSet<CallBase *, 32> IndirectCalls;

  /// Compute a new lattice value for the given constant. The constant, after
  /// stripping any pointer casts, should be a Function. We ignore null
  /// pointers as an optimization, since calling these values is undefined
  /// behavior.
  CVPLatticeVal computeConstant(Constant *C) {
    if (isa<ConstantPointerNull>(C))
      return CVPLatticeVal(CVPLatticeVal::FunctionSet);
    if (auto *F = dyn_cast<Function>(C->stripPointerCasts()))
      return CVPLatticeVal({F});
    return getOverdefinedVal();
  }

  /// Handle return instructions. The function's return state is the merge of
  /// the returned value state and the function's return state.
  void visitReturn(ReturnInst &I,
                   DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                   SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    Function *F = I.getParent()->getParent();
    if (F->getReturnType()->isVoidTy())
      return;
    auto RegI = CVPLatticeKey(I.getReturnValue(), IPOGrouping::Register);
    auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
    ChangedValues[RetF] =
        MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
  }

  /// Handle call sites. The state of a called function's formal arguments is
  /// the merge of the argument state with the call sites corresponding actual
  /// argument state. The call site state is the merge of the call site state
  /// with the returned value state of the called function.
  void visitCallBase(CallBase &CB,
                     DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                     SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    Function *F = CB.getCalledFunction();
    auto RegI = CVPLatticeKey(&CB, IPOGrouping::Register);

    // If this is an indirect call, save it so we can quickly revisit it when
    // attaching metadata.
    if (!F)
      IndirectCalls.insert(&CB);

    // If we can't track the function's return values, there's nothing to do.
    if (!F || !canTrackReturnsInterprocedurally(F)) {
      // Void return, No need to create and update CVPLattice state as no one
      // can use it.
      if (CB.getType()->isVoidTy())
        return;
      ChangedValues[RegI] = getOverdefinedVal();
      return;
    }

    // Inform the solver that the called function is executable, and perform
    // the merges for the arguments and return value.
    SS.MarkBlockExecutable(&F->front());
    auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
    for (Argument &A : F->args()) {
      auto RegFormal = CVPLatticeKey(&A, IPOGrouping::Register);
      auto RegActual =
          CVPLatticeKey(CB.getArgOperand(A.getArgNo()), IPOGrouping::Register);
      ChangedValues[RegFormal] =
          MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
    }

    // Void return, No need to create and update CVPLattice state as no one can
    // use it.
    if (CB.getType()->isVoidTy())
      return;

    ChangedValues[RegI] =
        MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
  }

  /// Handle select instructions. The select instruction state is the merge the
  /// true and false value states.
  void visitSelect(SelectInst &I,
                   DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                   SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
    auto RegT = CVPLatticeKey(I.getTrueValue(), IPOGrouping::Register);
    auto RegF = CVPLatticeKey(I.getFalseValue(), IPOGrouping::Register);
    ChangedValues[RegI] =
        MergeValues(SS.getValueState(RegT), SS.getValueState(RegF));
  }

  /// Handle load instructions. If the pointer operand of the load is a global
  /// variable, we attempt to track the value. The loaded value state is the
  /// merge of the loaded value state with the global variable state.
  void visitLoad(LoadInst &I,
                 DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                 SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
    if (auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand())) {
      auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
      ChangedValues[RegI] =
          MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
    } else {
      ChangedValues[RegI] = getOverdefinedVal();
    }
  }

  /// Handle store instructions. If the pointer operand of the store is a
  /// global variable, we attempt to track the value. The global variable state
  /// is the merge of the stored value state with the global variable state.
  void visitStore(StoreInst &I,
                  DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                  SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
    if (!GV)
      return;
    auto RegI = CVPLatticeKey(I.getValueOperand(), IPOGrouping::Register);
    auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
    ChangedValues[MemGV] =
        MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
  }

  /// Handle all other instructions. All other instructions are marked
  /// overdefined.
  void visitInst(Instruction &I,
                 DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
                 SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
    // Simply bail if this instruction has no user.
    if (I.use_empty())
      return;
    auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
    ChangedValues[RegI] = getOverdefinedVal();
  }
};
} // namespace

namespace llvm {
/// A specialization of LatticeKeyInfo for CVPLatticeKeys. The generic solver
/// must translate between LatticeKeys and LLVM Values when adding Values to
/// its work list and inspecting the state of control-flow related values.
template <> struct LatticeKeyInfo<CVPLatticeKey> {
  static inline Value *getValueFromLatticeKey(CVPLatticeKey Key) {
    return Key.getPointer();
  }
  static inline CVPLatticeKey getLatticeKeyFromValue(Value *V) {
    return CVPLatticeKey(V, IPOGrouping::Register);
  }
};
} // namespace llvm

static bool runCVP(Module &M) {
  // Our custom lattice function and generic sparse propagation solver.
  CVPLatticeFunc Lattice;
  SparseSolver<CVPLatticeKey, CVPLatticeVal> Solver(&Lattice);

  // For each function in the module, if we can't track its arguments, let the
  // generic solver assume it is executable.
  for (Function &F : M)
    if (!F.isDeclaration() && !canTrackArgumentsInterprocedurally(&F))
      Solver.MarkBlockExecutable(&F.front());

  // Solver our custom lattice. In doing so, we will also build a set of
  // indirect call sites.
  Solver.Solve();

  // Attach metadata to the indirect call sites that were collected indicating
  // the set of functions they can possibly target.
  bool Changed = false;
  MDBuilder MDB(M.getContext());
  for (CallBase *C : Lattice.getIndirectCalls()) {
    auto RegI = CVPLatticeKey(C->getCalledOperand(), IPOGrouping::Register);
    CVPLatticeVal LV = Solver.getExistingValueState(RegI);
    if (!LV.isFunctionSet() || LV.getFunctions().empty())
      continue;
    MDNode *Callees = MDB.createCallees(LV.getFunctions());
    C->setMetadata(LLVMContext::MD_callees, Callees);
    Changed = true;
  }

  return Changed;
}

PreservedAnalyses CalledValuePropagationPass::run(Module &M,
                                                  ModuleAnalysisManager &) {
  runCVP(M);
  return PreservedAnalyses::all();
}

namespace {
class CalledValuePropagationLegacyPass : public ModulePass {
public:
  static char ID;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesAll();
  }

  CalledValuePropagationLegacyPass() : ModulePass(ID) {
    initializeCalledValuePropagationLegacyPassPass(
        *PassRegistry::getPassRegistry());
  }

  bool runOnModule(Module &M) override {
    if (skipModule(M))
      return false;
    return runCVP(M);
  }
};
} // namespace

char CalledValuePropagationLegacyPass::ID = 0;
INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation",
                "Called Value Propagation", false, false)

ModulePass *llvm::createCalledValuePropagationPass() {
  return new CalledValuePropagationLegacyPass();
}