base.py 34.1 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
"""Base classes / Design

The design is that there are three components fitting together in this project:

- Trials - a list of documents including at least sub-documents:
    ['spec'] - the specification of hyper-parameters for a job
    ['result'] - the result of Domain.evaluate(). Typically includes:
        ['status'] - one of the STATUS_STRINGS
        ['loss'] - real-valued scalar that hyperopt is trying to minimize
    ['idxs'] - compressed representation of spec
    ['vals'] - compressed representation of spec
    ['tid'] - trial id (unique in Trials list)

- Domain - specifies a search problem

- Ctrl - a channel for two-way communication
         between an Experiment and Domain.evaluate.
         Experiment subclasses may subclass Ctrl to match. For example, if an
         experiment is going to dispatch jobs in other threads, then an
         appropriate thread-aware Ctrl subclass should go with it.

"""
from __future__ import print_function
from __future__ import absolute_import

import numbers
from builtins import str
from builtins import map
from builtins import zip
from builtins import range
from past.builtins import basestring
from builtins import object
import logging
import datetime
import sys

import numpy as np

try:
    import bson  # -- comes with pymongo
    from bson.objectid import ObjectId

    have_bson = True
except ImportError:
    have_bson = False

from . import pyll
from .pyll.stochastic import recursive_set_rng_kwarg

from .exceptions import (
    DuplicateLabel,
    InvalidTrial,
    InvalidResultStatus,
    InvalidLoss,
    AllTrialsFailed,
)
from .utils import pmin_sampled
from .utils import use_obj_for_literal_in_memo
from .vectorize import VectorizeHelper

__authors__ = "James Bergstra"
__license__ = "3-clause BSD License"
__contact__ = "github.com/hyperopt/hyperopt"

logger = logging.getLogger(__name__)


# -- STATUS values
#    An eval_fn returning a dictionary must have a status key with
#    one of these values. They are used by optimization routines
#    and plotting functions.

STATUS_NEW = "new"
STATUS_RUNNING = "running"
STATUS_SUSPENDED = "suspended"
STATUS_OK = "ok"
STATUS_FAIL = "fail"
STATUS_STRINGS = (
    "new",  # computations have not started
    "running",  # computations are in prog
    "suspended",  # computations have been suspended, job is not finished
    "ok",  # computations are finished, terminated normally
    "fail",
)  # computations are finished, terminated with error
#   - result['status_fail'] should contain more info


# -- JOBSTATE values
# These are used internally by the scheduler.
# These values are used to communicate between an Experiment
# and a worker process. Consider moving them to mongoexp.

# -- named constants for job execution pipeline
JOB_STATE_NEW = 0
JOB_STATE_RUNNING = 1
JOB_STATE_DONE = 2
JOB_STATE_ERROR = 3
JOB_STATE_CANCEL = 4
JOB_STATES = [
    JOB_STATE_NEW,
    JOB_STATE_RUNNING,
    JOB_STATE_DONE,
    JOB_STATE_ERROR,
    JOB_STATE_CANCEL,
]
JOB_VALID_STATES = {JOB_STATE_NEW, JOB_STATE_RUNNING, JOB_STATE_DONE}


TRIAL_KEYS = [
    "tid",
    "spec",
    "result",
    "misc",
    "state",
    "owner",
    "book_time",
    "refresh_time",
    "exp_key",
]

TRIAL_MISC_KEYS = ["tid", "cmd", "idxs", "vals"]


def _all_same(*args):
    return 1 == len(set(args))


