Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Windows MSVC support #78

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .github/workflows/llama-cpp-rs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,16 @@ jobs:
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build
run: cargo build
windows:
name: Check that it builds on windows
runs-on: windows-latest
steps:
- name: checkout
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11
with:
submodules: recursive
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build
24 changes: 12 additions & 12 deletions llama-cpp-2/src/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ impl ParseState {
rest = r;
rule.push(llama_grammar_element {
type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR,
value: c,
value: c as _,
});
}
rest = Self::consume_whitespace_and_comments(&rest[1..], nested);
Expand All @@ -292,14 +292,14 @@ impl ParseState {
};
rule.push(llama_grammar_element {
type_: gre_type,
value: c,
value: c as _,
});
if rest.starts_with("-]") {
let (c, r) = Self::parse_char(rest)?;
rest = r;
rule.push(llama_grammar_element {
type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER,
value: c,
value: c as _,
});
}
}
Expand Down Expand Up @@ -386,7 +386,7 @@ impl ParseState {
error,
})?;

Ok((value, rest))
Ok((value as llama_gretype, rest))
}

fn parse_char(rest: &str) -> Result<(llama_gretype, &str), GrammarParseError> {
Expand All @@ -401,17 +401,17 @@ impl ParseState {
'x' => Self::parse_hex(rest, 2),
'u' => Self::parse_hex(rest, 4),
'U' => Self::parse_hex(rest, 8),
't' => Ok((u32::from('\t'), rest)),
'r' => Ok((u32::from('\r'), rest)),
'n' => Ok((u32::from('\n'), rest)),
'\\' => Ok((u32::from('\\'), rest)),
'"' => Ok((u32::from('"'), rest)),
'[' => Ok((u32::from('['), rest)),
']' => Ok((u32::from(']'), rest)),
't' => Ok((u32::from('\t') as llama_gretype, rest)),
'r' => Ok((u32::from('\r') as llama_gretype, rest)),
'n' => Ok((u32::from('\n') as llama_gretype, rest)),
'\\' => Ok((u32::from('\\') as llama_gretype, rest)),
'"' => Ok((u32::from('"') as llama_gretype, rest)),
'[' => Ok((u32::from('[') as llama_gretype, rest)),
']' => Ok((u32::from(']') as llama_gretype, rest)),
c => Err(GrammarParseError::UnknownEscape { escape: c }),
}
} else if let Some(c) = rest.chars().next() {
Ok((u32::from(c), &rest[c.len_utf8()..]))
Ok((u32::from(c) as llama_gretype, &rest[c.len_utf8()..]))
} else {
Err(GrammarParseError::UnexpectedEndOfInput {
parse_stage: "char",
Expand Down
4 changes: 2 additions & 2 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ impl Drop for LlamaModel {
#[derive(Debug, Eq, Copy, Clone, PartialEq)]
pub enum VocabType {
/// Byte Pair Encoding
BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE,
BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
/// Sentence Piece Tokenizer
SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM,
SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
}

/// There was an error converting a `llama_vocab_type` to a `VocabType`.
Expand Down
16 changes: 8 additions & 8 deletions llama-cpp-2/src/token_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
#[allow(clippy::module_name_repetitions)]
pub enum LlamaTokenType {
/// An undefined token type.
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED,
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED as _,
/// A normal token type.
Normal = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL,
Normal = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL as _,
/// An unknown token type.
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN,
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN as _,
/// A control token type.
Control = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL,
Control = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL as _,
/// A user defined token type.
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED,
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED as _,
/// An unused token type.
Unused = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED,
Unused = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED as _,
/// A byte token type.
Byte = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE,
Byte = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE as _,
}

/// A safe wrapper for converting potentially deceptive `llama_token_type` values into
Expand Down Expand Up @@ -52,7 +52,7 @@ impl TryFrom<llama_cpp_sys_2::llama_token_type> for LlamaTokenType {
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED => Ok(LlamaTokenType::UserDefined),
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED => Ok(LlamaTokenType::Unused),
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE => Ok(LlamaTokenType::Byte),
_ => Err(LlamaTokenTypeFromIntError::UnknownValue(value)),
_ => Err(LlamaTokenTypeFromIntError::UnknownValue(value as _)),
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ fn main() {

// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
if let Some(ggml_cuda) = &mut ggml_cuda {
for lib in [
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
] {
for lib in ["cuda", "cublas", "cudart", "cublasLt"] {
println!("cargo:rustc-link-lib={}", lib);
}
if !ggml_cuda.get_compiler().is_like_msvc() {
for lib in ["culibos", "pthread", "dl", "rt"] {
println!("cargo:rustc-link-lib={}", lib);
}
}

println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");

Expand All @@ -46,10 +49,15 @@ fn main() {

ggml_cuda
.cuda(true)
.std("c++17")
.flag("-arch=all")
.file("llama.cpp/ggml-cuda.cu");

if ggml_cuda.get_compiler().is_like_msvc() {
ggml_cuda.std("c++14");
} else {
ggml_cuda.std("c++17");
}

ggml.define("GGML_USE_CUBLAS", None);
ggml_cuda.define("GGML_USE_CUBLAS", None);
llama_cpp.define("GGML_USE_CUBLAS", None);
Expand Down