Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a feature registry for models #267

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

lorenzoh
Copy link
Member

@lorenzoh lorenzoh commented Nov 26, 2022

Implements #246.

PR Checklist

  • Tests are added
  • Documentation (this is an internal change for now, docs will be added in follow-up when functionality is made available and used in domain package)

Usage examples

From #269:

using FastAI: models
# loading this adds the models to registry
using FastVision


# Load original model, 1000 output classes, no weights (`ResNet(18)`):
load(models()["metalhead/resnet18"]);

# Load original model, 1000 output classes, with weights (`ResNet(18), pretrain=true`):
load(models()["metalhead/resnet18"], pretrained = true);

# Load only backbone, without weights:
load(models()["metalhead/resnet18"], variant = "backbone");

# Load only backbone, with weights:
load(models()["metalhead/resnet18"], pretrained = true, variant = "backbone");

# Load model for task, adapting layers as necessary:
task = ImageClassificationSingle((256, 256), 1:5, C = Gray{N0f8}) # input with 1 color channel, 5 classes
load(models()["metalhead/resnet18"], input = task.blocks.x, output = task.blocks.y)
# Also works with pretrained weights
load(models()["metalhead/resnet18"], pretrained = true, input = task.blocks.x, output = task.blocks.y)

# Correct variants are selected automatically given the blocks:
load(models()["metalhead/resnet18"], output = FastAI.ConvFeatures)  # uses backbone variant


# Support for multiple checkpoints, selectable by name:
load(models()["metalhead/resnet18"], checkpoint = "imagenet1k")

Docs

The proposed interface is well-described by the registry description, pasted below:

A FeatureRegistry for models. Allows you to find and load models for various learning
tasks using a unified interface. Call models() to see a table view of available models:

using FastAI
models()

Which models are available depends on the loaded packages. For example, FastVision.jl adds
vision models from Metalhead to the registry. Index the registry with a model ID to get more
information about that model:

using FastAI: models
using FastVision  # loading the package extends the list of available models

models()["metalhead/resnet18"]

If you've selected a model, call load to then instantiate a model:

model = load("metalhead/resnet18")

By default, load loads a default version of the model without any pretrained weights.

load(model) also accepts keyword arguments that allow you to specify variants of the model and
weight checkpoints that should be loaded.

Loading a checkpoint of pretrained weights:

  • load(entry; pretrained = true): Use any pretrained weights, if they are available.
  • load(entry; checkpoint = "checkpoint-name"): Use the weights with given name. See
    entry.checkpoints for available checkpoints (if any).
  • load(entry; pretrained = false): Don't use pretrained weights

Loading a model variant for a specific task:

  • load(entry; input = ImageTensor, output = OneHotLabel): Load a model variant matching
    an input and output block.
  • load(entry; variant = "backbone"): Load a model variant by name. See entry.variants` for
    available variants.

@github-actions
Copy link
Contributor

A documentation preview has been successfully built, view it here: Documentation preview PR-267

@lorenzoh
Copy link
Member Author

@darsnack @theabhirath would love to get feedback on the API. Anything unclear or an important feature missing, let me know!

Now handles both loading checkpoints and possible transformations.
This makes it easier to ntegrate with third-party model libraries
that likewise handle both with a single function.
A `loadfn([checkpoint])` holds the default loading function for a model. As a result,
the :variants field no longer has to be
populated.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant