From f35d71e3f1c3e673c60072950f2cdd4354a7eabb Mon Sep 17 00:00:00 2001 From: CLRSDev Date: Thu, 30 May 2024 14:53:06 -0700 Subject: [PATCH] PUBLIC: Update METADATA. PiperOrigin-RevId: 638791929 --- clrs/_src/algorithms/searching.py | 9 ++++-- clrs/_src/probing.py | 2 +- clrs/_src/samplers.py | 47 ++++++++++++++++++++++--------- clrs/_src/specs.py | 3 +- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/clrs/_src/algorithms/searching.py b/clrs/_src/algorithms/searching.py index c998c655..b6916833 100644 --- a/clrs/_src/algorithms/searching.py +++ b/clrs/_src/algorithms/searching.py @@ -179,7 +179,8 @@ def partition(A, A_pos, p, r, target, probes): 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 'j': probing.mask_one(A_pos[j], A.shape[0]), 'i_rank': (i + 1) * 1.0 / A.shape[0], - 'target': target * 1.0 / A.shape[0] + 'target': target * 1.0 / A.shape[0], + 'pivot': probing.mask_one(A_pos[r], A.shape[0]), }) tmp = A[i + 1] @@ -199,8 +200,10 @@ def partition(A, A_pos, p, r, target, probes): 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 'j': probing.mask_one(A_pos[r], A.shape[0]), 'i_rank': (i + 1 - p) * 1.0 / A.shape[0], - 'target': target * 1.0 / A.shape[0] - }) + 'target': target * 1.0 / A.shape[0], + 'pivot': probing.mask_one(A_pos[i + 1], A.shape[0]), + }, + ) return i + 1 diff --git a/clrs/_src/probing.py b/clrs/_src/probing.py index d9e9ad4a..b1aaccbb 100644 --- a/clrs/_src/probing.py +++ b/clrs/_src/probing.py @@ -145,7 +145,7 @@ def finalize(probes: ProbesDict): else: # Only one instance of input/output exist. Remove leading axis. probes[stage][loc][name]['data'] = np.squeeze( - np.array(probes[stage][loc][name]['data'])) + np.array(probes[stage][loc][name]['data']), axis=0) def split_stages( diff --git a/clrs/_src/samplers.py b/clrs/_src/samplers.py index 9cf32e27..f77b4ed1 100644 --- a/clrs/_src/samplers.py +++ b/clrs/_src/samplers.py @@ -72,6 +72,7 @@ def __init__( num_samples: int, *args, seed: Optional[int] = None, + track_max_steps: bool = True, **kwargs, ): """Initializes a `Sampler`. @@ -85,6 +86,13 @@ def __init__( If -1, samples are generated on the fly with each call to `next`. *args: Algorithm args. seed: RNG seed. + track_max_steps: if True and sampling on the fly (`num_samples`==-1), + we keep track of the maximum unroll length so far to pad batched samples + to that length. This is desirable when batches are used in compiled + functions that need recompilation every time the batch size changes. + Also, we get an initial value for max_steps by generating 1000 samples, + which will slow down initialization. If uniform shape of the batches + is not a concern, set `track_max_steps` to False. **kwargs: Algorithm kwargs. """ @@ -95,19 +103,21 @@ def __init__( self._algorithm = algorithm self._args = args self._kwargs = kwargs + self._track_max_steps = track_max_steps if num_samples < 0: logging.warning('Sampling dataset on-the-fly, unlimited samples.') - # Just get an initial estimate of max hint length - self.max_steps = -1 - for _ in range(1000): - data = self._sample_data(*args, **kwargs) - _, probes = algorithm(*data) - _, _, hint = probing.split_stages(probes, spec) - for dp in hint: - assert dp.data.shape[1] == 1 # batching axis - if dp.data.shape[0] > self.max_steps: - self.max_steps = dp.data.shape[0] + if track_max_steps: + # Get an initial estimate of max hint length + self.max_steps = -1 + for _ in range(1000): + data = self._sample_data(*args, **kwargs) + _, probes = algorithm(*data) + _, _, hint = probing.split_stages(probes, spec) + for dp in hint: + assert dp.data.shape[1] == 1 # batching axis + if dp.data.shape[0] > self.max_steps: + self.max_steps = dp.data.shape[0] else: logging.info('Creating a dataset with %i samples.', num_samples) (self._inputs, self._outputs, self._hints, @@ -148,10 +158,11 @@ def next(self, batch_size: Optional[int] = None) -> Feedback: """ if batch_size: if self._num_samples < 0: # generate on the fly + min_length = self.max_steps if self._track_max_steps else 0 inputs, outputs, hints, lengths = self._make_batch( - batch_size, self._spec, self.max_steps, + batch_size, self._spec, min_length, self._algorithm, *self._args, **self._kwargs) - if hints[0].data.shape[0] > self.max_steps: + if self._track_max_steps and hints[0].data.shape[0] > self.max_steps: logging.warning('Increasing hint lengh from %i to %i', self.max_steps, hints[0].data.shape[0]) self.max_steps = hints[0].data.shape[0] @@ -261,6 +272,7 @@ def build_sampler( num_samples: int, *args, seed: Optional[int] = None, + track_max_steps: bool = True, **kwargs, ) -> Tuple[Sampler, specs.Spec]: """Builds a sampler. See `Sampler` documentation.""" @@ -276,8 +288,15 @@ def build_sampler( if set(clean_kwargs) != set(kwargs): logging.warning('Ignoring kwargs %s when building sampler class %s', set(kwargs).difference(clean_kwargs), sampler_class) - sampler = sampler_class(algorithm, spec, num_samples, seed=seed, - *args, **clean_kwargs) + sampler = sampler_class( + algorithm, + spec, + num_samples, + seed=seed, + track_max_steps=track_max_steps, + *args, + **clean_kwargs, + ) return sampler, spec diff --git a/clrs/_src/specs.py b/clrs/_src/specs.py index 5de850cb..c6efa0a8 100644 --- a/clrs/_src/specs.py +++ b/clrs/_src/specs.py @@ -161,7 +161,8 @@ class OutputClass: 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE), 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE), 'i_rank': (Stage.HINT, Location.GRAPH, Type.SCALAR), - 'target': (Stage.HINT, Location.GRAPH, Type.SCALAR) + 'target': (Stage.HINT, Location.GRAPH, Type.SCALAR), + 'pivot': (Stage.HINT, Location.NODE, Type.MASK_ONE), }, 'minimum': { 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),