def SONify(arg, memo=None):
    if not have_bson:
        return arg
    add_arg_to_raise = True
    try:
        if memo is None:
            memo = {}
        if id(arg) in memo:
            rval = memo[id(arg)]
        if isinstance(arg, ObjectId):
            rval = arg
        elif isinstance(arg, datetime.datetime):
            rval = arg
        elif isinstance(arg, np.floating):
            rval = float(arg)
        elif isinstance(arg, np.integer):
            rval = int(arg)
        elif isinstance(arg, (list, tuple)):
            rval = type(arg)([SONify(ai, memo) for ai in arg])
        elif isinstance(arg, dict):
            rval = dict(
                [(SONify(k, memo), SONify(v, memo)) for k, v in list(arg.items())]
            )
        elif isinstance(arg, (basestring, float, int, int, type(None))):
            rval = arg
        elif isinstance(arg, np.ndarray):
            if arg.ndim == 0:
                rval = SONify(arg.sum())
            else:
                rval = list(map(SONify, arg))  # N.B. memo None
        # -- put this after ndarray because ndarray not hashable
        elif isinstance(arg, bool):
            rval = int(arg)
        else:
            add_arg_to_raise = False
            raise TypeError("SONify", arg)
    except Exception as e:
        if add_arg_to_raise:
            e.args = e.args + (arg,)
        raise
    memo[id(rval)] = rval
    return rval


def miscs_update_idxs_vals(miscs, idxs, vals, assert_all_vals_used=True, idxs_map=None):
    """
    Unpack the idxs-vals format into the list of dictionaries that is `misc`.

    idxs_map: a dictionary of id->id mappings so that the misc['idxs'] can
        contain different numbers than the idxs argument. XXX CLARIFY
    """
    if idxs_map is None:
        idxs_map = {}

    assert set(idxs.keys()) == set(vals.keys())

    misc_by_id = dict([(m["tid"], m) for m in miscs])
    for m in miscs:
        m["idxs"] = {key: [] for key in idxs}
        m["vals"] = {key: [] for key in idxs}

    for key in idxs:
        assert len(idxs[key]) == len(vals[key])
        for tid, val in zip(idxs[key], vals[key]):
            tid = idxs_map.get(tid, tid)
            if assert_all_vals_used or tid in misc_by_id:
                misc_by_id[tid]["idxs"][key] = [tid]
                misc_by_id[tid]["vals"][key] = [val]

    return miscs


def miscs_to_idxs_vals(miscs, keys=None):
    if keys is None:
        if len(miscs) == 0:
            raise ValueError("cannot infer keys from empty miscs")
        keys = list(miscs[0]["idxs"].keys())
    idxs, vals = {k: [] for k in keys}, {k: [] for k in keys}
    for misc in miscs:
        for node_id in idxs:
            t_idxs = misc["idxs"][node_id]
            t_vals = misc["vals"][node_id]
            assert len(t_idxs) == len(t_vals)
            assert t_idxs == [] or t_idxs == [misc["tid"]]
            idxs[node_id].extend(t_idxs)
            vals[node_id].extend(t_vals)
    return idxs, vals


def spec_from_misc(misc):
    spec = {}
    for k, v in list(misc["vals"].items()):
        if len(v) == 0:
            pass
        elif len(v) == 1:
            spec[k] = v[0]
        else:
            raise NotImplementedError("multiple values", (k, v))
    return spec


def validate_timeout(timeout):
    if timeout is not None and (
        not isinstance(timeout, numbers.Number)
        or timeout <= 0
        or isinstance(timeout, bool)
    ):
        raise Exception(
            "The timeout argument should be None or a positive value. "
            "Given value: {timeout}".format(timeout=timeout)
        )
        
def validate_loss_threshold(loss_threshold):
    if loss_threshold is not None and (
        not isinstance(loss_threshold, numbers.Number)
        or isinstance(loss_threshold, bool)
    ):
        raise Exception(
            "The loss_threshold argument should be None or a numeric value. "
            "Given value: {loss_threshold}".format(loss_threshold=loss_threshold)
        )


class Trials(object):
    """Database interface supporting data-driven model-based optimization.

    The model-based optimization algorithms used by hyperopt's fmin function
    work by analyzing samples of a response surface--a history of what points
    in the search space were tested, and what was discovered by those tests.
    A Trials instance stores that history and makes it available to fmin and
    to the various optimization algorithms.

    This class (`base.Trials`) is a pure-Python implementation of the database
    in terms of lists of dictionaries.  Subclass `mongoexp.MongoTrials`
    implements the same API in terms of a mongodb database running in another
    process. Other subclasses may be implemented in future.

    The elements of `self.trials` represent all of the completed, in-progress,
    and scheduled evaluation points from an e.g. `fmin` call.

    Each element of `self.trials` is a dictionary with *at least* the following
    keys:

    * **tid**: a unique trial identification object within this Trials instance
      usually it is an integer, but it isn't obvious that other sortable,
      hashable objects couldn't be used at some point.

    * **result**: a sub-dictionary representing what was returned by the fmin
      evaluation function. This sub-dictionary has a key 'status' with a value
      from `STATUS_STRINGS` and the status is `STATUS_OK`, then there should be
      a 'loss' key as well with a floating-point value.  Other special keys in
      this sub-dictionary may be used by optimization algorithms  (see them
      for details). Other keys in this sub-dictionary can be used by the
      evaluation function to store miscelaneous diagnostics and debugging
      information.

    * **misc**: despite generic name, this is currently where the trial's
      hyperparameter assigments are stored. This sub-dictionary has two
      elements: `'idxs'` and `'vals'`. The `vals` dictionary is
      a sub-sub-dictionary mapping each hyperparameter to either `[]` (if the
      hyperparameter is inactive in this trial), or `[<val>]` (if the
      hyperparameter is active). The `idxs` dictionary is technically
      redundant -- it is the same as `vals` but it maps hyperparameter names
      to either `[]` or `[<tid>]`.

    """

    asynchronous = False

    def __init__(self, exp_key=None, refresh=True):
        self._ids = set()
        self._dynamic_trials = []
        self._exp_key = exp_key
        self.attachments = {}
        if refresh:
            self.refresh()

    def view(self, exp_key=None, refresh=True):
        rval = object.__new__(self.__class__)
        rval._exp_key = exp_key
        rval._ids = self._ids
        rval._dynamic_trials = self._dynamic_trials
        rval.attachments = self.attachments
        if refresh:
            rval.refresh()
        return rval

    def aname(self, trial, name):
        return "ATTACH::%s::%s" % (trial["tid"], name)

    def trial_attachments(self, trial):
        """
        Support syntax for load:  self.trial_attachments(doc)[name]
        # -- does this work syntactically?
        #    (In any event a 2-stage store will work)
        Support syntax for store: self.trial_attachments(doc)[name] = value
        """

        # don't offer more here than in MongoCtrl
        class Attachments(object):
            def __contains__(_self, name):
                return self.aname(trial, name) in self.attachments

            def __getitem__(_self, name):
                return self.attachments[self.aname(trial, name)]

            def __setitem__(_self, name, value):
                self.attachments[self.aname(trial, name)] = value

            def __delitem__(_self, name):
                del self.attachments[self.aname(trial, name)]

        return Attachments()

    def __iter__(self):
        try:
            return iter(self._trials)
        except AttributeError:
            print("You have to refresh before you iterate", file=sys.stderr)
            raise

    def __len__(self):
        try:
            return len(self._trials)
        except AttributeError:
            print("You have to refresh before you compute len", file=sys.stderr)
            raise

    def __getitem__(self, item):
        # -- how to make it obvious whether indexing is by _trials position
        #    or by tid if both are integers?
        raise NotImplementedError("")

    def refresh(self):
        # In MongoTrials, this method fetches from database
        if self._exp_key is None:
            self._trials = [
                tt for tt in self._dynamic_trials if tt["state"] in JOB_VALID_STATES
            ]
        else:
            self._trials = [
                tt
                for tt in self._dynamic_trials
                if (tt["state"] in JOB_VALID_STATES and tt["exp_key"] == self._exp_key)
            ]
        self._ids.update([tt["tid"] for tt in self._trials])

    @property
    def trials(self):
        return self._trials

    @property
    def tids(self):
        return [tt["tid"] for tt in self._trials]

    @property
    def specs(self):
        return [tt["spec"] for tt in self._trials]

    @property
    def results(self):
        return [tt["result"] for tt in self._trials]

    @property
    def miscs(self):
        return [tt["misc"] for tt in self._trials]

    @property
    def idxs_vals(self):
        return miscs_to_idxs_vals(self.miscs)

    @property
    def idxs(self):
        return self.idxs_vals[0]

    @property
    def vals(self):
        return self.idxs_vals[1]

    def assert_valid_trial(self, trial):
        if not (hasattr(trial, "keys") and hasattr(trial, "values")):
            raise InvalidTrial("trial should be dict-like", trial)
        for key in TRIAL_KEYS:
            if key not in trial:
                raise InvalidTrial("trial missing key %s", key)
        for key in TRIAL_MISC_KEYS:
            if key not in trial["misc"]:
                raise InvalidTrial('trial["misc"] missing key', key)
        if trial["tid"] != trial["misc"]["tid"]:
            raise InvalidTrial("tid mismatch between root and misc", trial)
        # -- check for SON-encodable
        if have_bson:
            try:
                bson.BSON.encode(trial)
            except:
                # TODO: save the trial object somewhere to inspect, fix, re-insert
                #       so that precious data is not simply deallocated and lost.
                print("-" * 80)
                print("CANT ENCODE")
                print("-" * 80)
                raise
        if trial["exp_key"] != self._exp_key:
            raise InvalidTrial("wrong exp_key", (trial["exp_key"], self._exp_key))
        # XXX how to assert that tids are unique?
        return trial

    def _insert_trial_docs(self, docs):
        """insert with no error checking
        """
        rval = [doc["tid"] for doc in docs]
        self._dynamic_trials.extend(docs)
        return rval

    def insert_trial_doc(self, doc):
        """insert trial after error checking

        Does not refresh. Call self.refresh() for the trial to appear in
        self.specs, self.results, etc.
        """
        doc = self.assert_valid_trial(SONify(doc))
        return self._insert_trial_docs([doc])[0]
        # refreshing could be done fast in this base implementation, but with
        # a real DB the steps should be separated.

    def insert_trial_docs(self, docs):
        """ trials - something like is returned by self.new_trial_docs()
        """
        docs = [self.assert_valid_trial(SONify(doc)) for doc in docs]
        return self._insert_trial_docs(docs)

    def new_trial_ids(self, n):
        aa = len(self._ids)
        rval = list(range(aa, aa + n))
        self._ids.update(rval)
        return rval

    def new_trial_docs(self, tids, specs, results, miscs):
        assert len(tids) == len(specs) == len(results) == len(miscs)
        trials_docs = []
        for tid, spec, result, misc in zip(tids, specs, results, miscs):
            doc = {"state": JOB_STATE_NEW, "tid": tid, "spec": spec, "result": result, "misc": misc,
                   "exp_key": self._exp_key, "owner": None, "version": 0, "book_time": None, "refresh_time": None}
            trials_docs.append(doc)
        return trials_docs

    def source_trial_docs(self, tids, specs, results, miscs, sources):
        assert _all_same(list(map(len, [tids, specs, results, miscs, sources])))
        rval = []
        for tid, spec, result, misc, source in zip(
            tids, specs, results, miscs, sources
        ):
            doc = dict(
                version=0,
                tid=tid,
                spec=spec,
                result=result,
                misc=misc,
                state=source["state"],
                exp_key=source["exp_key"],
                owner=source["owner"],
                book_time=source["book_time"],
                refresh_time=source["refresh_time"],
            )
            # -- ensure that misc has the following fields,
            #    some of which may already by set correctly.
            assign = ("tid", tid), ("cmd", None), ("from_tid", source["tid"])
            for k, v in assign:
                assert doc["misc"].setdefault(k, v) == v
            rval.append(doc)
        return rval

    def delete_all(self):
        self._dynamic_trials = []
        self.attachments = {}
        self.refresh()

    def count_by_state_synced(self, arg, trials=None):
        """
        Return trial counts by looking at self._trials
        """
        if trials is None:
            trials = self._trials
        if arg in JOB_STATES:
            queue = [doc for doc in trials if doc["state"] == arg]
        elif hasattr(arg, "__iter__"):
            states = set(arg)
            assert all([x in JOB_STATES for x in states])
            queue = [doc for doc in trials if doc["state"] in states]
        else:
            raise TypeError(arg)
        rval = len(queue)
        return rval

    def count_by_state_unsynced(self, arg):
        """
        Return trial counts that count_by_state_synced would return if we
        called refresh() first.
        """
        if self._exp_key is not None:
            exp_trials = [
                tt for tt in self._dynamic_trials if tt["exp_key"] == self._exp_key
            ]
        else:
            exp_trials = self._dynamic_trials
        return self.count_by_state_synced(arg, trials=exp_trials)

    def losses(self, bandit=None):
        if bandit is None:
            return [r.get("loss") for r in self.results]
        else:
            return list(map(bandit.loss, self.results, self.specs))

    def statuses(self, bandit=None):
        if bandit is None:
            return [r.get("status") for r in self.results]
        else:
            return list(map(bandit.status, self.results, self.specs))

    def average_best_error(self, bandit=None):
        """Return the average best error of the experiment

        Average best error is defined as the average of bandit.true_loss,
        weighted by the probability that the corresponding bandit.loss is best.

        For domains with loss measurement variance of 0, this function simply
        returns the true_loss corresponding to the result with the lowest loss.
        """

        if bandit is None:
            results = self.results
            loss = [r["loss"] for r in results if r["status"] == STATUS_OK]
            loss_v = [
                r.get("loss_variance", 0) for r in results if r["status"] == STATUS_OK
            ]
            true_loss = [
                r.get("true_loss", r["loss"])
                for r in results
                if r["status"] == STATUS_OK
            ]
        else:

            def fmap(f):
                rval = np.asarray(
                    [
                        f(r, s)
                        for (r, s) in zip(self.results, self.specs)
                        if bandit.status(r) == STATUS_OK
                    ]
                ).astype("float")
                if not np.all(np.isfinite(rval)):
                    raise ValueError()
                return rval

            loss = fmap(bandit.loss)
            loss_v = fmap(bandit.loss_variance)
            true_loss = fmap(bandit.true_loss)
        loss3 = list(zip(loss, loss_v, true_loss))
        if not loss3:
            raise ValueError("Empty loss vector")
        loss3.sort()
        loss3 = np.asarray(loss3)
        if np.all(loss3[:, 1] == 0):
            best_idx = np.argmin(loss3[:, 0])
            return loss3[best_idx, 2]
        else:
            cutoff = 0
            sigma = np.sqrt(loss3[0][1])
            while cutoff < len(loss3) and loss3[cutoff][0] < loss3[0][0] + 3 * sigma:
                cutoff += 1
            pmin = pmin_sampled(loss3[:cutoff, 0], loss3[:cutoff, 1])
            avg_true_loss = (pmin * loss3[:cutoff, 2]).sum()
            return avg_true_loss

    @property
    def best_trial(self):
        """
        Trial with lowest non-NaN loss and status=STATUS_OK.
        If no such trial exists, returns None.
        """
        candidates = [
            t
            for t in self.trials
            if t["result"]["status"] == STATUS_OK and not np.isnan(t["result"]["loss"])
        ]
        if not candidates:
            raise AllTrialsFailed
        losses = [float(t["result"]["loss"]) for t in candidates]
        if len(losses) == 0:
            return None
        best = np.argmin(losses)
        return candidates[best]

    @property
    def argmin(self):
        best_trial = self.best_trial
        vals = best_trial["misc"]["vals"]
        # unpack the one-element lists to values
        # and skip over the 0-element lists
        rval = {}
        for k, v in list(vals.items()):
            if v:
                rval[k] = v[0]
        return rval

    def fmin(
        self,
        fn,
        space,
        algo,
        max_evals,
        timeout=None,
        loss_threshold=None,
        max_queue_len=1,
        rstate=None,
        verbose=False,
        pass_expr_memo_ctrl=None,
        catch_eval_exceptions=False,
        return_argmin=True,
        show_progressbar=True,
    ):
        """Minimize a function over a hyperparameter space.

        For most parameters, see `hyperopt.fmin.fmin`.

        Parameters
        ----------

        catch_eval_exceptions : bool, default False
            If set to True, exceptions raised by either the evaluation of the
            configuration space from hyperparameters or the execution of `fn`
            , will be caught by fmin, and recorded in self._dynamic_trials as
            error jobs (JOB_STATE_ERROR).  If set to False, such exceptions
            will not be caught, and so they will propagate to calling code.

        show_progressbar : bool or context manager, default True.
            Show a progressbar. See `hyperopt.progress` for customizing progress reporting.

        """
        # -- Stop-gap implementation!
        #    fmin should have been a Trials method in the first place
        #    but for now it's still sitting in another file.
        from .fmin import fmin

        return fmin(
            fn,
            space,
            algo,
            max_evals,
            timeout=timeout,
            loss_threshold=loss_threshold,
            trials=self,
            rstate=rstate,
            verbose=verbose,
            max_queue_len=max_queue_len,
            allow_trials_fmin=False,  # -- prevent recursion
            pass_expr_memo_ctrl=pass_expr_memo_ctrl,
            catch_eval_exceptions=catch_eval_exceptions,
            return_argmin=return_argmin,
            show_progressbar=show_progressbar,
        )


