Traits.cpp 7.76 KB
//===- Traits.cpp - Common op traits shared by dialects -------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
                                        ArrayRef<int64_t> shape2,
                                        SmallVectorImpl<int64_t> &resultShape) {
  // To compute the result broadcasted shape, we compare operand shapes
  // element-wise: starting with the trailing dimensions, and working the
  // way backward. Two dimensions are compatible when
  //   1. they are equal, or
  //   2. one of them is 1
  // The result shape has the maximum among the two inputs at every
  // dimension index.

  resultShape.clear();
  if (shape1.size() > shape2.size()) {
    std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
  } else {
    std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
  }

  auto i1 = shape1.rbegin(), e1 = shape1.rend();
  auto i2 = shape2.rbegin(), e2 = shape2.rend();
  auto iR = resultShape.rbegin();

  // Check each dimension is consistent.
  for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
    if (*i1 == -1 || *i2 == -1) {
      // One or both dimensions is unknown. Follow TensorFlow behavior:
      // - If either dimension is greater than 1, we assume that the program is
      //   correct, and the other dimension will be broadcast to match it.
      // - If either dimension is 1, the other dimension is the output.
      if (*i1 > 1) {
        *iR = *i1;
      } else if (*i2 > 1) {
        *iR = *i2;
      } else if (*i1 == 1) {
        *iR = *i2;
      } else if (*i2 == 1) {
        *iR = *i1;
      } else {
        *iR = -1;
      }
    } else {
      if (*i1 == *i2 || *i2 == 1) {
        *iR = *i1;
      } else if (*i1 == 1) {
        *iR = *i2;
      } else {
        // This dimension of the two operand types is incompatible.
        resultShape.clear();
        return false;
      }
    }
  }

  return true;
}

/// Returns the shape of the given type. Scalars will be considered as having a
/// shape with zero dimensions.
static ArrayRef<int64_t> getShape(Type type) {
  if (auto sType = type.dyn_cast<ShapedType>())
    return sType.getShape();
  return {};
}

/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
  // Returns the scalar type out of the given type.
  auto getScalarType = [](Type type) -> Type {
    if (auto shapedType = type.dyn_cast<ShapedType>())
      return shapedType.getElementType();
    return type;
  };

  // Make sure underlying scalar type is the same.
  auto scalarType = getScalarType(type1);
  if (scalarType != getScalarType(type2))
    return {};

  // If one of the types is unranked tensor, then the other type shouldn't be
  // vector and the result should have unranked tensor type.
  if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
    if (type1.isa<VectorType>() || type2.isa<VectorType>())
      return {};
    return UnrankedTensorType::get(scalarType);
  }

  // Returns the type kind if the given type is a vector or ranked tensor type.
  // Returns llvm::None otherwise.
  auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
    if (type.isa<VectorType>() || type.isa<RankedTensorType>())
      return static_cast<StandardTypes::Kind>(type.getKind());
    return llvm::None;
  };

  // Make sure the composite type, if has, is consistent.
  auto compositeKind1 = getCompositeTypeKind(type1);
  auto compositeKind2 = getCompositeTypeKind(type2);
  Optional<StandardTypes::Kind> resultCompositeKind;

  if (compositeKind1 && compositeKind2) {
    // Disallow mixing vector and tensor.
    if (compositeKind1 != compositeKind2)
      return {};
    resultCompositeKind = compositeKind1;
  } else if (compositeKind1) {
    resultCompositeKind = compositeKind1;
  } else if (compositeKind2) {
    resultCompositeKind = compositeKind2;
  }

  // Get the shape of each type.
  SmallVector<int64_t, 4> resultShape;
  if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
    return {};

  // Compose the final broadcasted type
  if (resultCompositeKind == StandardTypes::Vector)
    return VectorType::get(resultShape, scalarType);
  if (resultCompositeKind == StandardTypes::RankedTensor)
    return RankedTensorType::get(resultShape, scalarType);
  return scalarType;
}

/// Returns true if the given types has both vector types and tensor types.
static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
  return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
         llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
}

static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
                                ArrayRef<int64_t> shape2) {
  auto isCompatible = [](int64_t dim1, int64_t dim2) {
    return dim1 == dim2 || dim1 == -1 || dim2 == -1;
  };
  if (shape1.size() != shape2.size())
    return false;
  for (auto p : llvm::zip(shape1, shape2))
    if (!isCompatible(std::get<0>(p), std::get<1>(p)))
      return false;
  return true;
}

LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
  assert(op->getNumOperands() == 2 &&
         "only support broadcast check on two operands");
  assert(op->getNumResults() == 1 &&
         "only support broadcast check on one result");

  auto type1 = op->getOperand(0).getType();
  auto type2 = op->getOperand(1).getType();
  auto retType = op->getResult(0).getType();

  // We forbid broadcasting vector and tensor.
  if (hasBothVectorAndTensorType({type1, type2, retType}))
    return op->emitError("cannot broadcast vector with tensor");

  if (retType.isa<UnrankedTensorType>())
    return success();

  bool isUnranked1 = type1.isa<UnrankedTensorType>();
  bool isUnranked2 = type2.isa<UnrankedTensorType>();

  // If both operands are unranked, then all result shapes are possible.
  if (isUnranked1 && isUnranked2)
    return success();

  // If one of the operands is unranked, then the known dimensions in the result
  // should be compatible with the other shaped operand.
  if (isUnranked1 || isUnranked2) {
    // Result should have higher rank than the shaped operand's rank and then
    // the result's trailing dimensions should be compatible with the operand
    // shape.
    ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
    ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
    if (!areCompatibleShapes(actualSuffix, shape))
      return op->emitOpError()
             << "result type " << retType
             << " has shape incompatible with a ranked operand type";
    return success();
  }

  // If both operands are shaped, then the computed broadcasted shape should be
  // compatible with the result shape.
  SmallVector<int64_t, 4> resultShape;
  if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
    return op->emitOpError("operands don't have broadcast-compatible shapes");

  if (!areCompatibleShapes(resultShape, getShape(retType)))
    return op->emitOpError() << "result type " << retType
                             << " does not have shape compatible with the one "
                                "computed from the operand types";

  return success();
}