-
-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use an already trained Torch model to predict on lots of data #111
Comments
cc @muammar In addition to this example, I'd also link to integration with a Scikit-learn wrapper for PyTorch skorch and Dask-ML's ParallelPostFit. |
Should have an example ready tomorrow.
I think we'll want to go into some detail about import glob
from PIL import Image
def default_loader(path, fs=__builtins__):
with fs.open(path, 'rb') as f:
img = Image.open(f).convert("RGB")
return img
class FileDataset(torch.utils.data.Dataset):
def __init__(self, files, transform=None, target_transform=None,
classes=None,
loader=default_loader):
self.files = files
self.transform = transform
self.target_transform = target_transform
self.loader = loader
if classes is None:
classes = list(sorted(set(x.split(os.path.sep)[-2] for x in files)))
else:
classes = list(classes)
self.classes = classes
def __len__(self):
return len(self.files)
def __getitem__(self, index):
filename = self.files[index]
img = self.loader(filename)
target = self.classes.index(filename.split(os.path.sep)[-2])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target and use it as files = glob.glob("hymenoptera_data/val/*/*.jpg")
dataset = FileDataset(files, transform=data_transforms['val']) For s3, the usage would be Things seem to be working out well after that. PyTorch models seem to (de)serialize much better than tensorflow's did last time I tried. |
Do we need to use the Torch Dataset API here?
I guess my hope is that, for image data at least, we could just pass around Numpy arrays. So we might created dask.delayed objects using |
Also, if you haven't seen it, this video is nice: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/video/S9198/s9198-dask-and-v100s-for-fast-distributed-batch-scoring-of-computer-vision-workloads.mp4 |
Ahh, yes if we’re doing prediction only we can probably do that. A Dataset would only be necessary for training.
… On Oct 15, 2019, at 16:53, Matthew Rocklin ***@***.***> wrote:
Do we need to use the Torch Dataset API here?
because it's not 100% straightforward how to get the data loaded onto workers
I guess my hope is that, for image data at least, we could just pass around Numpy arrays. So we might created dask.delayed objects using skimage.io.imread or something similar. (maybe like https://blog.dask.org/2019/06/20/load-image-data , but before the dask array bit)
—
You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
Distributed training would also be interesting of course, but my guess is
that that's more of an open problem. It's not clear to me which way is the
right way.
On Tue, Oct 15, 2019 at 3:29 PM Tom Augspurger <[email protected]>
wrote:
… Ahh, yes if we’re doing prediction only we can probably do that. A Dataset
would only be necessary for training.
> On Oct 15, 2019, at 16:53, Matthew Rocklin ***@***.***>
wrote:
>
>
> Do we need to use the Torch Dataset API here?
>
> because it's not 100% straightforward how to get the data loaded onto
workers
>
> I guess my hope is that, for image data at least, we could just pass
around Numpy arrays. So we might created dask.delayed objects using
skimage.io.imread or something similar. (maybe like
https://blog.dask.org/2019/06/20/load-image-data , but before the dask
array bit)
>
> —
> You are receiving this because you were assigned.
> Reply to this email directly, view it on GitHub, or unsubscribe.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#111?email_source=notifications&email_token=AACKZTDDOGKVT36QISDXWJLQOY74DA5CNFSM4I7QYORKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEBKNVLQ#issuecomment-542431918>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AACKZTCJTRFOQTY4AMNREFTQOY74DANCNFSM4I7QYORA>
.
|
I'm curious about inputs of Dask Arrays and outputs of model predictions too. I think PyTorch
It's also mentioned in dask/distributed#2581 |
Skorch looks interesting to me. Can the wrapper be used after loading the model from disk where the wrapper was not used? I've practiced applying the dask-ml parallelpostfit wrapper on a pre-trained model and I remember having to do a few manual steps before running predictions. I need to dig up that code. |
Yup. The underlying model is an attribute ( import torch
from skorch import NeuralNetClassifier
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
...
model = Net()
# Train model
# Save trained model using PyTorch
torch.save(model.state_dict(), "trained_model.pt")
# Use skorch later (not necessarily the training session)
sk_net = NeuralNetClassifier(Net)
sk_net.initialize()
# Load parameters saved with PyTorch
sk_net.module_.load_state_dict(torch.load("trained_model.pt")) |
Extending on #35 it would be nice to have an example using Dask and Torch together to parallelize prediction. This should be a simple embarrassingly parallel use case, but I suspect that it would be pragmatic for lots of folks.
The challenge, I think, is constructing a simple example that hopefully doesn't get too much into Torch or a dataset. In my ideal world this would be something like
Does anyone have good pointers to such a simple case?
cc @stsievert @TomAugspurger @AlbertDeFusco
The text was updated successfully, but these errors were encountered: