ReductionNode.cpp
4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the reduction nodes which are used to track of the
// metadata for a specific generated variant within a reduction pass and are the
// building blocks of the reduction tree structure. A reduction tree is used to
// keep track of the different generated variants throughout a reduction pass in
// the MLIR Reduce tool.
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionNode.h"
using namespace mlir;
/// Sets up the metadata and links the node to its parent.
ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
: module(module), evaluated(false) {
if (parent != nullptr)
parent->linkVariant(this);
}
ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
std::vector<bool> transformSpace)
: module(module), evaluated(false), transformSpace(transformSpace) {
if (parent != nullptr)
parent->linkVariant(this);
}
/// Calculates and updates the size and interesting values of the module.
void ReductionNode::measureAndTest(const Tester &test) {
SmallString<128> filepath;
int fd;
// Print module to temprary file.
std::error_code ec =
llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
if (ec)
llvm::report_fatal_error("Error making unique filename: " + ec.message());
llvm::ToolOutputFile out(filepath, fd);
module.print(out.os());
out.os().close();
if (out.os().has_error())
llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
size = out.os().tell();
interesting = test.isInteresting(filepath);
evaluated = true;
}
/// Returns true if the size and interestingness have been calculated.
bool ReductionNode::isEvaluated() const { return evaluated; }
/// Returns the size in bytes of the module.
int ReductionNode::getSize() const { return size; }
/// Returns true if the module exhibits the interesting behavior.
bool ReductionNode::isInteresting() const { return interesting; }
/// Returns the pointers to the child variants.
ReductionNode *ReductionNode::getVariant(unsigned long index) const {
if (index < variants.size())
return variants[index].get();
return nullptr;
}
/// Returns the number of child variants.
int ReductionNode::variantsSize() const { return variants.size(); }
/// Returns true if the child variants vector is empty.
bool ReductionNode::variantsEmpty() const { return variants.empty(); }
/// Link a child variant node.
void ReductionNode::linkVariant(ReductionNode *newVariant) {
std::unique_ptr<ReductionNode> ptrVariant(newVariant);
variants.push_back(std::move(ptrVariant));
}
/// Sort the child variants and remove the uninteresting ones.
void ReductionNode::organizeVariants(const Tester &test) {
// Ensure all variants are evaluated.
for (auto &var : variants)
if (!var->isEvaluated())
var->measureAndTest(test);
// Sort variants by interestingness and size.
llvm::array_pod_sort(
variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
return 0;
if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
return 1;
return (lhs->get()->getSize(), rhs->get()->getSize());
});
int interestingCount = 0;
for (auto &var : variants) {
if (var->isInteresting()) {
++interestingCount;
} else {
break;
}
}
// Remove uninteresting variants.
variants.resize(interestingCount);
}
/// Returns the number of non transformed indices.
int ReductionNode::transformSpaceSize() {
return std::count(transformSpace.begin(), transformSpace.end(), false);
}
/// Returns a vector of the transformable indices in the Module.
const std::vector<bool> ReductionNode::getTransformSpace() {
return transformSpace;
}