diff --git a/api/src/main/java/ai/djl/repository/AbstractRepository.java b/api/src/main/java/ai/djl/repository/AbstractRepository.java index 2ca087fdf86..c28a3b16887 100644 --- a/api/src/main/java/ai/djl/repository/AbstractRepository.java +++ b/api/src/main/java/ai/djl/repository/AbstractRepository.java @@ -14,13 +14,10 @@ import ai.djl.util.Hex; import ai.djl.util.Progress; +import ai.djl.util.TarUtils; import ai.djl.util.Utils; import ai.djl.util.ZipUtils; -import org.apache.commons.compress.archivers.tar.TarArchiveEntry; -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; -import org.apache.commons.compress.utils.CloseShieldFilterInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -212,9 +209,9 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr if ("zip".equals(extension)) { ZipUtils.unzip(pis, dir); } else if ("tgz".equals(extension)) { - untar(pis, dir, true); + TarUtils.untar(pis, dir, true); } else if ("tar".equals(extension)) { - untar(pis, dir, false); + TarUtils.untar(pis, dir, false); } else { throw new IOException("File type is not supported: " + extension); } @@ -233,36 +230,6 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr pis.validateChecksum(item); } - private void untar(InputStream is, Path dir, boolean gzip) throws IOException { - InputStream bis; - if (gzip) { - bis = new GzipCompressorInputStream(new BufferedInputStream(is)); - } else { - bis = new BufferedInputStream(is); - } - bis = new CloseShieldFilterInputStream(bis); - try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { - TarArchiveEntry entry; - while ((entry = tis.getNextEntry()) != null) { - String entryName = entry.getName(); - if (entryName.contains("..")) { - throw new IOException("Malicious zip entry: " + entryName); - } - Path file = dir.resolve(entryName).toAbsolutePath(); - if (entry.isDirectory()) { - Files.createDirectories(file); - } else { - Path parentFile = file.getParent(); - if (parentFile == null) { - throw new AssertionError("Parent path should never be null: " + file); - } - Files.createDirectories(parentFile); - Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); - } - } - } - } - private static Map parseQueryString(URI uri) { try { Map map = new ConcurrentHashMap<>(); diff --git a/api/src/main/java/ai/djl/util/TarUtils.java b/api/src/main/java/ai/djl/util/TarUtils.java new file mode 100644 index 00000000000..c02b278788f --- /dev/null +++ b/api/src/main/java/ai/djl/util/TarUtils.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.commons.compress.utils.CloseShieldFilterInputStream; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; + +/** Utilities for working with zip files. */ +public final class TarUtils { + + private TarUtils() {} + + /** + * Un-compress a tar ball from InputStream. + * + * @param is the InputStream + * @param dir the target directory + * @param gzip if the bar ball is gzip + * @throws IOException for failures to untar the input directory + */ + public static void untar(InputStream is, Path dir, boolean gzip) throws IOException { + InputStream bis; + if (gzip) { + bis = new GzipCompressorInputStream(new BufferedInputStream(is)); + } else { + bis = new BufferedInputStream(is); + } + bis = new CloseShieldFilterInputStream(bis); + try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { + TarArchiveEntry entry; + while ((entry = tis.getNextEntry()) != null) { + String entryName = entry.getName(); + if (entryName.contains("..")) { + throw new IOException("Malicious zip entry: " + entryName); + } + Path file = dir.resolve(entryName).toAbsolutePath(); + if (entry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError("Parent path should never be null: " + file); + } + Files.createDirectories(parentFile); + Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); + } + } + } + } +}