Skip to content

Commit

Permalink
- Updated README.md to reflect changes in the Ruby code snippet.
Browse files Browse the repository at this point in the history
- Added dependencies `serde` and `serde_json` to `Cargo.toml` in the `ext/candle` directory.
- Modified `lib.rs` in the `ext/candle/src` directory to add new methods `new1`, `new2`, and `new3` to `RbModel`.
- Updated `rb_model.rs` in the `ext/candle/src/model` directory to implement the new methods `new1`, `new2`, and `new3` in `RbModel`.
- Added a new function `read_config` to read the config file in `rb_model.rs`.
- Made changes to `build_model` method in `rb_model.rs` to load config from the config file path.
- Added logging statements for debugging purposes in `rb_model.rs`.

As a result, we can now load the Biobert embedding model and compute embeddings. There is an error loading the jina embedding model now that we will tackle next
  • Loading branch information
cpetersen committed Jul 26, 2024
1 parent b514f36 commit dd51b3b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ And the following ruby:

```ruby
require 'candle'
model = Candle::Model.new
model = Candle::Model.new("dmis-lab/biobert-base-cased-v1.1")
embedding = model.embedding("Hi there!")
```

Expand Down
2 changes: 2 additions & 0 deletions ext/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
candle-core = "0.4.1"
candle-nn = "0.4.1"
candle-transformers = "0.4.1"
Expand Down
2 changes: 2 additions & 0 deletions ext/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ fn init(ruby: &Ruby) -> RbResult<()> {

let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
rb_model.define_singleton_method("new1", function!(RbModel::new1, 1))?;
rb_model.define_singleton_method("new2", function!(RbModel::new2, 2))?;
rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;
Expand Down
54 changes: 43 additions & 11 deletions ext/candle/src/model/rb_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ use crate::model::{
errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
rb_tensor::RbTensor,
};
use candle_core::{DType, Device, Module, Tensor};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::jina_bert::{BertModel, Config};
use candle_transformers::models::bert::{BertModel, Config};
use magnus::Error;
use crate::model::RbResult;
use serde_json;
use std::fs;
use std::path::PathBuf;
use tokenizers::Tokenizer;

#[magnus::wrap(class = "Candle::Model", free_immediately, size)]
Expand All @@ -28,10 +31,18 @@ pub struct RbModelInner {

impl RbModel {
pub fn new() -> RbResult<Self> {
Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
Self::new3(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
}

pub fn new2(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>) -> RbResult<Self> {
pub fn new1(model_path: Option<String>) -> RbResult<Self> {
Self::new3(model_path, Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
}

pub fn new2(model_path: Option<String>, tokenizer_path: Option<String>) -> RbResult<Self> {
Self::new3(model_path, tokenizer_path, Some(Device::Cpu))
}

pub fn new3(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>) -> RbResult<Self> {
let device = device.unwrap_or(Device::Cpu);
Ok(RbModel(RbModelInner {
device: device.clone(),
Expand Down Expand Up @@ -60,25 +71,38 @@ impl RbModel {
}
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer or Model not found"))
}

}

fn build_model(model_path: String, device: Device) -> RbResult<BertModel> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model_path = Api::new()
.map_err(wrap_hf_err)?
.repo(Repo::new(
model_path,
model_path.clone(),
RepoType::Model,
))
.get("model.safetensors")
.map_err(wrap_hf_err)?;
let config = Config::v2_base();
println!("Model path: {:?}", model_path);
let config_path = model_path.parent().unwrap().join("config.json");
println!("Config path: {:?}", config_path);

// let config_path = Api::new()
// .map_err(wrap_hf_err)?
// .repo(Repo::new(
// model_path.to_str().unwrap().to_string(),
// RepoType::Model,
// ))
// .get("config.json")
// .map_err(wrap_hf_err)?;

let config: Config = read_config(config_path)?;

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
.map_err(wrap_candle_err)?
};
let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?;
let model = BertModel::load(vb, &config).map_err(wrap_candle_err)?;
Ok(model)
}

Expand Down Expand Up @@ -119,8 +143,10 @@ impl RbModel {
.unsqueeze(0)
.map_err(wrap_candle_err)?;

// let start: std::time::Instant = std::time::Instant::now();
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
let token_type_ids = Tensor::zeros(&*token_ids.shape(), DType::I64, &self.0.device)
.map_err(wrap_candle_err)?;

let result = model.forward(&token_ids, &token_type_ids).map_err(wrap_candle_err)?;

// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = result.dims3()
Expand All @@ -129,7 +155,6 @@ impl RbModel {
.map_err(wrap_candle_err)?;
let embeddings = (sum / (n_tokens as f64))
.map_err(wrap_candle_err)?;
// let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?;

Ok(embeddings)
}
Expand All @@ -148,6 +173,13 @@ impl RbModel {
}
}

fn read_config(config_path: PathBuf) -> Result<Config, magnus::Error> {
let config_str = fs::read_to_string(config_path).map_err(|e| wrap_std_err(Box::new(e)))?;
println!("Config string: {}", config_str);
let config_json: Config = serde_json::from_str(&config_str).map_err(|e| wrap_std_err(Box::new(e)))?;
Ok(config_json)
}

// #[cfg(test)]
// mod tests {
// #[test]
Expand Down

0 comments on commit dd51b3b

Please sign in to comment.