vectorize.py 16.8 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
from __future__ import print_function
from __future__ import absolute_import
from builtins import str
from builtins import zip
from builtins import range
from builtins import object
import sys

import numpy as np

from .pyll import Apply
from .pyll import as_apply
from .pyll import dfs
from .pyll import toposort
from .pyll import scope
from .pyll import stochastic

stoch = stochastic.implicit_stochastic_symbols


def ERR(msg):
    print("hyperopt.vectorize.ERR", msg, file=sys.stderr)


@scope.define_pure
def vchoice_split(idxs, choices, n_options):
    rval = [[] for ii in range(n_options)]
    if len(idxs) != len(choices):
        raise ValueError("idxs and choices different len", (len(idxs), len(choices)))
    for ii, cc in zip(idxs, choices):
        rval[cc].append(ii)
    return rval


@scope.define_pure
def vchoice_merge(idxs, choices, *vals):
    rval = []
    assert len(idxs) == len(choices)
    for idx, ch in zip(idxs, choices):
        vi, vv = vals[ch]
        rval.append(vv[list(vi).index(idx)])
    return rval


@scope.define_pure
def idxs_map(idxs, cmd, *args, **kwargs):
    """
    Return the cmd applied at positions idxs, by retrieving args and kwargs
    from the (idxs, vals) pair elements of `args` and `kwargs`.

    N.B. args and kwargs may generally include information for more idx values
    than are requested by idxs.
    """
    # XXX: consider insisting on sorted idxs
    # XXX: use np.searchsorted instead of dct

    if 0:  # these should all be true, but evaluating them is slow
        for ii, (idxs_ii, vals_ii) in enumerate(args):
            for jj in idxs:
                assert jj in idxs_ii
        for kw, (idxs_kw, vals_kw) in list(kwargs.items()):
            for jj in idxs:
                assert jj in idxs_kw

    args_imap = []
    for idxs_j, vals_j in args:
        if len(idxs_j):
            args_imap.append(dict(list(zip(idxs_j, vals_j))))
        else:
            args_imap.append({})

    kwargs_imap = {}
    for kw, (idxs_j, vals_j) in list(kwargs.items()):
        if len(idxs_j):
            kwargs_imap[kw] = dict(list(zip(idxs_j, vals_j)))
        else:
            kwargs_imap[kw] = {}

    f = scope._impls[cmd]
    rval = []
    for ii in idxs:
        try:
            args_nn = [arg_imap[ii] for arg_imap in args_imap]
        except:
            ERR("args_nn %s" % cmd)
            ERR("ii %s" % ii)
            ERR("arg_imap %s" % str(args_imap))
            ERR("args_imap %s" % str(args_imap))
            raise
        try:
            kwargs_nn = dict(
                [(kw, arg_imap[ii]) for kw, arg_imap in list(kwargs_imap.items())]
            )
        except:
            ERR("args_nn %s" % cmd)
            ERR("ii %s" % ii)
            ERR("kw %s" % kw)
            ERR("arg_imap %s" % str(args_imap))
            raise
        try:
            rval_nn = f(*args_nn, **kwargs_nn)
        except:
            ERR("error calling impl of %s" % cmd)
            raise
        rval.append(rval_nn)
    return rval


@scope.define_pure
def idxs_take(idxs, vals, which):
    """
    Return `vals[which]` where `which` is a subset of `idxs`
    """
    # TODO: consider insisting on sorted idxs
    # TODO: use np.searchsorted instead of dct
    assert len(idxs) == len(vals)
    table = dict(list(zip(idxs, vals)))
    return np.asarray([table[w] for w in which])


@scope.define_pure
def uniq(lst):
    s = set()
    rval = []
    for l in lst:
        if id(l) not in s:
            s.add(id(l))
            rval.append(l)
    return rval


