Skip to content

Commit

Permalink
Merge pull request #134 from edgenai/chore/issue122
Browse files Browse the repository at this point in the history
Chore/issue122: add embeddings to settings integration tests
  • Loading branch information
toschoo authored Mar 26, 2024
2 parents 45f2a7d + b5613b0 commit 7778c5f
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
14 changes: 14 additions & 0 deletions crates/edgen_core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ pub async fn create_project_dirs() -> Result<(), std::io::Error> {

let audio_transcriptions_dir = PathBuf::from(&audio_transcriptions_str);

let embeddings_str = SETTINGS
.read()
.await
.read()
.await
.embeddings_models_dir
.to_string();

let embeddings_dir = PathBuf::from(&embeddings_str);

if !config_dir.is_dir() {
std::fs::create_dir_all(config_dir)?;
}
Expand All @@ -74,6 +84,10 @@ pub async fn create_project_dirs() -> Result<(), std::io::Error> {
std::fs::create_dir_all(&audio_transcriptions_dir)?;
}

if !embeddings_dir.is_dir() {
std::fs::create_dir_all(&embeddings_str)?;
}

Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/edgen_server/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ async fn observe_progress(
let mut m = tokio::fs::metadata(&f.path()).await;
let mut last_size = 0;
let mut timestamp = Instant::now();
while m.is_ok() {
let s = m.unwrap().len() as u64;
while let Ok(d) = m {
let s = d.len() as u64;
let p = (s * 100) / size;

if s > last_size {
Expand Down
54 changes: 50 additions & 4 deletions crates/edgen_server/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ pub const SMALL_LLM_REPO: &str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF";
pub const SMALL_WHISPER_NAME: &str = "ggml-distil-small.en.bin";
pub const SMALL_WHISPER_REPO: &str = "distil-whisper/distil-small.en";

pub const SMALL_EMBEDDINGS_NAME: &str = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf";
pub const SMALL_EMBEDDINGS_REPO: &str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF";

pub const BASE_URL: &str = "http://localhost:33322/v1";
pub const CHAT_URL: &str = "/chat";
pub const COMPLETIONS_URL: &str = "/completions";
pub const AUDIO_URL: &str = "/audio";
pub const TRANSCRIPTIONS_URL: &str = "/transcriptions";
pub const EMBEDDINGS_URL: &str = "/embeddings";
pub const STATUS_URL: &str = "/status";
pub const MISC_URL: &str = "/misc";
pub const VERSION_URL: &str = "/version";
Expand Down Expand Up @@ -241,6 +245,10 @@ pub fn data_exists() {
let transcriptions = audio.join("transcriptions");
println!("exists: {:?}", transcriptions);
assert!(transcriptions.exists());

let embeddings = models.join("embeddings");
println!("exists: {:?}", embeddings);
assert!(embeddings.exists());
}

/// Edit the config file: set another model dir for the indicated endpoint.
Expand All @@ -256,7 +264,9 @@ pub fn set_model_dir(ep: Endpoint, model_dir: &str) {
Endpoint::AudioTranscriptions => {
config.audio_transcriptions_models_dir = model_dir.to_string();
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
config.embeddings_models_dir = model_dir.to_string();
}
}
write_config(&config).unwrap();

Expand All @@ -280,7 +290,10 @@ pub fn set_model(ep: Endpoint, model_name: &str, model_repo: &str) {
config.audio_transcriptions_model_name = model_name.to_string();
config.audio_transcriptions_model_repo = model_repo.to_string();
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
config.embeddings_model_name = model_name.to_string();
config.embeddings_model_repo = model_repo.to_string();
}
}
write_config(&config).unwrap();

Expand All @@ -292,7 +305,7 @@ pub fn set_model(ep: Endpoint, model_name: &str, model_repo: &str) {
Endpoint::AudioTranscriptions => {
make_url(&[BASE_URL, AUDIO_URL, TRANSCRIPTIONS_URL, STATUS_URL])
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => make_url(&[BASE_URL, EMBEDDINGS_URL, STATUS_URL]),
};
let stat: status::AIStatus = blocking::get(url).unwrap().json().unwrap();
assert_eq!(stat.active_model, model_name);
Expand Down Expand Up @@ -335,13 +348,22 @@ pub fn chat_completions_custom_body(model: &str) -> String {
.expect("cannot convert JSON to String")
}

/// embeddings body with custom model
pub fn embeddings_custom_body(model: &str) -> String {
serde_json::to_string(&json!({
"model": model,
"input": "what is the capital of idaho?",
}))
.expect("cannot convert JSON to String")
}

/// Spawn a thread to send a request to the indicated endpoint.
/// This allows the caller to perform another task in the caller thread.
pub fn spawn_request(ep: Endpoint, body: &str, model: &str) -> thread::JoinHandle<bool> {
match ep {
Endpoint::ChatCompletions => spawn_chat_completions_request(body),
Endpoint::AudioTranscriptions => spawn_audio_transcriptions_request(model),
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => spawn_embeddings_request(body),
}
}

Expand Down Expand Up @@ -369,6 +391,30 @@ pub fn spawn_chat_completions_request(body: &str) -> thread::JoinHandle<bool> {
})
}

pub fn spawn_embeddings_request(body: &str) -> thread::JoinHandle<bool> {
let body = body.to_string();
thread::spawn(move || {
let ep = make_url(&[BASE_URL, EMBEDDINGS_URL]);
println!("requesting {}", ep);
match blocking::Client::new()
.post(&ep)
.header("Content-Type", "application/json")
.body(body)
.timeout(Duration::from_secs(180))
.send()
{
Err(e) => {
eprintln!("cannot connect: {:?}", e);
false
}
Ok(v) => {
println!("Got {:?}", v);
v.status().is_success()
}
}
})
}

pub fn spawn_audio_transcriptions_request(model: &str) -> thread::JoinHandle<bool> {
let model = model.to_string();
let frost = Path::new("resources").join("frost.wav");
Expand Down
69 changes: 67 additions & 2 deletions crates/edgen_server/tests/settings_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ fn test_battery() {

chat_completions_status_reachable();
audio_transcriptions_status_reachable();
embeddings_status_reachable();

// ================================
common::test_message("SCENARIO 2");
Expand All @@ -85,14 +86,22 @@ fn test_battery() {
common::SMALL_WHISPER_NAME,
common::SMALL_WHISPER_REPO,
);
common::set_model(
Endpoint::Embeddings,
common::SMALL_EMBEDDINGS_NAME,
common::SMALL_EMBEDDINGS_REPO,
);

// test ai endpoint and download
test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");

Check warning on line 97 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");


// we have downloaded, we should not download again
test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 3");
Expand Down Expand Up @@ -122,17 +131,26 @@ fn test_battery() {
"transcriptions",

Check warning on line 131 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
);

common::set_model_dir(Endpoint::ChatCompletions, &new_chat_completions_dir);
let new_embeddings_dir = my_models_dir.clone()
+ &format!(
"{}{}",
path::MAIN_SEPARATOR,
"embeddings",
);

common::set_model_dir(Endpoint::ChatCompletions, &new_chat_completions_dir);
common::set_model_dir(Endpoint::AudioTranscriptions, &new_audio_transcriptions_dir);
common::set_model_dir(Endpoint::Embeddings, &new_embeddings_dir);

test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");

assert!(path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 4");
Expand All @@ -142,11 +160,13 @@ fn test_battery() {

test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");

assert!(path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 5");
Expand All @@ -163,28 +183,37 @@ fn test_battery() {
common::SMALL_WHISPER_NAME,
common::SMALL_WHISPER_REPO,
);
common::set_model(
Endpoint::Embeddings,
common::SMALL_EMBEDDINGS_NAME,
common::SMALL_EMBEDDINGS_REPO,
);

// make sure we read from the old directories again
remove_dir_all(&my_models_dir).unwrap();
assert!(!path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 6");
// ================================
let chat_model = "TheBloke/phi-2-GGUF/phi-2.Q2_K.gguf";
let audio_model = "distil-whisper/distil-medium.en/ggml-medium-32-2.en.bin";
let embeddings_model = "TheBloke/phi-2-GGUF/phi-2.Q2_K.gguf";

test_ai_endpoint_with_download(Endpoint::ChatCompletions, chat_model);
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, audio_model);
test_ai_endpoint_with_download(Endpoint::Embeddings, embeddings_model);

// ================================
common::test_message("SCENARIO 7");
// ================================
test_ai_endpoint_no_download(Endpoint::ChatCompletions, chat_model);
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, audio_model);
test_ai_endpoint_no_download(Endpoint::Embeddings, embeddings_model);

// ================================
common::test_message("SCENARIO 8");
Expand All @@ -200,6 +229,10 @@ fn test_battery() {
"audio/transcriptions",
);
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, ".whisper-medium-32-2.en.bin");

let source = "models--TheBloke--phi-2-GGUF/blobs";
common::copy_model(source, ".phi-2.Q2_K.gguf", "embeddings");
test_ai_endpoint_no_download(Endpoint::Embeddings, ".phi-2.Q2_K.gguf");
})
}

Expand Down Expand Up @@ -243,6 +276,25 @@ fn audio_transcriptions_status_reachable() {
});
}

fn embeddings_status_reachable() {
common::test_message("embeddings status is reachable");
assert!(match blocking::get(common::make_url(&[
common::BASE_URL,
common::EMBEDDINGS_URL,
common::STATUS_URL,
])) {
Err(e) => {
eprintln!("cannot connect: {:?}", e);
false
}
Ok(v) => {
assert!(v.status().is_success());
println!("have: '{}'", v.text().unwrap());
true
}
});
}

fn test_config_reset() {
common::test_message("test resetting config");
common::reset_config();
Expand Down Expand Up @@ -291,7 +343,20 @@ fn test_ai_endpoint(endpoint: Endpoint, model: &str, download: bool) {
"".to_string(),
)

Check warning on line 344 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
common::test_message(&format!(
"embeddints endpoint with download: {}",
download
));
(
common::make_url(&[
common::BASE_URL,
common::EMBEDDINGS_URL,
common::STATUS_URL,
]),
common::embeddings_custom_body(model),
)
}
};
let handle = common::spawn_request(endpoint, &body, model);
if download {
Expand Down

0 comments on commit 7778c5f

Please sign in to comment.