diff --git a/tests/test_testsuite.py b/tests/test_testsuite.py index 71428281..b197c6a8 100644 --- a/tests/test_testsuite.py +++ b/tests/test_testsuite.py @@ -387,6 +387,23 @@ def test_two(self): self.assertRaises(ValueError, suite.sort_tests) +class TestIterateTests(TestCase): + def test_iterate_suite(self): + a = PlaceHolder("a") + b = PlaceHolder("b") + suite = unittest.TestSuite([a, b]) # type: ignore[list-item] + self.assertEqual([a, b], list(iterate_tests(suite))) + + def test_iterate_single_test(self): + a = PlaceHolder("a") + self.assertEqual([a], list(iterate_tests(a))) + + def test_iterate_list(self): + a = PlaceHolder("a") + b = PlaceHolder("b") + self.assertEqual([a, b], list(iterate_tests([a, b]))) + + class TestSortedTests(TestCase): def test_sorts_custom_suites(self): a = PlaceHolder("a") diff --git a/testtools/testsuite.py b/testtools/testsuite.py index 8db7549d..2aa461c1 100644 --- a/testtools/testsuite.py +++ b/testtools/testsuite.py @@ -58,13 +58,13 @@ def iterate_tests( test_suite_or_case: TestSuiteOrCase, ) -> Generator[unittest.TestCase, None, None]: """Iterate through all of the test cases in 'test_suite_or_case'.""" - if isinstance(test_suite_or_case, unittest.TestSuite): - # It's a suite, iterate through it + if isinstance(test_suite_or_case, Iterable) and not isinstance( + test_suite_or_case, unittest.TestCase + ): for test in test_suite_or_case: yield from iterate_tests(test) else: - # It's a test case (could be unittest.TestCase or duck-typed) - yield test_suite_or_case + yield test_suite_or_case # type: ignore[misc] class ConcurrentTestSuite(unittest.TestSuite):