SparsePropagation.cpp
20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/SparsePropagation.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/IR/IRBuilder.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
/// To enable interprocedural analysis, we assign LLVM values to the following
/// groups. The register group represents SSA registers, the return group
/// represents the return values of functions, and the memory group represents
/// in-memory values. An LLVM Value can technically be in more than one group.
/// It's necessary to distinguish these groups so we can, for example, track a
/// global variable separately from the value stored at its location.
enum class IPOGrouping { Register, Return, Memory };
/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
/// The PointerIntPair header provides a DenseMapInfo specialization, so using
/// these as LatticeKeys is fine.
using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
} // namespace
namespace llvm {
/// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
/// must translate between LatticeKeys and LLVM Values when adding Values to
/// its work list and inspecting the state of control-flow related values.
template <> struct LatticeKeyInfo<TestLatticeKey> {
static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
return Key.getPointer();
}
static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
return TestLatticeKey(V, IPOGrouping::Register);
}
};
} // namespace llvm
namespace {
/// This class defines a simple test lattice value that could be used for
/// solving problems similar to constant propagation. The value is maintained
/// as a PointerIntPair.
class TestLatticeVal {
public:
/// The states of the lattices value. Only the ConstantVal state is
/// interesting; the rest are special states used by the generic solver. The
/// UntrackedVal state differs from the other three in that the generic
/// solver uses it to avoid doing unnecessary work. In particular, when a
/// value moves to the UntrackedVal state, it's users are not notified.
enum TestLatticeStateTy {
UndefinedVal,
ConstantVal,
OverdefinedVal,
UntrackedVal
};
TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
TestLatticeVal(Constant *C, TestLatticeStateTy State)
: LatticeVal(C, State) {}
/// Return true if this lattice value is in the Constant state. This is used
/// for checking the solver results.
bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
/// Return true if this lattice value is in the Overdefined state. This is
/// used for checking the solver results.
bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
bool operator==(const TestLatticeVal &RHS) const {
return LatticeVal == RHS.LatticeVal;
}
bool operator!=(const TestLatticeVal &RHS) const {
return LatticeVal != RHS.LatticeVal;
}
private:
/// A simple lattice value type for problems similar to constant propagation.
/// It holds the constant value and the lattice state.
PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
};
/// This class defines a simple test lattice function that could be used for
/// solving problems similar to constant propagation. The test lattice differs
/// from a "real" lattice in a few ways. First, it initializes all return
/// values, values stored in global variables, and arguments in the undefined
/// state. This means that there are no limitations on what we can track
/// interprocedurally. For simplicity, all global values in the tests will be
/// given internal linkage, since this is not something this lattice function
/// tracks. Second, it only handles the few instructions necessary for the
/// tests.
class TestLatticeFunc
: public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
public:
/// Construct a new test lattice function with special values for the
/// Undefined, Overdefined, and Untracked states.
TestLatticeFunc()
: AbstractLatticeFunction(
TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
/// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
/// test analysis, a LatticeKey will begin in the undefined state, unless it
/// represents an LLVM Constant in the register grouping.
TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
if (Key.getInt() == IPOGrouping::Register)
if (auto *C = dyn_cast<Constant>(Key.getPointer()))
return TestLatticeVal(C, TestLatticeVal::ConstantVal);
return getUndefVal();
}
/// Merge the two given lattice values. This merge should be equivalent to
/// what is done for constant propagation. That is, the resulting lattice
/// value is constant only if the two given lattice values are constant and
/// hold the same value.
TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
if (X == getUntrackedVal() || Y == getUntrackedVal())
return getUntrackedVal();
if (X == getOverdefinedVal() || Y == getOverdefinedVal())
return getOverdefinedVal();
if (X == getUndefVal() && Y == getUndefVal())
return getUndefVal();
if (X == getUndefVal())
return Y;
if (Y == getUndefVal())
return X;
if (X == Y)
return X;
return getOverdefinedVal();
}
/// Compute the lattice values that change as a result of executing the given
/// instruction. We only handle the few instructions needed for the tests.
void ComputeInstructionState(
Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
switch (I.getOpcode()) {
case Instruction::Call:
return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
case Instruction::Ret:
return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
case Instruction::Store:
return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
default:
return visitInst(I, ChangedValues, SS);
}
}
private:
/// Handle call sites. The state of a called function's argument is the merge
/// of the current formal argument state with the call site's corresponding
/// actual argument state. The call site state is the merge of the call site
/// state with the returned value state of the called function.
void visitCallBase(CallBase &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = I.getCalledFunction();
auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
if (!F) {
ChangedValues[RegI] = getOverdefinedVal();
return;
}
SS.MarkBlockExecutable(&F->front());
for (Argument &A : F->args()) {
auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
auto RegActual =
TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register);
ChangedValues[RegFormal] =
MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
}
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RegI] =
MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
}
/// Handle return instructions. The function's return state is the merge of
/// the returned value state and the function's current return state.
void visitReturn(ReturnInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = I.getParent()->getParent();
if (F->getReturnType()->isVoidTy())
return;
auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RetF] =
MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
}
/// Handle store instructions. If the pointer operand of the store is a
/// global variable, we attempt to track the value. The global variable state
/// is the merge of the stored value state with the current global variable
/// state.
void visitStore(StoreInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
if (!GV)
return;
auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
ChangedValues[MemPtr] =
MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
}
/// Handle all other instructions. All other instructions are marked
/// overdefined.
void visitInst(Instruction &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
ChangedValues[RegI] = getOverdefinedVal();
}
};
/// This class defines the common data used for all of the tests. The tests
/// should add code to the module and then run the solver.
class SparsePropagationTest : public testing::Test {
protected:
LLVMContext Context;
Module M;
IRBuilder<> Builder;
TestLatticeFunc Lattice;
SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
public:
SparsePropagationTest()
: M("", Context), Builder(Context), Solver(&Lattice) {}
};
} // namespace
/// Test that we mark discovered functions executable.
///
/// define internal void @f() {
/// call void @g()
/// ret void
/// }
///
/// define internal void @g() {
/// call void @f()
/// ret void
/// }
///
/// For this test, we initially mark "f" executable, and the solver discovers
/// "g" because of the call in "f". The mutually recursive call in "g" also
/// tests that we don't add a block to the basic block work list if it is
/// already executable. Doing so would put the solver into an infinite loop.
TEST_F(SparsePropagationTest, MarkBlockExecutable) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateCall(F);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
}
/// Test that we propagate information through global variables.
///
/// @gv = internal global i64
///
/// define internal void @f() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// define internal void @g() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// For this test, we initially mark both "f" and "g" executable, and the
/// solver computes the lattice state of the global variable as constant.
TEST_F(SparsePropagationTest, GlobalVariableConstant) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
}
/// Test that we propagate information through global variables.
///
/// @gv = internal global i64
///
/// define internal void @f() {
/// store i64 0, i64* @gv
/// ret void
/// }
///
/// define internal void @g() {
/// store i64 1, i64* @gv
/// ret void
/// }
///
/// For this test, we initially mark both "f" and "g" executable, and the
/// solver computes the lattice state of the global variable as overdefined.
TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(0), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
}
/// Test that we propagate information through function returns.
///
/// define internal i64 @f(i1* %cond) {
/// if:
/// %0 = load i1, i1* %cond
/// br i1 %0, label %then, label %else
///
/// then:
/// ret i64 1
///
/// else:
/// ret i64 1
/// }
///
/// For this test, we initially mark "f" executable, and the solver computes
/// the return value of the function as constant.
TEST_F(SparsePropagationTest, FunctionDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(1));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
}
/// Test that we propagate information through function returns.
///
/// define internal i64 @f(i1* %cond) {
/// if:
/// %0 = load i1, i1* %cond
/// br i1 %0, label %then, label %else
///
/// then:
/// ret i64 0
///
/// else:
/// ret i64 1
/// }
///
/// For this test, we initially mark "f" executable, and the solver computes
/// the return value of the function as overdefined.
TEST_F(SparsePropagationTest, FunctionOverDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(0));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
}
/// Test that we propagate information through arguments.
///
/// define internal void @f() {
/// call void @g(i64 0, i64 1)
/// call void @g(i64 1, i64 1)
/// ret void
/// }
///
/// define internal void @g(i64 %a, i64 %b) {
/// ret void
/// }
///
/// For this test, we initially mark "f" executable, and the solver discovers
/// "g" because of the calls in "f". The solver computes the state of argument
/// "a" as overdefined and the state of "b" as constant.
///
/// In addition, this test demonstrates that ComputeInstructionState can alter
/// the state of multiple lattice values, in addition to the one associated
/// with the instruction definition. Each call instruction in this test updates
/// the state of arguments "a" and "b".
TEST_F(SparsePropagationTest, ComputeInstructionState) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(
FunctionType::get(Builder.getVoidTy(),
{Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
GlobalValue::InternalLinkage, "g", &M);
Argument *A = G->arg_begin();
Argument *B = std::next(G->arg_begin());
A->setName("a");
B->setName("b");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
auto RegA = TestLatticeKey(A, IPOGrouping::Register);
auto RegB = TestLatticeKey(B, IPOGrouping::Register);
EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
}
/// Test that we can handle exceptional terminator instructions.
///
/// declare internal void @p()
///
/// declare internal void @g()
///
/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) {
/// entry:
/// invoke void @g()
/// to label %exit unwind label %catch.pad
///
/// catch.pad:
/// %0 = catchswitch within none [label %catch.body] unwind to caller
///
/// catch.body:
/// %1 = catchpad within %0 []
/// catchret from %1 to label %exit
///
/// exit:
/// ret void
/// }
///
/// For this test, we initially mark the entry block executable. The solver
/// then discovers the rest of the blocks in the function are executable.
TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "p", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Constant *C =
ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy());
F->setPersonalityFn(C);
BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
Builder.SetInsertPoint(Entry);
Builder.CreateInvoke(G, Exit, Pad);
Builder.SetInsertPoint(Pad);
CatchSwitchInst *CatchSwitch =
Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
CatchSwitch->addHandler(Body);
Builder.SetInsertPoint(Body);
CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
Builder.CreateCatchRet(CatchPad, Exit);
Builder.SetInsertPoint(Exit);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(Entry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(Pad));
EXPECT_TRUE(Solver.isBlockExecutable(Body));
EXPECT_TRUE(Solver.isBlockExecutable(Exit));
}