Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
jiashenC committed Sep 8, 2023
1 parent db819bf commit f8dde26
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
12 changes: 6 additions & 6 deletions test/unit_tests/readers/test_decord_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_should_sample_only_iframe(self):
batches = list(video_loader.read())

expected = self._batches_to_reader_convertor(
create_dummy_batches(filters=[i for i in range(0, NUM_FRAMES, k)])
create_dummy_batches(filters=[i for i in range(0, NUM_FRAMES, k)], is_from_storage=True)
)

print(batches[0].frames)
Expand All @@ -96,7 +96,7 @@ def test_should_sample_every_k_frame_with_predicate(self):
value = NUM_FRAMES // 2
start = value + k - (value % k) if value % k else value
expected = self._batches_to_reader_convertor(
create_dummy_batches(filters=[i for i in range(start, NUM_FRAMES, k)])
create_dummy_batches(filters=[i for i in range(start, NUM_FRAMES, k)], is_from_storage=True)
)
self.assertEqual(batches, expected)

Expand All @@ -123,7 +123,7 @@ def test_should_sample_every_k_frame_with_predicate(self):
batches = list(video_loader.read())
start = value + k - (value % k) if value % k else value
expected = self._batches_to_reader_convertor(
create_dummy_batches(filters=[i for i in range(start, 8, k)])
create_dummy_batches(filters=[i for i in range(start, 8, k)], is_from_storage=True)
)
self.assertEqual(batches, expected)

Expand All @@ -132,7 +132,7 @@ def test_should_return_one_batch(self):
file_url=self.video_file_url,
)
batches = list(video_loader.read())
expected = self._batches_to_reader_convertor(create_dummy_batches())
expected = self._batches_to_reader_convertor(create_dummy_batches(is_from_storage=True))
self.assertEqual(batches, expected)

def test_should_return_batches_equivalent_to_number_of_frames(self):
Expand All @@ -141,7 +141,7 @@ def test_should_return_batches_equivalent_to_number_of_frames(self):
batch_mem_size=self.frame_size,
)
batches = list(video_loader.read())
expected = self._batches_to_reader_convertor(create_dummy_batches(batch_size=1))
expected = self._batches_to_reader_convertor(create_dummy_batches(batch_size=1, is_from_storage=True))
self.assertEqual(batches, expected)

def test_should_sample_every_k_frame(self):
Expand All @@ -152,7 +152,7 @@ def test_should_sample_every_k_frame(self):
)
batches = list(video_loader.read())
expected = self._batches_to_reader_convertor(
create_dummy_batches(filters=[i for i in range(0, NUM_FRAMES, k)])
create_dummy_batches(filters=[i for i in range(0, NUM_FRAMES, k)], is_from_storage=True)
)
self.assertEqual(batches, expected)

Expand Down
5 changes: 4 additions & 1 deletion test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def create_dummy_batches(
batch_size=10,
start_id=0,
video_dir=None,
is_from_storage=False, # if cover test directly from storage, it needs to append a _row_number
):
video_dir = video_dir or get_tmp_dir()

Expand All @@ -461,7 +462,6 @@ def create_dummy_batches(
data.append(
{
"myvideo._row_id": 1,
"myvideo._row_number": i + start_id,
"myvideo.name": os.path.join(video_dir, "dummy.avi"),
"myvideo.id": i + start_id,
"myvideo.data": np.array(
Expand All @@ -471,6 +471,9 @@ def create_dummy_batches(
}
)

if is_from_storage:
data[-1]["myvideo._row_number"] = i + start_id

if len(data) % batch_size == 0:
yield Batch(pd.DataFrame(data))
data = []
Expand Down

0 comments on commit f8dde26

Please sign in to comment.