def trials_from_docs(docs, validate=True, **kwargs):
    """Construct a Trials base class instance from a list of trials documents
    """
    rval = Trials(**kwargs)
    if validate:
        rval.insert_trial_docs(docs)
    else:
        rval._insert_trial_docs(docs)
    rval.refresh()
    return rval


class Ctrl(object):
    """Control object for interruptible, checkpoint-able evaluation
    """

    info = logger.info
    warn = logger.warning
    error = logger.error
    debug = logger.debug

    def __init__(self, trials, current_trial=None):
        # -- attachments should be used like
        #      attachments[key]
        #      attachments[key] = value
        #    where key and value are strings. Client code should not
        #    expect any dictionary-like behaviour beyond that (no update)
        if trials is None:
            self.trials = Trials()
        else:
            self.trials = trials
        self.current_trial = current_trial

    def checkpoint(self, r=None):
        assert self.current_trial in self.trials._trials
        if r is not None:
            self.current_trial["result"] = r

    @property
    def attachments(self):
        """
        Support syntax for load:  self.attachments[name]
        Support syntax for store: self.attachments[name] = value
        """
        return self.trials.trial_attachments(trial=self.current_trial)

    def inject_results(self, specs, results, miscs, new_tids=None):
        """Inject new results into self.trials

        Returns ??? XXX

        new_tids can be None, in which case new tids will be generated
        automatically

        """
        trial = self.current_trial
        assert trial is not None
        num_news = len(specs)
        assert len(specs) == len(results) == len(miscs)
        if new_tids is None:
            new_tids = self.trials.new_trial_ids(num_news)
        new_trials = self.trials.source_trial_docs(
            tids=new_tids, specs=specs, results=results, miscs=miscs, sources=[trial]
        )
        for t in new_trials:
            t["state"] = JOB_STATE_DONE
        return self.trials.insert_trial_docs(new_trials)


