UniformSolvers.cpp 4.84 KB
//===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Quantizer/Support/UniformSolvers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>

using namespace mlir;
using namespace mlir::quantizer;

bool UniformParamsFromMinMaxSolver::compute() {
  // Compute adjMin, adjMax, clamping to ensure that they straddle zero.
  if (boundingMin > 0 && boundingMax >= boundingMin) {
    // Lop-sided to the positive.
    adjMin = 0;
    adjMax = boundingMax;
  } else if (boundingMax < 0 && boundingMin <= boundingMax) {
    // Lop-sided to the negative.
    adjMin = boundingMin;
    adjMax = 0;
  } else if (boundingMin <= 0 && boundingMax >= 0) {
    adjMin = boundingMin;
    adjMax = boundingMax;
  } else {
    // Illegal bounds.
    return satisfied = false;
  }

  const double origMinAdj = adjMin;
  const double origMaxAdj = adjMax;
  const double numLevelsDouble = storageParams.numLevels;

  struct fns {
    static std::pair<double, double>
    computeMinMax(double boundingMin, double numLevels, double delta) {
      double adjMin = delta * std::floor(boundingMin / delta);
      return std::make_pair(adjMin, adjMin + numLevels * delta);
    }
    static double overshoot(double boundingMin, double boundingMax,
                            double numLevels, double delta) {
      auto adjMinMax = computeMinMax(boundingMin, numLevels, delta);
      double maxOvershoot = adjMinMax.second - boundingMax;
      double minOvershoot = boundingMin - adjMinMax.first;
      // If undershooting on the min or max end, return that because it is
      // to be unconditionally avoided. Otherwise return the end with the
      // greatest magnitude of overshoot.
      if (maxOvershoot < 0)
        return maxOvershoot;
      if (minOvershoot < 0)
        return minOvershoot;
      return std::max(maxOvershoot, minOvershoot);
    }
  };

  // Bisect to find a suitable delta, starting with bounds of deltaInit
  // and deltaMax.
  double deltaInit = (adjMax - adjMin) / numLevelsDouble;
  double deltaMax =
      ((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble;
  double deltaMid;
  double prevDeltaMid = 0.0;
  for (stepCount = 0; stepCount < 60; ++stepCount) {
    deltaMid = (deltaInit + deltaMax) / 2.0;
    auto fInit =
        fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit);
    auto fMid =
        fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid);
    if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) {
      // Solution found (or step size is infinitesimal and an overshoot).
      // Empirically, this seems to terminate around 30-50 steps or so.
      // This will find a zero point for exactly representable ranges and
      // will terminate on a small step size for inexact, biasing towards
      // overshooting.
      delta = deltaMid;
      break;
    }
    bool signMid = fMid > 0;
    bool signInit = fInit > 0;
    if (signMid == signInit) {
      deltaInit = deltaMid;
    } else {
      deltaMax = deltaMid;
    }
    prevDeltaMid = deltaMid;
  }
  delta = deltaMid;

  // Recalculate adjMin/adjMax based on new delta.
  auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta);
  adjMin = adjMinMax.first;
  adjMax = adjMinMax.second;

  satisfied = false;
  zp = 0;

  if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) {
    satisfied = true;
    // Finally, scale and zeroPoint. Since it casts to integer, only valid
    // if the inputs are valid.
    zp = std::round(storageParams.minValue - adjMin / delta);
  }

  return satisfied;
}

int64_t UniformParamsFromMinMaxSolver::quantize(double x) const {
  int64_t xq = std::round(x / delta + zp);
  return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq));
}

double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const {
  return (xq - zp) * delta;
}

raw_ostream &mlir::quantizer::operator<<(raw_ostream &os,
                                         const UniformStorageParams &p) {
  os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}";
  return os;
}

raw_ostream &
mlir::quantizer::operator<<(raw_ostream &os,
                            const UniformParamsFromMinMaxSolver &s) {
  os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){";
  os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> ";
  if (!s.isSatisfied()) {
    os << "unsat}";
    return os;
  }

  os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")";
  os << ", scale = " << s.getScale();
  os << ", zp = " << s.getZp();
  os << "}";

  return os;
}