Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/test_testsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,23 @@ def test_two(self):
self.assertRaises(ValueError, suite.sort_tests)


class TestIterateTests(TestCase):
def test_iterate_suite(self):
a = PlaceHolder("a")
b = PlaceHolder("b")
suite = unittest.TestSuite([a, b]) # type: ignore[list-item]
self.assertEqual([a, b], list(iterate_tests(suite)))

def test_iterate_single_test(self):
a = PlaceHolder("a")
self.assertEqual([a], list(iterate_tests(a)))

def test_iterate_list(self):
a = PlaceHolder("a")
b = PlaceHolder("b")
self.assertEqual([a, b], list(iterate_tests([a, b])))


class TestSortedTests(TestCase):
def test_sorts_custom_suites(self):
a = PlaceHolder("a")
Expand Down
8 changes: 4 additions & 4 deletions testtools/testsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def iterate_tests(
test_suite_or_case: TestSuiteOrCase,
) -> Generator[unittest.TestCase, None, None]:
"""Iterate through all of the test cases in 'test_suite_or_case'."""
if isinstance(test_suite_or_case, unittest.TestSuite):
# It's a suite, iterate through it
if isinstance(test_suite_or_case, Iterable) and not isinstance(
test_suite_or_case, unittest.TestCase
):
for test in test_suite_or_case:
yield from iterate_tests(test)
else:
# It's a test case (could be unittest.TestCase or duck-typed)
yield test_suite_or_case
yield test_suite_or_case # type: ignore[misc]


class ConcurrentTestSuite(unittest.TestSuite):
Expand Down