def vectorize_stochastic(orig):
    if orig.name == "idxs_map" and orig.pos_args[1]._obj in stoch:
        # -- this is an idxs_map of a random draw of distribution `dist`
        idxs = orig.pos_args[0]
        dist = orig.pos_args[1]._obj

        def foo(arg):
            # -- each argument is an idxs, vals pair
            assert arg.name == "pos_args"
            assert len(arg.pos_args) == 2
            arg_vals = arg.pos_args[1]

            # XXX: write a pattern-substitution rule for this case
            if arg_vals.name == "idxs_take":
                if arg_vals.arg["vals"].name == "asarray":
                    if arg_vals.arg["vals"].inputs()[0].name == "repeat":
                        # -- draws are iid, so forget about
                        #    repeating the distribution parameters
                        repeated_thing = arg_vals.arg["vals"].inputs()[0].inputs()[1]
                        return repeated_thing
            if arg.pos_args[0] is idxs:
                return arg_vals
            else:
                # -- arg.pos_args[0] is a superset of idxs
                #    TODO: slice out correct elements using
                #    idxs_take, but more importantly - test this case.
                raise NotImplementedError()

        new_pos_args = [foo(arg) for arg in orig.pos_args[2:]]
        new_named_args = [[aname, foo(arg)] for aname, arg in orig.named_args]
        vnode = Apply(dist, new_pos_args, new_named_args, o_len=None)
        n_times = scope.len(idxs)
        if "size" in dict(vnode.named_args):
            raise NotImplementedError("random node already has size")
        vnode.named_args.append(["size", n_times])
        return vnode
    else:
        return orig


def replace_repeat_stochastic(expr, return_memo=False):
    nodes = dfs(expr)
    memo = {}
    for ii, orig in enumerate(nodes):
        if orig.name == "idxs_map" and orig.pos_args[1]._obj in stoch:
            # -- this is an idxs_map of a random draw of distribution `dist`
            idxs = orig.pos_args[0]
            dist = orig.pos_args[1]._obj

            def foo(arg):
                # -- each argument is an idxs, vals pair
                assert arg.name == "pos_args"
                assert len(arg.pos_args) == 2
                arg_vals = arg.pos_args[1]
                if arg_vals.name == "asarray" and arg_vals.inputs()[0].name == "repeat":
                    # -- draws are iid, so forget about
                    #    repeating the distribution parameters
                    repeated_thing = arg_vals.inputs()[0].inputs()[1]
                    return repeated_thing
                else:
                    if arg.pos_args[0] is idxs:
                        return arg_vals
                    else:
                        # -- arg.pos_args[0] is a superset of idxs
                        #    TODO: slice out correct elements using
                        #    idxs_take, but more importantly - test this case.
                        raise NotImplementedError()

            new_pos_args = [foo(arg) for arg in orig.pos_args[2:]]
            new_named_args = [[aname, foo(arg)] for aname, arg in orig.named_args]
            vnode = Apply(dist, new_pos_args, new_named_args, None)
            n_times = scope.len(idxs)
            if "size" in dict(vnode.named_args):
                raise NotImplementedError("random node already has size")
            vnode.named_args.append(["size", n_times])
            # -- loop over all nodes that *use* this one, and change them
            for client in nodes[ii + 1 :]:
                client.replace_input(orig, vnode)
            if expr is orig:
                expr = vnode
            memo[orig] = vnode
    if return_memo:
        return expr, memo
    else:
        return expr


