Skip to content

Commit

Permalink
[api] Improve Sam2Translator for PyTorch traced model
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 4, 2024
1 parent f88f2d6 commit f94e98e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class Sam2Translator implements NoBatchifyTranslator<Sam2Input, DetectedO
private Pipeline pipeline;
private Predictor<NDList, NDList> predictor;
private String encoderPath;
private String encodeMethod;

/** Constructs a {@code Sam2Translator} instance. */
public Sam2Translator(Builder builder) {
Expand All @@ -66,12 +67,19 @@ public Sam2Translator(Builder builder) {
pipeline.add(new ToTensor());
pipeline.add(new Normalize(MEAN, STD));
this.encoderPath = builder.encoderPath;
this.encodeMethod = builder.encodeMethod;
}

/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws IOException, ModelException {
if (encoderPath == null) {
// PyTorch model
if (encodeMethod != null) {
Model model = ctx.getModel();
predictor = model.newPredictor(new NoopTranslator(null));
model.getNDManager().attachInternal(UUID.randomUUID().toString(), predictor);
}
return;
}
Model model = ctx.getModel();
Expand Down Expand Up @@ -111,7 +119,15 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except
return new NDList(array, locations, labels);
}

NDList embeddings = predictor.predict(new NDList(array));
NDList embeddings;
if (encodeMethod == null) {
embeddings = predictor.predict(new NDList(array));
} else {
NDArray placeholder = manager.create("");
placeholder.setName("module_method:" + encodeMethod);
embeddings = predictor.predict(new NDList(placeholder, array));
}

NDArray mask = manager.zeros(new Shape(1, 1, 256, 256));
NDArray hasMask = manager.zeros(new Shape(1));
return new NDList(
Expand Down Expand Up @@ -173,9 +189,11 @@ public static Builder builder(Map<String, ?> arguments) {
public static class Builder {

String encoderPath;
String encodeMethod;

Builder(Map<String, ?> arguments) {
encoderPath = ArgumentsUtil.stringValue(arguments, "encoder");
encodeMethod = ArgumentsUtil.stringValue(arguments, "encode_method");
}

/**
Expand All @@ -189,6 +207,17 @@ public Builder optEncoderPath(String encoderPath) {
return this;
}

/**
* Sets the module name for encode method.
*
* @param encodeMethod the module name for encode method
* @return the builder
*/
public Builder optEncodeMethod(String encodeMethod) {
this.encodeMethod = encodeMethod;
return this;
}

/**
* Builds the translator.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ public class PtModelZoo extends ModelZoo {
addModel(
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny-gpu", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large-gpu", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.2"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.2"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
Expand Down
48 changes: 21 additions & 27 deletions examples/docs/segment_anything_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,21 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
from torch import nn


class SAM2ImageEncoder(nn.Module):
class Sam2Wrapper(nn.Module):

def __init__(self, sam_model: SAM2Base) -> None:
def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.img_size = sam_model.image_size
self.multimask_output = multimask_output
self.sparse_embedding = None

def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
@torch.no_grad()
def encode(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
backbone_out = self.image_encoder(x)
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0])
Expand All @@ -106,18 +112,6 @@ class SAM2ImageEncoder(nn.Module):

return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):

def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.img_size = sam_model.image_size
self.multimask_output = multimask_output
self.sparse_embedding = None

@torch.no_grad()
def forward(
self,
Expand Down Expand Up @@ -205,17 +199,13 @@ def trace_model(model_id: str):
device = torch.device("cpu")

model_name = f"{model_id[9:]}"
os.makedirs(model_name)
os.makedirs(model_name, exist_ok=True)

predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
encoder = SAM2ImageEncoder(predictor.model)
decoder = SAM2ImageDecoder(predictor.model, True)
model = Sam2Wrapper(predictor.model, True)

input_image = torch.ones(1, 3, 1024, 1024).to(device)
high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image)

converted = torch.jit.trace(encoder, input_image)
torch.jit.save(converted, f"model_name/encoder.pt")
high_res_feats_0, high_res_feats_1, image_embed = model.encode(input_image)

# trace decoder model
embed_size = (
Expand All @@ -232,10 +222,14 @@ def trace_model(model_id: str):
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([1], dtype=torch.float)

converted = torch.jit.trace(
decoder, (image_embed, high_res_feats_0, high_res_feats_1,
point_coords, point_labels, mask_input, has_mask_input))
torch.jit.save(converted, f"model_name/model_name.pt")
converted = torch.jit.trace_module(
model, {
"encode":
input_image,
"forward": (image_embed, high_res_feats_0, high_res_feats_1,
point_coords, point_labels, mask_input, has_mask_input)
})
torch.jit.save(converted, f"{model_name}/{model_name}.pt")

# save serving.properties
serving_file = os.path.join(model_name, "serving.properties")
Expand All @@ -244,7 +238,7 @@ def trace_model(model_id: str):
f"engine=PyTorch\n"
f"option.modelName={model_name}\n"
f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n"
f"encoder=encoder.pt")
f"encode_method=encode\n")


if __name__ == '__main__':
Expand Down
48 changes: 21 additions & 27 deletions examples/docs/trace_sam2_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@
from torch import nn


class SAM2ImageEncoder(nn.Module):
class Sam2Wrapper(nn.Module):

def __init__(self, sam_model: SAM2Base) -> None:
def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.img_size = sam_model.image_size
self.multimask_output = multimask_output
self.sparse_embedding = None

def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
@torch.no_grad()
def encode(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
backbone_out = self.image_encoder(x)
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0])
Expand All @@ -53,18 +59,6 @@ def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:

return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):

def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.img_size = sam_model.image_size
self.multimask_output = multimask_output
self.sparse_embedding = None

@torch.no_grad()
def forward(
self,
Expand Down Expand Up @@ -152,17 +146,13 @@ def trace_model(model_id: str):
device = torch.device("cpu")

model_name = f"{model_id[9:]}"
os.makedirs(model_name)
os.makedirs(model_name, exist_ok=True)

predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
encoder = SAM2ImageEncoder(predictor.model)
decoder = SAM2ImageDecoder(predictor.model, True)
model = Sam2Wrapper(predictor.model, True)

input_image = torch.ones(1, 3, 1024, 1024).to(device)
high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image)

converted = torch.jit.trace(encoder, input_image)
torch.jit.save(converted, f"model_name/encoder.pt")
high_res_feats_0, high_res_feats_1, image_embed = model.encode(input_image)

# trace decoder model
embed_size = (
Expand All @@ -179,10 +169,14 @@ def trace_model(model_id: str):
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([1], dtype=torch.float)

converted = torch.jit.trace(
decoder, (image_embed, high_res_feats_0, high_res_feats_1,
point_coords, point_labels, mask_input, has_mask_input))
torch.jit.save(converted, f"model_name/model_name.pt")
converted = torch.jit.trace_module(
model, {
"encode":
input_image,
"forward": (image_embed, high_res_feats_0, high_res_feats_1,
point_coords, point_labels, mask_input, has_mask_input)
})
torch.jit.save(converted, f"{model_name}/{model_name}.pt")

# save serving.properties
serving_file = os.path.join(model_name, "serving.properties")
Expand All @@ -191,7 +185,7 @@ def trace_model(model_id: str):
f"engine=PyTorch\n"
f"option.modelName={model_name}\n"
f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n"
f"encoder=encoder.pt")
f"encode_method=encode\n")


if __name__ == '__main__':
Expand Down

0 comments on commit f94e98e

Please sign in to comment.