ConvertSimQuant.cpp 5.19 KB
//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
#include "mlir/Dialect/QuantOps/Passes.h"
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/QuantOps/UniformSupport.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::quant;

namespace {

class ConvertSimulatedQuantPass
    : public FunctionPass<ConvertSimulatedQuantPass> {
public:
  void runOnFunction() override;
};

/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
template <typename ConcreteRewriteClass, typename FakeQuantOp>
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
public:
  using OpRewritePattern<FakeQuantOp>::OpRewritePattern;

  FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
      : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}

  PatternMatchResult matchAndRewrite(FakeQuantOp op,
                                     PatternRewriter &rewriter) const override {
    // TODO: If this pattern comes up more frequently, consider adding core
    // support for failable rewrites.
    if (failableRewrite(op, rewriter)) {
      *hadFailure = true;
      return Pattern::matchFailure();
    }

    return Pattern::matchSuccess();
  }

private:
  bool *hadFailure;

  bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
    auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
    if (!converter) {
      return (op.emitError("unsupported quantized type conversion"), true);
    }

    QuantizedType elementType =
        static_cast<const ConcreteRewriteClass *>(this)
            ->convertFakeQuantAttrsToType(op, converter.expressedType);

    if (!elementType) {
      // Note that the fakeQuantAttrsToType will have emitted the error.
      return true;
    }

    Type quantizedType = converter.convert(elementType);
    assert(quantizedType &&
           "Converter accepted a type that it did not convert");

    // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
    // this is a forced/hard-coded constraint.
    auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
                                                    op.inputs());
    rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
                                                  qbarrier.getResult());

    return false;
  }
};

class ConstFakeQuantRewrite
    : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
public:
  using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;

  ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
      : BaseRewrite(ctx, hadFailure) {}

  QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
                                            Type expressedType) const {
    return fakeQuantAttrsToType(
        fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
        fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
        fqOp.narrow_range(), expressedType, fqOp.is_signed());
  }
};

class ConstFakeQuantPerAxisRewrite
    : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
                              ConstFakeQuantPerAxis> {
public:
  using BaseRewrite =
      FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;

  ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
      : BaseRewrite(ctx, hadFailure) {}

  QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
                                            Type expressedType) const {
    SmallVector<double, 4> min, max;
    min.reserve(fqOp.min().size());
    max.reserve(fqOp.max().size());
    for (auto m : fqOp.min())
      min.push_back(m.cast<FloatAttr>().getValueAsDouble());
    for (auto m : fqOp.max())
      max.push_back(m.cast<FloatAttr>().getValueAsDouble());

    return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
                                fqOp.axis().getSExtValue(), min, max,
                                fqOp.narrow_range(), expressedType,
                                fqOp.is_signed());
  }
};

} // namespace

void ConvertSimulatedQuantPass::runOnFunction() {
  bool hadFailure = false;
  OwningRewritePatternList patterns;
  auto func = getFunction();
  auto ctx = func.getContext();
  patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
      ctx, &hadFailure);
  applyPatternsGreedily(func, patterns);
  if (hadFailure)
    signalPassFailure();
}

std::unique_ptr<OpPassBase<FuncOp>>
mlir::quant::createConvertSimulatedQuantPass() {
  return std::make_unique<ConvertSimulatedQuantPass>();
}

static PassRegistration<ConvertSimulatedQuantPass>
    pass("quant-convert-simulated-quantization",
         "Converts training-time simulated quantization ops to corresponding "
         "quantize/dequantize casts.");