Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load .pt model in scala? #233

Open
zaryabRiasat opened this issue Aug 10, 2024 · 0 comments
Open

How to load .pt model in scala? #233

zaryabRiasat opened this issue Aug 10, 2024 · 0 comments

Comments

@zaryabRiasat
Copy link

I've downloaded pre-trained model from there, which is 20180402-114759-vggface2.pt. I've used this in python and it is working fine with great accuracy.

python.py:

from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import torch

mtcnn = MTCNN(image_size=160, margin=0)
resnet = InceptionResnetV1(pretrained='vggface2').eval()

resnet.load_state_dict(torch.load('../20180402-114759-vggface2.pt'), strict=False)

img1 = Image.open('../img1')
img2 = Image.open('../img2')

img1_cropped = mtcnn(img1)
img2_cropped = mtcnn(img2)

if img1_cropped is not None and img2_cropped is not None:
    img1_embedding = resnet(img1_cropped.unsqueeze(0))
    img2_embedding = resnet(img2_cropped.unsqueeze(0))

    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    similarity = cos(img1_embedding, img2_embedding)
    
    print(f"Cosine Similarity: {similarity.item()}")
    
    threshold = 0.6  
    if similarity > threshold:
        print("The faces are similar!")
    else:
        print("The faces are different!")
else:
    print("Face not detected in one or both images.")

Now I want to use it in Scala (JVM Environment). I've searched a lot, and found that we can use .pt model in scala using DJL (Deep Java Library), the code which I tried in scala is:

libraries in build.sbt:

libraryDependencies ++= Seq(
  "ai.djl" % "api" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-engine" % "0.29.0" % "runtime",
  "ai.djl.pytorch" % "pytorch-model-zoo" % "0.29.0",
  "ai.djl.pytorch" % "pytorch-native-cpu" % "2.3.1" % "runtime" classifier "linux-x86_64",
  "ai.djl.pytorch" % "pytorch-jni" % "2.3.1-0.29.0" % "runtime"
)

main:

import ai.djl.Model
import ai.djl.modality.cv.Image
import ai.djl.modality.cv.ImageFactory
import ai.djl.ndarray.{NDArray, NDList, NDManager}
import ai.djl.ndarray.types.Shape
import ai.djl.translate.{Batchifier, Translator, TranslatorContext}

import java.nio.file.Paths

object FaceRecognitionDJL {

  def main(args: Array[String]): Unit = {
    val image1Path = Paths.get("../img_1.png")
    val image2Path = Paths.get("../img_2.png")

    val image1 = ImageFactory.getInstance().fromFile(image1Path)
    val image2 = ImageFactory.getInstance().fromFile(image2Path)

    val model = Model.newInstance("face_recognition_model")
    model.load(Paths.get("../20180402-114759-vggface2.pt"))

    val embeddings1 = getEmbeddings(model, image1)
    val embeddings2 = getEmbeddings(model, image2)

    val similarity = compareEmbeddings(embeddings1, embeddings2)
    println(s"Similarity between faces: $similarity")

    if (similarity > 0.7) {
      println("Faces belong to the same person.")
    } else {
      println("Faces do not belong to the same person.")
    }
  }

  def getEmbeddings(model: Model, image: Image): Array[Float] = {
    val predictor = model.newPredictor(new MyTranslator)
    predictor.predict(image)
  }

  def compareEmbeddings(embedding1: Array[Float], embedding2: Array[Float]): Double = {
    val dotProduct = embedding1.zip(embedding2).map { case (a, b) => a * b }.sum
    val norm1 = Math.sqrt(embedding1.map(x => x * x).sum)
    val norm2 = Math.sqrt(embedding2.map(x => x * x).sum)
    dotProduct / (norm1 * norm2)
  }
}

class MyTranslator extends Translator[Image, Array[Float]] {
  override def processInput(ctx: TranslatorContext, input: Image): NDList = {
    val manager = NDManager.newBaseManager()

    val imgArray: NDArray = input.toNDArray(manager)

    val resizedImgArray = imgArray.reshape(new Shape(160, 160))
    val normalizedImgArray = resizedImgArray.div(255.0)

    new NDList(normalizedImgArray)
  }

  override def processOutput(ctx: TranslatorContext, list: NDList): Array[Float] = {
    list.get(0).toFloatArray
  }

  override def getBatchifier: Batchifier = null
}

I have tried above code, after searching on different websites. But this is giving an error:

[error] Exception in thread "main" ai.djl.engine.EngineException: PytorchStreamReader failed reading zip archive: failed finding central directory

Same .pt model is working fine in python but I'm unable to run that in scala. Guide me what I'm doing wrong?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant