ReferenceImplGen.cpp 3.39 KB
//===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// ReferenceImplGen uses the description of operations to generate reference
// implementations for the ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"

using namespace llvm;
using namespace mlir;

using mlir::tblgen::Operator;

static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
                                         raw_ostream &os) {
  emitSourceFileHeader("Reference implementation file", os);
  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");

  os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n"
     << "  using namespace ::mlir::edsc;\n"
     << "if (false) {}";
  for (auto *def : defs) {
    Operator op(def);
    auto referenceImplGenerator = def->getValueInit("referenceImplementation");
    if (!referenceImplGenerator)
      continue;
    os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
       << "  edsc::ScopedContext scope(f);\n";

    for (auto en : llvm::enumerate(op.getOperands())) {
      os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
                              "(void)arg_{0};\n",
                              en.value().name, en.index());
      // TODO(jpienaar): this is generally incorrect, not all args are memref
      // in the general case.
      os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
                              "(void)view_{0};\n",
                              en.value().name, en.index());
    }
    unsigned numOperands = op.getNumOperands();
    unsigned numResults = op.getNumResults();
    for (unsigned idx = 0; idx < numResults; ++idx) {
      os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
                              "(void)arg_{0};\n",
                              op.getResult(idx).name, numOperands + idx);
      // TODO(jpienaar): this is generally incorrect, not all args are memref
      // in the general case.
      os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
                              "(void)view_{0};\n",
                              op.getResult(idx).name, numOperands + idx);
    }

    // Print the EDSC.
    os << referenceImplGenerator->getAsUnquotedString() << "\n";
    os.indent(2) << "f->print(llvm::outs());\n\n";
    os << "}";
  }
  os << " else {\n";
  os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
  os.indent(2) << "return;\n";
  os << "}\n";
  os << "}\n";
}

static mlir::GenRegistration
    genRegister("gen-reference-implementations",
                "Generate reference implementations",
                [](const RecordKeeper &records, raw_ostream &os) {
                  emitReferenceImplementations(records, os);
                  return false;
                });