unittest: Add exception capturing for subTest.

pull/488/head
Andrew Leech 2022-05-03 17:07:48 +10:00
rodzic 9f6f211506
commit 9b6315a2ba
2 zmienionych plików z 67 dodań i 16 usunięć

Wyświetl plik

@ -148,6 +148,14 @@ class TestUnittestAssertions(unittest.TestCase):
assert global_context is None
global_context = True
def test_subtest_even(self):
"""
Test that numbers between 0 and 5 are all even.
"""
for i in range(0, 10, 2):
with self.subTest("Should only pass for even numbers", i=i):
self.assertEqual(i % 2, 0)
if __name__ == "__main__":
unittest.main()

Wyświetl plik

@ -37,11 +37,35 @@ class AssertRaisesContext:
return False
# These are used to provide required context to things like subTest
__current_test__ = None
__test_result__ = None
class SubtestContext:
def __enter__(self):
pass
def __exit__(self, *exc_info):
if exc_info[0] is not None:
# Exception raised
global __test_result__, __current_test__
handle_test_exception(
__current_test__,
__test_result__,
exc_info
)
# Suppress the exception as we've captured it above
return True
class NullContext:
def __enter__(self):
pass
def __exit__(self, a, b, c):
def __exit__(self, exc_type, exc_value, traceback):
pass
@ -61,7 +85,7 @@ class TestCase:
func(*args, **kwargs)
def subTest(self, msg=None, **params):
return NullContext()
return SubtestContext(msg=msg, params=params)
def skipTest(self, reason):
raise SkipTest(reason)
@ -298,15 +322,29 @@ class TestResult:
return self
def capture_exc(e):
def capture_exc(exc, traceback):
buf = io.StringIO()
if hasattr(sys, "print_exception"):
sys.print_exception(e, buf)
sys.print_exception(exc, buf)
elif traceback is not None:
traceback.print_exception(None, e, sys.exc_info()[2], file=buf)
traceback.print_exception(None, exc, traceback, file=buf)
return buf.getvalue()
def handle_test_exception(current_test: tuple, test_result: TestResult, exc_info: tuple):
exc = exc_info[1]
traceback = exc_info[2]
ex_str = capture_exc(exc, traceback)
if isinstance(exc, AssertionError):
test_result.failuresNum += 1
test_result.failures.append((current_test, ex_str))
print(" FAIL")
else:
test_result.errorsNum += 1
test_result.errors.append((current_test, ex_str))
print(" ERROR")
def run_suite(c, test_result, suite_name=""):
if isinstance(c, TestSuite):
c.run(test_result)
@ -324,29 +362,34 @@ def run_suite(c, test_result, suite_name=""):
except AttributeError:
pass
def run_one(m):
def run_one(test_function):
global __test_result__, __current_test__
print("%s (%s) ..." % (name, suite_name), end="")
set_up()
__test_result__ = test_result
test_container = f"({suite_name})"
__current_test__ = (name, test_container)
try:
test_result.testsRun += 1
m()
test_globals = dict(**globals())
test_globals["test_function"] = test_function
exec("test_function()", test_globals, test_globals)
# No exception occurred, test passed
print(" ok")
except SkipTest as e:
print(" skipped:", e.args[0])
test_result.skippedNum += 1
except Exception as ex:
ex_str = capture_exc(ex)
if isinstance(ex, AssertionError):
test_result.failuresNum += 1
test_result.failures.append(((name, c), ex_str))
print(" FAIL")
else:
test_result.errorsNum += 1
test_result.errors.append(((name, c), ex_str))
print(" ERROR")
handle_test_exception(
current_test=(name, c),
test_result=test_result,
exc_info=sys.exc_info()
)
# Uncomment to investigate failure in detail
# raise
finally:
__test_result__ = None
__current_test__ = None
tear_down()
try:
o.doCleanups()