class Domain(object):
    """Picklable representation of search space and evaluation function.

    """

    rec_eval_print_node_on_error = False

    # -- the Ctrl object is not used directly, but rather
    #    a live Ctrl instance is inserted for the pyll_ctrl
    #    in self.evaluate so that it can be accessed from within
    #    the pyll graph describing the search space.
    pyll_ctrl = pyll.as_apply(Ctrl)

    def __init__(
        self,
        fn,
        expr,
        workdir=None,
        pass_expr_memo_ctrl=None,
        name=None,
        loss_target=None,
    ):
        """
        Paramaters
        ----------

        fn : callable
            This stores the `fn` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        expr : hyperopt.pyll.Apply
            This is the `space` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        workdir : string (or None)
            If non-None, the current working directory will be `workdir`while
            `expr` and `fn` are evaluated. (XXX Currently only respected by
            jobs run via MongoWorker)

        pass_expr_memo_ctrl : bool
            If True, `fn` will be called like this:
            `fn(self.expr, memo, ctrl)`,
            where `memo` is a dictionary mapping `Apply` nodes to their
            computed values, and `ctrl` is a `Ctrl` instance for communicating
            with a Trials database.  This lower-level calling convention is
            useful if you want to call e.g. `hyperopt.pyll.rec_eval` yourself
            in some customized way.

        name : string (or None)
            Label, used for pretty-printing.

        loss_target : float (or None)
            The actual or estimated minimum of `fn`.
            Some optimization algorithms may behave differently if their first
            objective is to find an input that achieves a certain value,
            rather than the more open-ended objective of pure minimization.
            XXX: Move this from Domain to be an fmin arg.

        """
        self.fn = fn
        if pass_expr_memo_ctrl is None:
            self.pass_expr_memo_ctrl = getattr(fn, "fmin_pass_expr_memo_ctrl", False)
        else:
            self.pass_expr_memo_ctrl = pass_expr_memo_ctrl

        self.expr = pyll.as_apply(expr)

        self.params = {}
        for node in pyll.dfs(self.expr):
            if node.name == "hyperopt_param":
                label = node.arg["label"].obj
                if label in self.params:
                    raise DuplicateLabel(label)
                self.params[label] = node.arg["obj"]

        self.loss_target = loss_target
        self.name = name

        self.workdir = workdir
        self.s_new_ids = pyll.Literal("new_ids")  # -- list at eval-time
        before = pyll.dfs(self.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.expr)
        vh = self.vh = VectorizeHelper(self.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.params.keys())

        self.s_rng = pyll.Literal("rng-placeholder")
        # -- N.B. operates inplace:
        self.s_idxs_vals = recursive_set_rng_kwarg(
            pyll.scope.pos_args(idxs_by_label, vals_by_label), self.s_rng
        )

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)

        # -- Protocol for serialization.
        #    self.cmd indicates to e.g. MongoWorker how this domain
        #    should be [un]serialized.
        #    XXX This mechanism deserves review as support for ipython
        #        workers improves.
        self.cmd = ("domain_attachment", "FMinIter_Domain")

    def memo_from_config(self, config):
        memo = {}
        for node in pyll.dfs(self.expr):
            if node.name == "hyperopt_param":
                label = node.arg["label"].obj
                # -- hack because it's not really garbagecollected
                #    this does have the desired effect of crashing the
                #    function if rec_eval actually needs a value that
                #    the the optimization algorithm thought to be unnecessary
                memo[node] = config.get(label, pyll.base.GarbageCollected)
        return memo

    def evaluate(self, config, ctrl, attach_attachments=True):
        memo = self.memo_from_config(config)
        use_obj_for_literal_in_memo(self.expr, ctrl, Ctrl, memo)
        if self.pass_expr_memo_ctrl:
            rval = self.fn(expr=self.expr, memo=memo, ctrl=ctrl)
        else:
            # -- the "work" of evaluating `config` can be written
            #    either into the pyll part (self.expr)
            #    or the normal Python part (self.fn)
            pyll_rval = pyll.rec_eval(
                self.expr,
                memo=memo,
                print_node_on_error=self.rec_eval_print_node_on_error,
            )
            rval = self.fn(pyll_rval)

        if isinstance(rval, (float, int, np.number)):
            dict_rval = {"loss": float(rval), "status": STATUS_OK}
        else:
            dict_rval = dict(rval)
            status = dict_rval["status"]
            if status not in STATUS_STRINGS:
                raise InvalidResultStatus(dict_rval)

            if status == STATUS_OK:
                # -- make sure that the loss is present and valid
                try:
                    dict_rval["loss"] = float(dict_rval["loss"])
                except (TypeError, KeyError):
                    raise InvalidLoss(dict_rval)

        if attach_attachments:
            attachments = dict_rval.pop("attachments", {})
            for key, val in list(attachments.items()):
                ctrl.attachments[key] = val

        # -- don't do this here because SON-compatibility is only a requirement
        #    for trials destined for a mongodb. In-memory rvals can contain
        #    anything.
        return dict_rval

    def evaluate_async(self, config, ctrl, attach_attachments=True):
        """
        this is the first part of async evaluation for ipython parallel engines (see ipy.py)

        This breaks evaluate into two parts to allow for the apply_async call
        to only pass the objective function and arguments.
        """
        memo = self.memo_from_config(config)
        use_obj_for_literal_in_memo(self.expr, ctrl, Ctrl, memo)
        if self.pass_expr_memo_ctrl:
            pyll_rval = self.fn(expr=self.expr, memo=memo, ctrl=ctrl)
        else:
            # -- the "work" of evaluating `config` can be written
            #    either into the pyll part (self.expr)
            #    or the normal Python part (self.fn)
            pyll_rval = pyll.rec_eval(
                self.expr,
                memo=memo,
                print_node_on_error=self.rec_eval_print_node_on_error,
            )
            return (self.fn, pyll_rval)

    def evaluate_async2(self, rval, ctrl, attach_attachments=True):
        """
        this is the second part of async evaluation for ipython parallel engines (see ipy.py)
        """
        if isinstance(rval, (float, int, np.number)):
            dict_rval = {"loss": float(rval), "status": STATUS_OK}
        else:
            dict_rval = dict(rval)
            status = dict_rval["status"]
            if status not in STATUS_STRINGS:
                raise InvalidResultStatus(dict_rval)

            if status == STATUS_OK:
                # -- make sure that the loss is present and valid
                try:
                    dict_rval["loss"] = float(dict_rval["loss"])
                except (TypeError, KeyError):
                    raise InvalidLoss(dict_rval)

        if attach_attachments:
            attachments = dict_rval.pop("attachments", {})
            for key, val in list(attachments.items()):
                ctrl.attachments[key] = val

        # -- don't do this here because SON-compatibility is only a requirement
        #    for trials destined for a mongodb. In-memory rvals can contain
        #    anything.
        return dict_rval

    def short_str(self):
        return "Domain{%s}" % str(self.fn)

    def loss(self, result, config=None):
        """Extract the scalar-valued loss from a result document
        """
        return result.get("loss", None)

    def loss_variance(self, result, config=None):
        """Return the variance in the estimate of the loss"""
        return result.get("loss_variance", 0.0)

    def true_loss(self, result, config=None):
        """Return a true loss, in the case that the `loss` is a surrogate"""
        # N.B. don't use get() here, it evaluates self.loss un-necessarily
        try:
            return result["true_loss"]
        except KeyError:
            return self.loss(result, config=config)

    def true_loss_variance(self, config=None):
        """Return the variance in  true loss,
        in the case that the `loss` is a surrogate.
        """
        raise NotImplementedError()

    def status(self, result, config=None):
        """Extract the job status from a result document
        """
        return result["status"]

    def new_result(self):
        """Return a JSON-encodable object
        to serve as the 'result' for new jobs.
        """
        return {"status": STATUS_NEW}


# -- flake8 doesn't like blank last line