class VectorizeHelper(object):
    """
    Convert a pyll expression representing a single trial into a pyll
    expression representing multiple trials.

    The resulting multi-trial expression is not meant to be evaluated
    directly. It is meant to serve as the input to a suggest algo.

    idxs_memo - node in expr graph -> all elements we might need for it
    take_memo - node in expr graph -> all exprs retrieving computed elements

    """

    def __init__(self, expr, expr_idxs, build=True):
        self.expr = expr
        self.expr_idxs = expr_idxs
        self.dfs_nodes = dfs(expr)
        self.params = {}
        for ii, node in enumerate(self.dfs_nodes):
            if node.name == "hyperopt_param":
                label = node.arg["label"].obj
                self.params[label] = node.arg["obj"]
        # -- recursive construction
        #    This makes one term in each idxs, vals memo for every
        #    directed path through the switches in the graph.

        self.idxs_memo = {}  # node -> union, all idxs computed
        self.take_memo = {}  # node -> list of idxs_take retrieving node vals
        self.v_expr = self.build_idxs_vals(expr, expr_idxs)

        # TODO: graph-optimization pass to remove cruft:
        #  - unions of 1
        #  - unions of full sets with their subsets
        #  - idxs_take that can be merged

        self.assert_integrity_idxs_take()

    def assert_integrity_idxs_take(self):
        idxs_memo = self.idxs_memo
        take_memo = self.take_memo
        after = dfs(self.expr)
        assert after == self.dfs_nodes
        assert set(idxs_memo.keys()) == set(take_memo.keys())
        for node in idxs_memo:
            idxs = idxs_memo[node]
            assert idxs.name == "array_union"
            vals = take_memo[node][0].pos_args[1]
            for take in take_memo[node]:
                assert take.name == "idxs_take"
                assert [idxs, vals] == take.pos_args[:2]

    def build_idxs_vals(self, node, wanted_idxs):
        """
        This recursive procedure should be called on an output-node.
        """
        checkpoint_asserts = False

        def checkpoint():
            if checkpoint_asserts:
                self.assert_integrity_idxs_take()
                if node in self.idxs_memo:
                    toposort(self.idxs_memo[node])
                if node in self.take_memo:
                    for take in self.take_memo[node]:
                        toposort(take)

        checkpoint()

        # wanted_idxs are fixed, whereas idxs_memo
        # is full of unions, that can grow in subsequent recursive
        # calls to build_idxs_vals with node as argument.
        assert wanted_idxs != self.idxs_memo.get(node)

        # -- easy exit case
        if node.name == "hyperopt_param":
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg["obj"], wanted_idxs)

        # -- easy exit case
        elif node.name == "hyperopt_result":
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg["obj"], wanted_idxs)

        # -- literal case: always take from universal set
        elif node.name == "literal":
            if node in self.idxs_memo:
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- initialize idxs_memo to full set
                all_idxs = self.expr_idxs
                n_times = scope.len(all_idxs)
                # -- put array_union into graph for consistency, though it is
                # not necessary
                all_idxs = scope.array_union(all_idxs)
                self.idxs_memo[node] = all_idxs
                all_vals = scope.asarray(scope.repeat(n_times, node))
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                assert node not in self.take_memo
                self.take_memo[node] = [wanted_vals]
                checkpoint()
            return wanted_vals

        # -- switch case: complicated
        elif node.name == "switch":
            if node in self.idxs_memo and wanted_idxs in self.idxs_memo[node].pos_args:
                # -- phew, easy case
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- we need to add some indexes
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]
                    assert all_idxs.name == "array_union"
                    all_idxs.pos_args.append(wanted_idxs)
                else:
                    all_idxs = scope.array_union(wanted_idxs)

                choice = node.pos_args[0]
                all_choices = self.build_idxs_vals(choice, all_idxs)

                options = node.pos_args[1:]
                args_idxs = scope.vchoice_split(all_idxs, all_choices, len(options))
                all_vals = scope.vchoice_merge(all_idxs, all_choices)
                for opt_ii, idxs_ii in zip(options, args_idxs):
                    all_vals.pos_args.append(
                        as_apply([idxs_ii, self.build_idxs_vals(opt_ii, idxs_ii)])
                    )

                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs,
                )  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == "array_union"
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    for take in self.take_memo[node]:
                        assert take.name == "idxs_take"
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        # -- general case
        else:
            # -- this is a general node.
            #    It is generally handled with idxs_memo,
            #    but vectorize_stochastic may immediately transform it into
            #    a more compact form.
            if node in self.idxs_memo and wanted_idxs in self.idxs_memo[node].pos_args:
                # -- phew, easy case
                for take in self.take_memo[node]:
                    if take.pos_args[2] == wanted_idxs:
                        return take
                raise NotImplementedError("how did this happen?")
                # all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                # wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                # self.take_memo[node].append(wanted_vals)
                # checkpoint()
            else:
                # XXX
                # -- determine if wanted_idxs is actually a subset of the idxs
                # that we are already computing.  This is not only an
                # optimization, but prevents the creation of cycles, which
                # would otherwise occur if we have a graph of the form
                # switch(f(a), g(a), 0). If there are other switches inside f
                # and g, does this get trickier?

                # -- assume we need to add some indexes
                checkpoint()
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]

                else:
                    all_idxs = scope.array_union(wanted_idxs)
                checkpoint()

                all_vals = scope.idxs_map(all_idxs, node.name)
                for ii, aa in enumerate(node.pos_args):
                    all_vals.pos_args.append(
                        as_apply([all_idxs, self.build_idxs_vals(aa, all_idxs)])
                    )
                    checkpoint()
                for ii, (nn, aa) in enumerate(node.named_args):
                    all_vals.named_args.append(
                        [nn, as_apply([all_idxs, self.build_idxs_vals(aa, all_idxs)])]
                    )
                    checkpoint()
                all_vals = vectorize_stochastic(all_vals)

                checkpoint()
                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs,
                )  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == "array_union"
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    toposort(self.idxs_memo[node])
                    # -- this catches the cycle bug mentioned above
                    for take in self.take_memo[node]:
                        assert take.name == "idxs_take"
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        return wanted_vals

    def idxs_by_label(self):
        return dict(
            [(name, self.idxs_memo[node]) for name, node in list(self.params.items())]
        )

    def vals_by_label(self):
        return dict(
            [
                (name, self.take_memo[node][0].pos_args[1])
                for name, node in list(self.params.items())
            ]
        )