Skip to content

Commit

Permalink
perf: only load spacy model once
Browse files Browse the repository at this point in the history
  • Loading branch information
BrewingWeasel committed Nov 16, 2023
1 parent b41deab commit 170edda
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 22 additions & 12 deletions spacy-parsing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use pyo3::{exceptions::PyEnvironmentError, prelude::*};
use pyo3::prelude::*;
use std::{collections::HashMap, str::FromStr};

pub struct Token {
Expand Down Expand Up @@ -54,19 +54,29 @@ impl FromStr for PartOfSpeech {
}
}

pub fn get_spacy_info(sent: &str, model: &str) -> Result<Vec<Token>, String> {
pub fn get_spacy_model(model: &str) -> Result<Py<PyAny>, String> {
Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let spacy = PyModule::import(py, "spacy")?;
let v = spacy.getattr("load")?.call1((model,))?;
Ok(v.to_object(py))
})
.map_err(|e| e.to_string())
}

pub fn get_spacy_info(sent: &str, morphologizer: &PyObject) -> Result<Vec<Token>, String> {
Python::with_gil(|py| -> PyResult<Vec<Token>> {
let mut words = Vec::new();
let spacy = PyModule::import(py, "spacy")?;
let morphologizer = match spacy.getattr("load")?.call1((model,)) {
Ok(v) => v,
Err(_) => {
return Err(PyEnvironmentError::new_err(format!(
"Unable to load {model}"
)))
}
};
let total: Vec<PyObject> = morphologizer.call1((sent,))?.extract()?;
// let spacy = PyModule::import(py, "spacy")?;
// let morphologizerr = match spacy.getattr("load")?.call1((model,)) {
// Ok(v) => v,
// Err(_) => {
// return Err(PyEnvironmentError::new_err(format!(
// "Unable to load {model}"
// )))
// }
// };

let total: Vec<PyObject> = morphologizer.call1(py, (sent,))?.extract(py)?;
for token in total {
let text: String = token.getattr(py, "text")?.extract(py)?;
let pos_str: String = token.getattr(py, "pos_")?.extract(py)?;
Expand Down
1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ spacy-parsing = { path = "../spacy-parsing" }
toml = "0.8.2"
dirs = "5.0.1"
chrono = { version = "0.4.31", features = ["serde"] }
pyo3 = "0.20.0"


[features]
Expand Down
2 changes: 1 addition & 1 deletion src-tauri/src/language_parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub async fn parse_text(sent: &str, state: State<'_, SakinyjeState>) -> Result<V
if sent.is_empty() {
return Ok(words);
}
let parsed_words = get_spacy_info(sent, &state.settings.model)?;
let parsed_words = get_spacy_info(sent, &state.model)?;
for word in parsed_words {
let clickable = !matches!(
word.pos,
Expand Down
13 changes: 11 additions & 2 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use crate::{
use ankiconnect::get_anki_card_statuses;
use chrono::{DateTime, Utc};
use commands::run_command;
use pyo3::PyObject;
use serde::{Deserialize, Serialize};
use shared::{SakinyjeResult, Settings};
use spacy_parsing::get_spacy_model;
use std::{collections::HashMap, fs};
use tauri::{async_runtime::block_on, GlobalWindowEvent, Manager, State, WindowEvent};

Expand All @@ -26,6 +28,7 @@ struct SakinyjeState(tauri::async_runtime::Mutex<SharedInfo>);
struct SharedInfo {
settings: Settings,
to_save: ToSave,
model: PyObject,
}

#[derive(Serialize, Deserialize, Default)]
Expand All @@ -41,7 +44,7 @@ impl Default for SharedInfo {
let config_file = dirs::config_dir().unwrap().join("sakinyje.toml");

let mut to_save: ToSave = fs::read_to_string(saved_state_file)
.map(|v| toml::from_str(&v).unwrap())
.map(|v| toml::from_str(&v).unwrap_or_default())
.unwrap_or_default();

let settings: Settings = fs::read_to_string(config_file)
Expand Down Expand Up @@ -74,8 +77,14 @@ impl Default for SharedInfo {
}
}

let model = get_spacy_model(&settings.model).unwrap(); // TODO: model

to_save.last_launched = new_time;
Self { to_save, settings }
Self {
to_save,
settings,
model,
}
}
}

Expand Down

0 comments on commit 170edda

Please sign in to comment.