diff --git a/onnxruntime-sys/Cargo.toml b/onnxruntime-sys/Cargo.toml index 419bdf29..60bfe290 100644 --- a/onnxruntime-sys/Cargo.toml +++ b/onnxruntime-sys/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Nicolas Bigaouette "] edition = "2018" name = "onnxruntime-sys" -version = "0.0.23" +version = "0.0.24" description = "Unsafe wrapper around Microsoft's ONNX Runtime" documentation = "https://docs.rs/onnxruntime-sys" diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 5ab5c5bf..91507551 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -51,6 +51,18 @@ static ONNXRUNTIME_DIR_NAME: once_cell::sync::Lazy = || format!("onnxruntime-{}-{}", TRIPLET.as_onnx_str(), ORT_VERSION,), ); +static DOWNLOAD_DIRECTML: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + if matches!(TRIPLET.os, Os::Windows) { + #[cfg(feature = "directml")] + let downlaod_directml = true; + #[cfg(not(feature = "directml"))] + let downlaod_directml = false; + downlaod_directml + } else { + false + } +}); + #[cfg(feature = "disable-sys-build-script")] fn main() { println!("Build script disabled!"); @@ -60,14 +72,7 @@ fn main() { fn main() { let libort_install_dir = prepare_libort_dir(); - #[cfg(not(feature = "directml"))] - let (include_dir, lib_dir) = ( - libort_install_dir.join("include"), - libort_install_dir.join("lib"), - ); - - #[cfg(feature = "directml")] - let (include_dir, lib_dir) = { + let (include_dir, lib_dir) = if *DOWNLOAD_DIRECTML { let include_dir = libort_install_dir.join("build/native/include"); let runtimes_dir = libort_install_dir .join("runtimes") @@ -87,6 +92,11 @@ fn main() { fs::create_dir_all(&export_lib_dir).unwrap(); copy_all_files(runtimes_dir, &export_lib_dir); (export_include_dir, export_lib_dir) + } else { + ( + libort_install_dir.join("include"), + libort_install_dir.join("lib"), + ) }; println!("Include directory: {:?}", include_dir); @@ -420,19 +430,19 @@ impl OnnxPrebuiltArchive for Triplet { } fn prebuilt_archive_url() -> (PathBuf, String) { - #[cfg(not(feature = "directml"))] - let prebuilt_archive = format!( - "{}.{}", - &*ONNXRUNTIME_DIR_NAME, - TRIPLET.os.archive_extension() - ); - - #[cfg(feature = "directml")] - let prebuilt_archive = format!( - "Microsoft.ML.OnnxRuntime.DirectML.{}.{}", - ORT_VERSION, - TRIPLET.os.archive_extension() - ); + let prebuilt_archive = if *DOWNLOAD_DIRECTML { + format!( + "Microsoft.ML.OnnxRuntime.DirectML.{}.{}", + ORT_VERSION, + TRIPLET.os.archive_extension() + ) + } else { + format!( + "{}.{}", + &*ONNXRUNTIME_DIR_NAME, + TRIPLET.os.archive_extension() + ) + }; let prebuilt_url = format!( "{}/v{}/{}", ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive @@ -469,10 +479,11 @@ fn prepare_libort_dir_prebuilt() -> PathBuf { // directmlの場合はzipの子ディレクトリがzipファイル名のディレクトリではないため、 // この処理は非directmlの場合のみ行う - #[cfg(not(feature = "directml"))] - let extract_dir = extract_dir.join(prebuilt_archive.file_stem().unwrap()); - - extract_dir + if !*DOWNLOAD_DIRECTML { + extract_dir.join(prebuilt_archive.file_stem().unwrap()) + } else { + extract_dir + } } fn prepare_libort_dir() -> PathBuf { diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 08353bae..24071b0c 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Nicolas Bigaouette "] edition = "2018" name = "onnxruntime" -version = "0.0.28" +version = "0.0.29" description = "Wrapper around Microsoft's ONNX Runtime" documentation = "https://docs.rs/onnxruntime" @@ -19,7 +19,7 @@ name = "integration_tests" required-features = ["model-fetching"] [dependencies] -onnxruntime-sys = { path = "../onnxruntime-sys", version = "0.0.23" } +onnxruntime-sys = { path = "../onnxruntime-sys", version = "0.0.24" } lazy_static = "1.4" ndarray = "0.15"