Skip to content

Commit

Permalink
Merge pull request #3 from rhmdnd/implement-traversal
Browse files Browse the repository at this point in the history
Implement a way to traverse the entire benchmark
  • Loading branch information
rhmdnd authored Apr 26, 2023
2 parents 44b6d28 + 685e70f commit 5cb3eff
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
17 changes: 15 additions & 2 deletions pycompliance/pycompliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,29 @@ def find(self, id: str) -> 'Node|None':
if n:
return n

def traverse(self, node: 'Node') -> 'list[Node]':
r = []
for child in node.children:
res = self.traverse(child)
r = res + r
r.append(node)
return r


class Benchmark(Node):

def __init__(self, name: str):
super().__init__(name)
self.name = name
self.id = name
self.version = None
self.version = ''

def add_section(self, section: 'Section'):
# see if the section already exists
s = self.find(section.id)
if s:
return s

if '.' not in section.id:
self.children.append(section)
return
Expand All @@ -42,7 +55,7 @@ def add_control(self, control: 'Control'):
parent.children.append(control)

def _find_parent_id(self, id: str) -> str:
pattern = re.compile(r"^(.+).\d$")
pattern = re.compile(r"^(.+)\.(\d+)$")
m = pattern.search(id)
return m.group(1)

Expand Down
47 changes: 46 additions & 1 deletion pycompliance/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ def setUp(self) -> None:

def test_benchmark_defaults(self):
self.assertEqual(self.b.name, 'foo')
self.assertIsNone(self.b.version)
self.assertEqual(self.b.version, '')
self.assertListEqual(self.b.children, [])

def test_benchmark_set_version(self):
self.b.version = '1.2.1'
self.assertEqual(self.b.version, '1.2.1')

def test_find_section(self):
section = pycompliance.Section('1')
self.b.add_section(section)
Expand Down Expand Up @@ -85,6 +89,14 @@ def test_nested_sections(self):
self.assertIn(s, b.children)
self.assertNotIn(sub, b.children)

def test_add_duplicate_section(self):
b = pycompliance.Benchmark('foo')
s = pycompliance.Section('1')
b.add_section(s)
self.assertListEqual(b.children, [s])
b.add_section(s)
self.assertListEqual(b.children, [s])


class TestControl(unittest.TestCase):
def test_default_control(self):
Expand Down Expand Up @@ -155,3 +167,36 @@ def test_add_control_with_invalid_subsection_fails(self):
control = pycompliance.Control('1.1.1')
b.add_section(section)
self.assertRaises(pycompliance.SectionNotFound, b.add_control, control)

def test_add_control_(self):
b = pycompliance.Benchmark('foo')
section = pycompliance.Section('3')
subsection = pycompliance.Section('3.2')
b.add_section(section)
b.add_section(subsection)
expected = []
for i in range(20):
control = pycompliance.Control('3.2.' + str(i))
expected.append(control)
b.add_control(control)
self.assertListEqual(b.children, [section])
self.assertListEqual(section.children, [subsection])
self.assertListEqual(subsection.children, expected)

class TestTraveral(unittest.TestCase):

def test_traverse_benchmark(self):
b = pycompliance.Benchmark('foo')
section = pycompliance.Section('1')
subsection = pycompliance.Section('1.1')
c1 = pycompliance.Control('1.1.1')
c2 = pycompliance.Control('1.1.2')
expected = [b, section, subsection, c1, c2]
b.add_section(section)
b.add_section(subsection)
b.add_control(c1)
b.add_control(c2)
nodes = b.traverse(b)
self.assertEqual(len(expected), len(nodes))
for n in nodes:
self.assertIn(n, expected)

0 comments on commit 5cb3eff

Please sign in to comment.