Skip to content

Commit

Permalink
better tracking of batch counting. (this can be tricky for parallel q…
Browse files Browse the repository at this point in the history
…ueueing, since batches can be sampled directly if there are none in the queue).
  • Loading branch information
bnb32 committed Dec 29, 2024
1 parent 8beb2a0 commit 1e7afa9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
41 changes: 20 additions & 21 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,14 @@ def queue_shape(self):
@property
def queue_len(self):
"""Get number of batches in the queue."""
return self.queue.size().numpy()
return self.queue.size().numpy() + self.queue_futures

@property
def queue_futures(self):
"""Get number of scheduled futures that will eventually add batches to
the queue."""
return self._thread_pool._work_queue.qsize()

@property
def queue_free(self):
"""Get number of free spots in the queue."""
return self.queue_cap - self.queue_len

def get_queue(self):
"""Return FIFO queue for storing batches."""
return tf.queue.FIFOQueue(
Expand Down Expand Up @@ -232,16 +227,16 @@ def __len__(self):
return self.n_batches

def __iter__(self):
self._batch_count = 0
self.start()
self._batch_count = 0
return self

def get_batch(self) -> DsetTuple:
"""Get batch from queue or directly from a ``Sampler`` through
``sample_batch``."""
if self.queue_len > 0 or self.queue_futures > 0:
return self.queue.dequeue()
return self.sample_batch()
if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0:
return self.sample_batch()
return self.queue.dequeue()

@property
def running(self):
Expand Down Expand Up @@ -272,19 +267,26 @@ def sample_batches(self, n_batches) -> None:
)
return [task.result() for task in tasks]

@property
def needed_batches(self):
"""Number of batches needed to either fill or the queue or hit the
epoch limit."""
remaining = self.n_batches - self._batch_count - self.queue_len - 1
return min(self.queue_cap - self.queue_len, remaining)

def enqueue_batches(self) -> None:
"""Callback function for queue thread. While training, the queue is
checked for empty spots and filled. In the training thread, batches are
removed from the queue."""
log_time = time.time()
while self.running:
needed = min(
self.queue_free - self.queue_futures,
self.n_batches - self._batch_count
)
# no point in getting more than one batch at a time if
# max_workers == 1
needed = 1 if needed > 0 and self.max_workers == 1 else needed
needed = (
1
if self.needed_batches > 0 and self.max_workers == 1
else self.needed_batches
)

if needed > 0:
for batch in self.sample_batches(n_batches=needed):
Expand All @@ -307,14 +309,14 @@ def __next__(self) -> DsetTuple:
if self._batch_count < self.n_batches:
self.timer.start()
samples = self.get_batch()
self._batch_count += 1
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
batch = self.post_proc(samples)
self.timer.stop()
self._batch_count += 1
if self.verbose:
logger.debug(
'Batch step %s finished in %s.',
Expand Down Expand Up @@ -348,11 +350,8 @@ def sample_batch(self):

def log_queue_info(self):
"""Log info about queue size."""
return '{} queue length: {} / {}, with {} futures'.format(
self._thread_name.title(),
self.queue_len,
self.queue_cap,
self.queue_futures
return '{} queue length: {} / {}'.format(
self._thread_name.title(), self.queue_len, self.queue_cap
)

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/batch_handlers/test_bh_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_batch_handler_workers():
2 * sample_shape[-1],
)
n_obs = 40
max_workers = 5
n_batches = 40
max_workers = 20
n_batches = 20

lons, lats = np.meshgrid(
np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons)
Expand Down

0 comments on commit 1e7afa9

Please sign in to comment.