testutils.py
3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (C) 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Python Fire's tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import re
import sys
import unittest
from fire import core
from fire import trace
import mock
import six
class BaseTestCase(unittest.TestCase):
"""Shared test case for Python Fire tests."""
@contextlib.contextmanager
def assertOutputMatches(self, stdout='.*', stderr='.*', capture=True):
"""Asserts that the context generates stdout and stderr matching regexps.
Note: If wrapped code raises an exception, stdout and stderr will not be
checked.
Args:
stdout: (str) regexp to match against stdout (None will check no stdout)
stderr: (str) regexp to match against stderr (None will check no stderr)
capture: (bool, default True) do not bubble up stdout or stderr
Yields:
Yields to the wrapped context.
"""
stdout_fp = six.StringIO()
stderr_fp = six.StringIO()
try:
with mock.patch.object(sys, 'stdout', stdout_fp):
with mock.patch.object(sys, 'stderr', stderr_fp):
yield
finally:
if not capture:
sys.stdout.write(stdout_fp.getvalue())
sys.stderr.write(stderr_fp.getvalue())
for name, regexp, fp in [('stdout', stdout, stdout_fp),
('stderr', stderr, stderr_fp)]:
value = fp.getvalue()
if regexp is None:
if value:
raise AssertionError('%s: Expected no output. Got: %r' %
(name, value))
else:
if not re.search(regexp, value, re.DOTALL | re.MULTILINE):
raise AssertionError('%s: Expected %r to match %r' %
(name, value, regexp))
@contextlib.contextmanager
def assertRaisesFireExit(self, code, regexp='.*'):
"""Asserts that a FireExit error is raised in the context.
Allows tests to check that Fire's wrapper around SystemExit is raised
and that a regexp is matched in the output.
Args:
code: The status code that the FireExit should contain.
regexp: stdout must match this regex.
Yields:
Yields to the wrapped context.
"""
with self.assertOutputMatches(stderr=regexp):
with self.assertRaises(core.FireExit):
try:
yield
except core.FireExit as exc:
if exc.code != code:
raise AssertionError('Incorrect exit code: %r != %r' % (exc.code,
code))
self.assertIsInstance(exc.trace, trace.FireTrace)
raise
# pylint: disable=invalid-name
main = unittest.main
skip = unittest.skip
skipIf = unittest.skipIf
# pylint: enable=invalid-name