CopyRemoval.cpp 6.64 KB
//===- CopyRemoval.cpp - Removing the redundant copies --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;
using namespace MemoryEffects;

namespace {

//===----------------------------------------------------------------------===//
// CopyRemovalPass
//===----------------------------------------------------------------------===//
/// This pass removes the redundant Copy operations. Additionally, it
/// removes the leftover definition and deallocation operations by erasing the
/// copy operation.
class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
private:
  /// List of operations that need to be removed.
  DenseSet<Operation *> eraseList;

  /// Returns the deallocation operation for `value` in `block` if it exists.
  Operation *getDeallocationInBlock(Value value, Block *block) {
    auto valueUsers = value.getUsers();
    auto it = llvm::find_if(valueUsers, [&](Operation *op) {
      auto effects = dyn_cast<MemoryEffectOpInterface>(op);
      return effects && op->getBlock() == block && effects.hasEffect<Free>();
    });
    return (it == valueUsers.end() ? nullptr : *it);
  }

  /// Returns true if an operation between start and end operations has memory
  /// effect.
  bool hasMemoryEffectOpBetween(Operation *start, Operation *end) {
    assert(start->getBlock() == end->getBlock() &&
           "Start and end operations should be in the same block.");
    Operation *op = start->getNextNode();
    while (op->isBeforeInBlock(end)) {
      auto effects = dyn_cast<MemoryEffectOpInterface>(op);
      if (effects)
        return true;
      op = op->getNextNode();
    }
    return false;
  };

  /// Returns true if `val` value has at least a user between `start` and
  /// `end` operations.
  bool hasUsersBetween(Value val, Operation *start, Operation *end) {
    Block *block = start->getBlock();
    assert(block == end->getBlock() &&
           "Start and end operations should be in the same block.");
    return llvm::any_of(val.getUsers(), [&](Operation *op) {
      return op->getBlock() == block && start->isBeforeInBlock(op) &&
             op->isBeforeInBlock(end);
    });
  };

  bool areOpsInTheSameBlock(ArrayRef<Operation *> operations) {
    llvm::SmallPtrSet<Block *, 4> blocks;
    for (Operation *op : operations)
      blocks.insert(op->getBlock());
    return blocks.size() == 1;
  }

  /// Input:
  /// func(){
  ///   %from = alloc()
  ///   write_to(%from)
  ///   %to = alloc()
  ///   copy(%from,%to)
  ///   dealloc(%from)
  ///   return %to
  /// }
  ///
  /// Output:
  /// func(){
  ///   %from = alloc()
  ///   write_to(%from)
  ///   return %from
  /// }
  /// Constraints:
  /// 1) %to, copy and dealloc must all be defined and lie in the same block.
  /// 2) This transformation cannot be applied if there is a single user/alias
  /// of `to` value between the defining operation of `to` and the copy
  /// operation.
  /// 3) This transformation cannot be applied if there is a single user/alias
  /// of `from` value between the copy operation and the deallocation of `from`.
  /// TODO: Alias analysis is not available at the moment. Currently, we check
  /// if there are any operations with memory effects between copy and
  /// deallocation operations.
  void ReuseCopySourceAsTarget(CopyOpInterface copyOp) {
    if (eraseList.count(copyOp))
      return;

    Value from = copyOp.getSource();
    Value to = copyOp.getTarget();

    Operation *copy = copyOp.getOperation();
    Operation *fromDefiningOp = from.getDefiningOp();
    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
    Operation *toDefiningOp = to.getDefiningOp();
    if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
        !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
        hasUsersBetween(to, toDefiningOp, copy) ||
        hasUsersBetween(from, copy, fromFreeingOp) ||
        hasMemoryEffectOpBetween(copy, fromFreeingOp))
      return;

    to.replaceAllUsesWith(from);
    eraseList.insert(copy);
    eraseList.insert(toDefiningOp);
    eraseList.insert(fromFreeingOp);
  }

  /// Input:
  /// func(){
  ///   %to = alloc()
  ///   %from = alloc()
  ///   write_to(%from)
  ///   copy(%from,%to)
  ///   dealloc(%from)
  ///   return %to
  /// }
  ///
  /// Output:
  /// func(){
  ///   %to = alloc()
  ///   write_to(%to)
  ///   return %to
  /// }
  /// Constraints:
  /// 1) %from, copy and dealloc must all be defined and lie in the same block.
  /// 2) This transformation cannot be applied if there is a single user/alias
  /// of `to` value between the defining operation of `from` and the copy
  /// operation.
  /// 3) This transformation cannot be applied if there is a single user/alias
  /// of `from` value between the copy operation and the deallocation of `from`.
  /// TODO: Alias analysis is not available at the moment. Currently, we check
  /// if there are any operations with memory effects between copy and
  /// deallocation operations.
  void ReuseCopyTargetAsSource(CopyOpInterface copyOp) {
    if (eraseList.count(copyOp))
      return;

    Value from = copyOp.getSource();
    Value to = copyOp.getTarget();

    Operation *copy = copyOp.getOperation();
    Operation *fromDefiningOp = from.getDefiningOp();
    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
    if (!fromDefiningOp || !fromFreeingOp ||
        !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
        hasUsersBetween(to, fromDefiningOp, copy) ||
        hasUsersBetween(from, copy, fromFreeingOp) ||
        hasMemoryEffectOpBetween(copy, fromFreeingOp))
      return;

    from.replaceAllUsesWith(to);
    eraseList.insert(copy);
    eraseList.insert(fromDefiningOp);
    eraseList.insert(fromFreeingOp);
  }

public:
  void runOnOperation() override {
    getOperation()->walk([&](CopyOpInterface copyOp) {
      ReuseCopySourceAsTarget(copyOp);
      ReuseCopyTargetAsSource(copyOp);
    });
    for (Operation *op : eraseList)
      op->erase();
  }
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
// CopyRemovalPass construction
//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> mlir::createCopyRemovalPass() {
  return std::make_unique<CopyRemovalPass>();
}