pyll_utils.py
6.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from __future__ import absolute_import
from builtins import str
from builtins import zip
from builtins import range
from past.builtins import basestring
from builtins import object
from functools import partial, wraps
from .base import DuplicateLabel
from .pyll.base import Apply, Literal, MissingArgument
from .pyll import scope
from .pyll import as_apply
def validate_label(f):
@wraps(f)
def wrapper(label, *args, **kwargs):
is_real_string = isinstance(label, basestring)
is_literal_string = isinstance(label, Literal) and isinstance(
label.obj, basestring
)
if not is_real_string and not is_literal_string:
raise TypeError("require string label")
return f(label, *args, **kwargs)
return wrapper
#
# Hyperparameter Types
#
@scope.define
def hyperopt_param(label, obj):
""" A graph node primarily for annotating - VectorizeHelper looks out
for these guys, and optimizes subgraphs of the form:
hyperopt_param(<stochastic_expression>(...))
"""
return obj
@validate_label
def hp_pchoice(label, p_options):
"""
label: string
p_options: list of (probability, option) pairs
"""
p, options = list(zip(*p_options))
ch = scope.hyperopt_param(label, scope.categorical(p))
return scope.switch(ch, *options)
@validate_label
def hp_choice(label, options):
ch = scope.hyperopt_param(label, scope.randint(len(options)))
return scope.switch(ch, *options)
@validate_label
def hp_randint(label, *args, **kwargs):
return scope.hyperopt_param(label, scope.randint(*args, **kwargs))
@validate_label
def hp_uniform(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.uniform(*args, **kwargs)))
@validate_label
def hp_uniformint(label, *args, **kwargs):
args += (1.0,)
return scope.int(hp_quniform(label, *args, **kwargs))
@validate_label
def hp_quniform(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.quniform(*args, **kwargs)))
@validate_label
def hp_loguniform(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.loguniform(*args, **kwargs)))
@validate_label
def hp_qloguniform(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.qloguniform(*args, **kwargs)))
@validate_label
def hp_normal(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.normal(*args, **kwargs)))
@validate_label
def hp_qnormal(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.qnormal(*args, **kwargs)))
@validate_label
def hp_lognormal(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.lognormal(*args, **kwargs)))
@validate_label
def hp_qlognormal(label, *args, **kwargs):
return scope.float(scope.hyperopt_param(label, scope.qlognormal(*args, **kwargs)))
#
# Tools for extracting a search space from a Pyll graph
#
class Cond(object):
def __init__(self, name, val, op):
self.op = op
self.name = name
self.val = val
def __str__(self):
return "Cond{%s %s %s}" % (self.name, self.op, self.val)
def __eq__(self, other):
return self.op == other.op and self.name == other.name and self.val == other.val
def __hash__(self):
return hash((self.op, self.name, self.val))
def __repr__(self):
return str(self)
EQ = partial(Cond, op="=")
def _expr_to_config(expr, conditions, hps):
if expr.name == "switch":
idx = expr.inputs()[0]
options = expr.inputs()[1:]
assert idx.name == "hyperopt_param"
assert idx.arg["obj"].name in (
"randint", # -- in case of hp.choice
"categorical", # -- in case of hp.pchoice
)
_expr_to_config(idx, conditions, hps)
for ii, opt in enumerate(options):
_expr_to_config(opt, conditions + (EQ(idx.arg["label"].obj, ii),), hps)
elif expr.name == "hyperopt_param":
label = expr.arg["label"].obj
if label in hps:
if hps[label]["node"] != expr.arg["obj"]:
raise DuplicateLabel(label)
hps[label]["conditions"].add(conditions)
else:
hps[label] = {
"node": expr.arg["obj"],
"conditions": set((conditions,)),
"label": label,
}
else:
for ii in expr.inputs():
_expr_to_config(ii, conditions, hps)
def expr_to_config(expr, conditions, hps):
"""
Populate dictionary `hps` with the hyperparameters in pyll graph `expr`
and conditions for participation in the evaluation of `expr`.
Arguments:
expr - a pyll expression root.
conditions - a tuple of conditions (`Cond`) that must be True for
`expr` to be evaluated.
hps - dictionary to populate
Creates `hps` dictionary:
label -> { 'node': apply node of hyperparameter distribution,
'conditions': `conditions` + tuple,
'label': label
}
"""
expr = as_apply(expr)
if conditions is None:
conditions = ()
assert isinstance(expr, Apply)
_expr_to_config(expr, conditions, hps)
_remove_allpaths(hps, conditions)
def _remove_allpaths(hps, conditions):
"""Hacky way to recognize some kinds of false dependencies
Better would be logic programming.
"""
potential_conds = {}
for k, v in list(hps.items()):
if v["node"].name == "randint":
low = v["node"].arg["low"].obj
# if high is None, the domain is [0, low), else it is [low, high)
domain_size = (
v["node"].arg["high"].obj - low
if v["node"].arg["high"] != MissingArgument
else low
)
potential_conds[k] = frozenset([EQ(k, ii) for ii in range(domain_size)])
elif v["node"].name == "categorical":
p = v["node"].arg["p"].obj
potential_conds[k] = frozenset([EQ(k, ii) for ii in range(p.size)])
for k, v in list(hps.items()):
if len(v["conditions"]) > 1:
all_conds = [[c for c in cond if c is not True] for cond in v["conditions"]]
all_conds = [cond for cond in all_conds if len(cond) >= 1]
if len(all_conds) == 0:
v["conditions"] = set([conditions])
continue
depvar = all_conds[0][0].name
all_one_var = all(
len(cond) == 1 and cond[0].name == depvar for cond in all_conds
)
if all_one_var:
conds = [cond[0] for cond in all_conds]
if frozenset(conds) == potential_conds[depvar]:
v["conditions"] = set([conditions])
continue
# -- eof