Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

directmlのダウンロードをwindowsでのみ有効にするように修正 #2

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["Nicolas Bigaouette <[email protected]>"]
edition = "2018"
name = "onnxruntime-sys"
version = "0.0.23"
version = "0.0.24"

description = "Unsafe wrapper around Microsoft's ONNX Runtime"
documentation = "https://docs.rs/onnxruntime-sys"
Expand Down
61 changes: 36 additions & 25 deletions onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ static ONNXRUNTIME_DIR_NAME: once_cell::sync::Lazy<String> =
|| format!("onnxruntime-{}-{}", TRIPLET.as_onnx_str(), ORT_VERSION,),
);

static DOWNLOAD_DIRECTML: once_cell::sync::Lazy<bool> = once_cell::sync::Lazy::new(|| {
if matches!(TRIPLET.os, Os::Windows) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build.rs上でtarget_os="windows"で判定しようとするとtargetがwindowsでも実行環境のOSで判定されてしまうため環境変数から判定されたTRIPLETでOSを判定するようにしている

#[cfg(feature = "directml")]
let downlaod_directml = true;
#[cfg(not(feature = "directml"))]
let downlaod_directml = false;
downlaod_directml
} else {
false
}
});

#[cfg(feature = "disable-sys-build-script")]
fn main() {
println!("Build script disabled!");
Expand All @@ -60,14 +72,7 @@ fn main() {
fn main() {
let libort_install_dir = prepare_libort_dir();

#[cfg(not(feature = "directml"))]
let (include_dir, lib_dir) = (
libort_install_dir.join("include"),
libort_install_dir.join("lib"),
);

#[cfg(feature = "directml")]
let (include_dir, lib_dir) = {
let (include_dir, lib_dir) = if *DOWNLOAD_DIRECTML {
let include_dir = libort_install_dir.join("build/native/include");
let runtimes_dir = libort_install_dir
.join("runtimes")
Expand All @@ -87,6 +92,11 @@ fn main() {
fs::create_dir_all(&export_lib_dir).unwrap();
copy_all_files(runtimes_dir, &export_lib_dir);
(export_include_dir, export_lib_dir)
} else {
(
libort_install_dir.join("include"),
libort_install_dir.join("lib"),
)
};

println!("Include directory: {:?}", include_dir);
Expand Down Expand Up @@ -420,19 +430,19 @@ impl OnnxPrebuiltArchive for Triplet {
}

fn prebuilt_archive_url() -> (PathBuf, String) {
#[cfg(not(feature = "directml"))]
let prebuilt_archive = format!(
"{}.{}",
&*ONNXRUNTIME_DIR_NAME,
TRIPLET.os.archive_extension()
);

#[cfg(feature = "directml")]
let prebuilt_archive = format!(
"Microsoft.ML.OnnxRuntime.DirectML.{}.{}",
ORT_VERSION,
TRIPLET.os.archive_extension()
);
let prebuilt_archive = if *DOWNLOAD_DIRECTML {
format!(
"Microsoft.ML.OnnxRuntime.DirectML.{}.{}",
ORT_VERSION,
TRIPLET.os.archive_extension()
)
} else {
format!(
"{}.{}",
&*ONNXRUNTIME_DIR_NAME,
TRIPLET.os.archive_extension()
)
};
let prebuilt_url = format!(
"{}/v{}/{}",
ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive
Expand Down Expand Up @@ -469,10 +479,11 @@ fn prepare_libort_dir_prebuilt() -> PathBuf {

// directmlの場合はzipの子ディレクトリがzipファイル名のディレクトリではないため、
// この処理は非directmlの場合のみ行う
#[cfg(not(feature = "directml"))]
let extract_dir = extract_dir.join(prebuilt_archive.file_stem().unwrap());

extract_dir
if *DOWNLOAD_DIRECTML {
extract_dir
} else {
extract_dir.join(prebuilt_archive.file_stem().unwrap())
}
}

fn prepare_libort_dir() -> PathBuf {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["Nicolas Bigaouette <[email protected]>"]
edition = "2018"
name = "onnxruntime"
version = "0.0.28"
version = "0.0.29"

description = "Wrapper around Microsoft's ONNX Runtime"
documentation = "https://docs.rs/onnxruntime"
Expand All @@ -19,7 +19,7 @@ name = "integration_tests"
required-features = ["model-fetching"]

[dependencies]
onnxruntime-sys = { path = "../onnxruntime-sys", version = "0.0.23" }
onnxruntime-sys = { path = "../onnxruntime-sys", version = "0.0.24" }

lazy_static = "1.4"
ndarray = "0.15"
Expand Down