diff --git a/pycompliance/pycompliance.py b/pycompliance/pycompliance.py index 270dedb..bcfe4de 100644 --- a/pycompliance/pycompliance.py +++ b/pycompliance/pycompliance.py @@ -19,7 +19,7 @@ def traverse(self, node: 'Node') -> 'list[Node]': r = [] for child in node.children: res = self.traverse(child) - r = res + r + r = r + res r.append(node) return r diff --git a/pycompliance/tests/test.py b/pycompliance/tests/test.py index 2c2dbd1..6bef757 100644 --- a/pycompliance/tests/test.py +++ b/pycompliance/tests/test.py @@ -188,15 +188,47 @@ class TestTraveral(unittest.TestCase): def test_traverse_benchmark(self): b = pycompliance.Benchmark('foo') section = pycompliance.Section('1') - subsection = pycompliance.Section('1.1') - c1 = pycompliance.Control('1.1.1') - c2 = pycompliance.Control('1.1.2') - expected = [b, section, subsection, c1, c2] - b.add_section(section) - b.add_section(subsection) - b.add_control(c1) - b.add_control(c2) + for section in ['1', '2', '3']: + b.add_section(pycompliance.Section(section)) + + for subsection in ['1.1', '1.2', '2.1', '2.2', '2.4', '3.1']: + b.add_section(pycompliance.Section(subsection)) + + for control in [ + '1.1.1', '1.1.2', '1.1.3', '1.2.1', '2.1.1', '2.1.2', + '2.2.1', '2.2.2', '2.2.3', '2.4.1', '2.4.2', '2.4.3', + '3.1.1', '3.1.2', '3.1.3', '3.1.4']: + b.add_control(pycompliance.Control(control)) + + expected_order = [ + '1.1.1', '1.1.2', '1.1.3', '1.1', '1.2.1', '1.2', '1', '2.1.1', + '2.1.2', '2.1', '2.2.1', '2.2.2', '2.2.3', '2.2', '2.4.1', '2.4.2', + '2.4.3', '2.4', '2', '3.1.1', '3.1.2', '3.1.3', '3.1.4', '3.1', + '3', 'foo'] nodes = b.traverse(b) - self.assertEqual(len(expected), len(nodes)) - for n in nodes: - self.assertIn(n, expected) + self.assertEqual(len(nodes), len(expected_order)) + for i, n in enumerate(nodes): + self.assertEqual(n.id, expected_order[i]) + + def test_traverse_subtree(self): + b = pycompliance.Benchmark('foo') + section = pycompliance.Section('1') + for section in ['1', '2', '3']: + b.add_section(pycompliance.Section(section)) + + for subsection in ['1.1', '1.2', '2.1', '2.2', '2.4', '3.1']: + b.add_section(pycompliance.Section(subsection)) + + for control in [ + '1.1.1', '1.1.2', '1.1.3', '1.2.1', '2.1.1', '2.1.2', + '2.2.1', '2.2.2', '2.2.3', '2.4.1', '2.4.2', '2.4.3', + '3.1.1', '3.1.2', '3.1.3', '3.1.4']: + b.add_control(pycompliance.Control(control)) + node = b.find('2') + nodes = b.traverse(node) + expected_order = [ + '2.1.1', '2.1.2', '2.1', '2.2.1', '2.2.2', '2.2.3', '2.2', '2.4.1', '2.4.2', + '2.4.3', '2.4', '2'] + self.assertEqual(len(nodes), len(expected_order)) + for i, n in enumerate(nodes): + self.assertEqual(n.id, expected_order[i])