progress.py
975 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Progress is reported using context managers.
A progress context manager takes an `initial` and a `total` argument
and should yield an object with an `update(n)` method.
"""
from __future__ import print_function
from __future__ import absolute_import
import contextlib
from tqdm import tqdm
from .std_out_err_redirect_tqdm import std_out_err_redirect_tqdm
@contextlib.contextmanager
def tqdm_progress_callback(initial, total):
with std_out_err_redirect_tqdm() as wrapped_stdout, tqdm(
total=total,
file=wrapped_stdout,
postfix={"best loss": "?"},
disable=False,
dynamic_ncols=True,
unit="trial",
initial=initial,
) as pbar:
yield pbar
@contextlib.contextmanager
def no_progress_callback(initial, total):
class NoProgressContext:
def update(self, n):
pass
yield NoProgressContext()
default_callback = tqdm_progress_callback
"""Use tqdm for progress by default"""