Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBLIC: Update METADATA. #132

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions clrs/_src/algorithms/searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[j], A.shape[0]),
'i_rank': (i + 1) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[r], A.shape[0]),
})

tmp = A[i + 1]
Expand All @@ -199,8 +200,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[r], A.shape[0]),
'i_rank': (i + 1 - p) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[i + 1], A.shape[0]),
},
)

return i + 1

Expand Down
2 changes: 1 addition & 1 deletion clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def finalize(probes: ProbesDict):
else:
# Only one instance of input/output exist. Remove leading axis.
probes[stage][loc][name]['data'] = np.squeeze(
np.array(probes[stage][loc][name]['data']))
np.array(probes[stage][loc][name]['data']), axis=0)


def split_stages(
Expand Down
47 changes: 33 additions & 14 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
):
"""Initializes a `Sampler`.
Expand All @@ -85,6 +86,13 @@ def __init__(
If -1, samples are generated on the fly with each call to `next`.
*args: Algorithm args.
seed: RNG seed.
track_max_steps: if True and sampling on the fly (`num_samples`==-1),
we keep track of the maximum unroll length so far to pad batched samples
to that length. This is desirable when batches are used in compiled
functions that need recompilation every time the batch size changes.
Also, we get an initial value for max_steps by generating 1000 samples,
which will slow down initialization. If uniform shape of the batches
is not a concern, set `track_max_steps` to False.
**kwargs: Algorithm kwargs.
"""

Expand All @@ -95,19 +103,21 @@ def __init__(
self._algorithm = algorithm
self._args = args
self._kwargs = kwargs
self._track_max_steps = track_max_steps

if num_samples < 0:
logging.warning('Sampling dataset on-the-fly, unlimited samples.')
# Just get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
if track_max_steps:
# Get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
else:
logging.info('Creating a dataset with %i samples.', num_samples)
(self._inputs, self._outputs, self._hints,
Expand Down Expand Up @@ -148,10 +158,11 @@ def next(self, batch_size: Optional[int] = None) -> Feedback:
"""
if batch_size:
if self._num_samples < 0: # generate on the fly
min_length = self.max_steps if self._track_max_steps else 0
inputs, outputs, hints, lengths = self._make_batch(
batch_size, self._spec, self.max_steps,
batch_size, self._spec, min_length,
self._algorithm, *self._args, **self._kwargs)
if hints[0].data.shape[0] > self.max_steps:
if self._track_max_steps and hints[0].data.shape[0] > self.max_steps:
logging.warning('Increasing hint lengh from %i to %i',
self.max_steps, hints[0].data.shape[0])
self.max_steps = hints[0].data.shape[0]
Expand Down Expand Up @@ -261,6 +272,7 @@ def build_sampler(
num_samples: int,
*args,
seed: Optional[int] = None,
track_max_steps: bool = True,
**kwargs,
) -> Tuple[Sampler, specs.Spec]:
"""Builds a sampler. See `Sampler` documentation."""
Expand All @@ -276,8 +288,15 @@ def build_sampler(
if set(clean_kwargs) != set(kwargs):
logging.warning('Ignoring kwargs %s when building sampler class %s',
set(kwargs).difference(clean_kwargs), sampler_class)
sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
*args, **clean_kwargs)
sampler = sampler_class(
algorithm,
spec,
num_samples,
seed=seed,
track_max_steps=track_max_steps,
*args,
**clean_kwargs,
)
return sampler, spec


Expand Down
3 changes: 2 additions & 1 deletion clrs/_src/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class OutputClass:
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
'i_rank': (Stage.HINT, Location.GRAPH, Type.SCALAR),
'target': (Stage.HINT, Location.GRAPH, Type.SCALAR)
'target': (Stage.HINT, Location.GRAPH, Type.SCALAR),
'pivot': (Stage.HINT, Location.NODE, Type.MASK_ONE),
},
'minimum': {
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
Expand Down