test_anneal.py
2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import print_function
from __future__ import absolute_import
from builtins import range
from functools import partial
import unittest
import numpy as np
from hyperopt import anneal
from hyperopt import rand
from hyperopt import Trials, fmin
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
from .test_domains import CasePerDomain
def passthrough(x):
return x
class TestItJustRuns(unittest.TestCase, CasePerDomain):
def work(self):
trials = Trials()
space = self.bandit.expr
fmin(
fn=passthrough,
space=space,
trials=trials,
algo=anneal.suggest,
max_evals=10,
)
class TestItAtLeastSortOfWorks(unittest.TestCase, CasePerDomain):
thresholds = dict(
quadratic1=1e-5,
q1_lognormal=0.01,
distractor=-0.96, # -- anneal is a strategy that can really
# get tricked by the distractor.
gauss_wave=-2.0,
gauss_wave2=-2.0,
n_arms=-2.5,
many_dists=0.0005,
branin=0.7,
)
iters_thresholds = dict(
# -- running a long way out tests overflow/underflow
# to some extent
quadratic1=1000,
many_dists=200,
# -- anneal is pretty bad at this kind of function
distractor=150,
branin=200,
)
def setUp(self):
self.olderr = np.seterr("raise")
np.seterr(under="ignore")
def tearDown(self, *args):
np.seterr(**self.olderr)
def work(self):
bandit = self.bandit
assert bandit.name is not None
algo = partial(anneal.suggest)
iters_thresholds = self.iters_thresholds.get(bandit.name, 50)
trials = Trials()
fmin(
fn=passthrough,
space=self.bandit.expr,
trials=trials,
algo=algo,
max_evals=iters_thresholds,
rstate=np.random.RandomState(8),
)
assert len(trials) == iters_thresholds
rtrials = Trials()
fmin(
fn=passthrough,
space=self.bandit.expr,
trials=rtrials,
algo=rand.suggest,
max_evals=iters_thresholds,
rstate=np.random.RandomState(8),
)
thresh = self.thresholds[bandit.name]
assert min(trials.losses()) < thresh