Skip to content

Commit

Permalink
PUBLIC: Update METADATA.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638791929
  • Loading branch information
CLRSDev authored and copybara-github committed May 30, 2024
1 parent 5635129 commit f35d71e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 19 deletions.
9 changes: 6 additions & 3 deletions clrs/_src/algorithms/searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[j], A.shape[0]),
'i_rank': (i + 1) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[r], A.shape[0]),
})

tmp = A[i + 1]
Expand All @@ -199,8 +200,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[r], A.shape[0]),
'i_rank': (i + 1 - p) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[i + 1], A.shape[0]),
},
)

return i + 1

Expand Down
2 changes: 1 addition & 1 deletion clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def finalize(probes: ProbesDict):
else:
# Only one instance of input/output exist. Remove leading axis.
probes[stage][loc][name]['data'] = np.squeeze(
np.array(probes[stage][loc][name]['data']))
np.array(probes[stage][loc][name]['data']), axis=0)


def split_stages(
Expand Down
47 changes: 33 additions & 14 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
):
"""Initializes a `Sampler`.
Expand All @@ -85,6 +86,13 @@ def __init__(
If -1, samples are generated on the fly with each call to `next`.
*args: Algorithm args.
seed: RNG seed.
track_max_steps: if True and sampling on the fly (`num_samples`==-1),
we keep track of the maximum unroll length so far to pad batched samples
to that length. This is desirable when batches are used in compiled
functions that need recompilation every time the batch size changes.
Also, we get an initial value for max_steps by generating 1000 samples,
which will slow down initialization. If uniform shape of the batches
is not a concern, set `track_max_steps` to False.
**kwargs: Algorithm kwargs.
"""

Expand All @@ -95,19 +103,21 @@ def __init__(
self._algorithm = algorithm
self._args = args
self._kwargs = kwargs
self._track_max_steps = track_max_steps

if num_samples < 0:
logging.warning('Sampling dataset on-the-fly, unlimited samples.')
# Just get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
if track_max_steps:
# Get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
else:
logging.info('Creating a dataset with %i samples.', num_samples)
(self._inputs, self._outputs, self._hints,
Expand Down Expand Up @@ -148,10 +158,11 @@ def next(self, batch_size: Optional[int] = None) -> Feedback:
"""
if batch_size:
if self._num_samples < 0: # generate on the fly
min_length = self.max_steps if self._track_max_steps else 0
inputs, outputs, hints, lengths = self._make_batch(
batch_size, self._spec, self.max_steps,
batch_size, self._spec, min_length,
self._algorithm, *self._args, **self._kwargs)
if hints[0].data.shape[0] > self.max_steps:
if self._track_max_steps and hints[0].data.shape[0] > self.max_steps:
logging.warning('Increasing hint lengh from %i to %i',
self.max_steps, hints[0].data.shape[0])
self.max_steps = hints[0].data.shape[0]
Expand Down Expand Up @@ -261,6 +272,7 @@ def build_sampler(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
) -> Tuple[Sampler, specs.Spec]:
"""Builds a sampler. See `Sampler` documentation."""
Expand All @@ -276,8 +288,15 @@ def build_sampler(
if set(clean_kwargs) != set(kwargs):
logging.warning('Ignoring kwargs %s when building sampler class %s',
set(kwargs).difference(clean_kwargs), sampler_class)
sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
*args, **clean_kwargs)
sampler = sampler_class(
algorithm,
spec,
num_samples,
seed=seed,
track_max_steps=track_max_steps,
*args,
**clean_kwargs,
)
return sampler, spec


Expand Down
3 changes: 2 additions & 1 deletion clrs/_src/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class OutputClass:
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
'i_rank': (Stage.HINT, Location.GRAPH, Type.SCALAR),
'target': (Stage.HINT, Location.GRAPH, Type.SCALAR)
'target': (Stage.HINT, Location.GRAPH, Type.SCALAR),
'pivot': (Stage.HINT, Location.NODE, Type.MASK_ONE),
},
'minimum': {
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
Expand Down

0 comments on commit f35d71e

Please sign in to comment.