Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modelzoo metadata, improve integrity #1200

Open
wants to merge 8 commits into
base: v3_develop
Choose a base branch
from
106 changes: 97 additions & 9 deletions src/modelzoo/Zoo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ZooManager {
/**
* @brief Download model from model zoo
*/
void downloadModel();
void downloadModel(const nlohmann::json& responseJson);
Comment on lines 81 to +84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps nicer to return the object instead since the function is now void anyhow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting downloadModel returns a path to the downloaded file?


/**
* @brief Return path to model in cache
Expand All @@ -89,6 +89,34 @@ class ZooManager {
*/
std::string loadModelFromCache() const;

/**
* @brief Get path to metadata file
*
* @return std::string: Path to metadata file
*/
std::string getMetadataFilePath() const;

/**
* @brief Fetch model download links from Hub
*
* @return nlohmann::json: JSON with download links
*/
nlohmann::json fetchModelDownloadLinks();

/**
* @brief Get files in folder
*
* @return std::vector<std::string>: Files in folder
*/
std::vector<std::string> getFilesInFolder(const std::string& folder) const;

/**
* @brief Check if internet is available
*
* @return bool: True if internet is available
*/
bool internetIsAvailable() const;

private:
// Description of the model
NNModelDescription modelDescription;
Expand Down Expand Up @@ -150,10 +178,14 @@ bool checkIsErrorHub(const cpr::Response& response) {
return false;
}

std::vector<std::string> getFilesInFolder(const std::string& folder) {
std::vector<std::string> ZooManager::getFilesInFolder(const std::string& folder) const {
auto metadata = utility::loadYaml(getMetadataFilePath());
auto downloadedFiles = utility::yamlGet<std::vector<std::string>>(metadata, "downloaded_files");
std::vector<std::string> files;
for(const auto& entry : std::filesystem::directory_iterator(folder)) {
files.push_back(entry.path().string());
for(const auto& downloadedFile : downloadedFiles) {
if(std::filesystem::exists(combinePaths(folder, downloadedFile))) {
files.push_back(combinePaths(folder, downloadedFile));
}
}
return files;
}
Expand Down Expand Up @@ -207,7 +239,7 @@ bool ZooManager::isModelCached() const {
return std::filesystem::exists(getModelCacheFolderPath(cacheDirectory));
}

void ZooManager::downloadModel() {
nlohmann::json ZooManager::fetchModelDownloadLinks() {
// Add request parameters
cpr::Parameters params;

Expand Down Expand Up @@ -254,7 +286,18 @@ void ZooManager::downloadModel() {

// Extract download links from response
nlohmann::json responseJson = nlohmann::json::parse(response.text);
return responseJson;
}

void ZooManager::downloadModel(const nlohmann::json& responseJson) {
// Extract download links from response
auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>();
auto downloadHash = responseJson["hash"].get<std::string>();

// Metadata
YAML::Node metadata;
metadata["hash"] = downloadHash;
metadata["downloaded_files"] = std::vector<std::string>();

// Download all files and store them in cache folder
for(const auto& downloadLink : downloadLinks) {
Expand All @@ -270,7 +313,13 @@ void ZooManager::downloadModel() {
std::ofstream file(filepath, std::ios::binary);
file.write(downloadResponse.text.c_str(), downloadResponse.text.size());
file.close();

// Add filename to metadata
metadata["downloaded_files"].push_back(filename);
}

// Save metadata to file
utility::saveYaml(metadata, getMetadataFilePath());
}

std::string ZooManager::loadModelFromCache() const {
Expand All @@ -296,25 +345,49 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
// Check if model is cached
bool modelIsCached = zooManager.isModelCached();
bool useCachedModel = useCached && modelIsCached;
bool internetIsAvailable = zooManager.internetIsAvailable();
nlohmann::json responseJson;

if(internetIsAvailable) {
responseJson = zooManager.fetchModelDownloadLinks();
}

// Use cached model if present and useCached is true
if(useCachedModel) {
std::string modelPath = zooManager.loadModelFromCache();
Logging::getInstance().logger.info("Using cached model located at {}", modelPath);
return modelPath;
if(!internetIsAvailable) {
std::string modelPath = zooManager.loadModelFromCache();
Logging::getInstance().logger.info("Using cached model located at {}", modelPath);
return modelPath;
}

auto responseHash = responseJson["hash"].get<std::string>();
auto metadata = utility::loadYaml(zooManager.getMetadataFilePath());
auto metadataHash = utility::yamlGet<std::string>(metadata, "hash");

if(responseHash == metadataHash) {
std::string modelPath = zooManager.loadModelFromCache();
Logging::getInstance().logger.info("Using cached model located at {}", modelPath);
return modelPath;
}

Logging::getInstance().logger.warn("Cached model hash does not match response hash, downloading anew ...");
}

// Remove cached model if present
if(modelIsCached) {
zooManager.removeModelCacheFolder();
}

if(!internetIsAvailable) {
throw std::runtime_error("No internet connection available. Please check your network settings and try again.");
}

// Create cache folder
zooManager.createCacheFolder();

// Download model
Logging::getInstance().logger.info("Downloading model from model zoo");
zooManager.downloadModel();
zooManager.downloadModel(responseJson);

// Find path to model in cache
std::string modelPath = zooManager.loadModelFromCache();
Expand Down Expand Up @@ -351,6 +424,21 @@ void downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
}
}

bool ZooManager::internetIsAvailable() const {
constexpr int timeout_ms = 5000;
lnotspotl marked this conversation as resolved.
Show resolved Hide resolved
constexpr std::string_view host = "http://example.com";
try {
cpr::Response r = cpr::Get(cpr::Url{host}, cpr::Timeout{timeout_ms});
return r.status_code == cpr::status::HTTP_OK;
} catch(const std::exception& e) {
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of no internet this will slow down startup by 5 seconds, I think we should reduce this to ~500ms instead.

Would it maybe make sense to try pinging the URL we actually need instead of http://example.com?
Some customers might setup the network in a way where only the Zoo is accessible for the model updates, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I performed some empirical tests and 500 milliseconds turn out not to be enough. At times, depthai would not receive a response from the hub fast enough and thus conclude that there is no internet connection despite my PC being connected to the network. A timeout of 1 second turns out to be enough.

Copy link
Member Author

@lnotspotl lnotspotl Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out maybe even that 1 second is not sufficient. Changing to a two second timeout.

image

It might be of value do add a flag/environment variable that will disable this initial network checking. I would set it to True by default but if a user is certain such a check is not needed, they might set it to False. Moreover, I would add an environment variable defining the timeout in milliseconds to make things a little more modifiable, so as to not have these hardcoded.


std::string ZooManager::getMetadataFilePath() const {
return combinePaths(getModelCacheFolderPath(cacheDirectory), "metadata.yaml");
}

} // namespace dai

#else
Expand Down
Loading