Skip to content

Commit

Permalink
Merge pull request #32 from beclab/feat/bertv3-embedding
Browse files Browse the repository at this point in the history
fix:fix r4userembeddingwithmodel image error, add build_model_and_tok…
  • Loading branch information
bleachzou3 authored Oct 14, 2024
2 parents 335dad0 + 769e9e3 commit f3ad9ca
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 45 deletions.
4 changes: 2 additions & 2 deletions Dockerfile.r4userembedding
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RUN mkdir -p /userembedding && \
apt install build-essential -y && \
apt-get install libssl-dev -y && \
apt-get install pkg-config -y

ENV PATH="/root/.cargo/bin:${PATH}"
WORKDIR /userembedding
COPY user-embedding/src /userembedding/src
Expand All @@ -22,7 +22,7 @@ RUN cargo build --release
FROM ubuntu:jammy

# Import from builder.

ENV MODEL_INTERNET = "true"

WORKDIR /userembedding

Expand Down
1 change: 1 addition & 0 deletions Dockerfile.r4userembeddingwithmodel
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ RUN /userembedding/target/release/downloadmodel
FROM ubuntu:jammy

# Import from builder.
ENV MODEL_INTERNET = "false"


WORKDIR /userembedding
Expand Down
79 changes: 58 additions & 21 deletions user-embedding/src/bertcommon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use candle_transformers::models::bert::BertModel;
use log::{debug as logdebug, error as logerror};
use ndarray_rand::rand;
use rand::seq::SliceRandom;
use tokenizers::Tokenizer;

use crate::{
embedding_common::{self, normalize_l2, TokenizerImplSimple},
knowledge_base_api,
};

// const HUGGING_FACE_MODEL_NAME: &str = "sentence-transformers/all-MiniLM-L6-v2";
// const HUGGING_FACE_MODEL_REVISION: &str = "refs/pr/21";
// const HUGGING_FACE_MODEL_REVISION: &str = "refs/pr/21";
pub const BERT_V2_EMBEDDING_DIMENSION: usize = 384;
pub const BERT_V3_EMBEDDING_DIMENSION: usize = 384;

Expand Down Expand Up @@ -72,26 +73,50 @@ pub async fn calculate_single_entry_pure(
async fn calculate_userembedding() -> AnyhowResult<Tensor, AnyhowError> {
let current_source_name: String = env::var("TERMINUS_RECOMMEND_SOURCE_NAME")
.expect("TERMINUS_RECOMMEND_SOURCE_NAME env not found.");
let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap();
let embedding_method: String =
std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField =
embedding_common::MODEL_RELATED_INFO_MAP
.get(embedding_method.as_str())
.unwrap();
let mut option_cumulative_tensor: Option<Tensor> = None;
if model_related_info.model_name == "bert_v2"{
let cumulative_embedding_data: [f32; BERT_V2_EMBEDDING_DIMENSION] =
[0f32; BERT_V2_EMBEDDING_DIMENSION];
if model_related_info.model_name == "bert_v2" {
let cumulative_embedding_data: [f32; BERT_V2_EMBEDDING_DIMENSION] =
[0f32; BERT_V2_EMBEDDING_DIMENSION];
option_cumulative_tensor = Some(Tensor::new(&cumulative_embedding_data, &Device::Cpu)?);
}else if model_related_info.model_name == "bert_v3"{
} else if model_related_info.model_name == "bert_v3" {
let cumulative_embedding_data: [f32; BERT_V3_EMBEDDING_DIMENSION] =
[0f32; BERT_V3_EMBEDDING_DIMENSION];
[0f32; BERT_V3_EMBEDDING_DIMENSION];
option_cumulative_tensor = Some(Tensor::new(&cumulative_embedding_data, &Device::Cpu)?);
}else{
} else {
tracing::error!("embedding method {} not exist", embedding_method);
return Err(AnyhowError::msg("embedding method not exist"));
}
let mut cumulative_tensor: Tensor = option_cumulative_tensor.unwrap();
let default_model: String = model_related_info.hugging_face_model_name.to_string();
let default_revision: String = model_related_info.hugging_face_model_revision.to_string();
let (model, mut tokenizer) =
embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap();
let MODEL_INTERNET: String = env::var("MY_ENV_VAR").unwrap_or("false".to_string());
let mut model_option: Option<BertModel> = None;
let mut model_tokenizer: Option<Tokenizer> = None;

if MODEL_INTERNET == "true" {
logdebug!("use internet model");
let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet(
default_model,
default_revision,
)
.unwrap();
model_option = Some(model);
model_tokenizer = Some(tokenizer);
} else {
logdebug!("use local model");
let (model, mut tokenizers) =
embedding_common::build_model_and_tokenizer_from_local(model_related_info).unwrap();
model_option = Some(model);
model_tokenizer = Some(tokenizers);
}
let model: BertModel = model_option.unwrap();
let mut tokenizer: Tokenizer = model_tokenizer.unwrap();
let current_tokenizer: &TokenizerImplSimple = tokenizer
.with_padding(None)
.with_truncation(None)
Expand Down Expand Up @@ -146,15 +171,20 @@ async fn calculate_userembedding() -> AnyhowResult<Tensor, AnyhowError> {
}

pub async fn execute_bertv2_user_embedding() {
let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap();
let embedding_method: String =
std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField =
embedding_common::MODEL_RELATED_INFO_MAP
.get(embedding_method.as_str())
.unwrap();
let user_embedding: Tensor = calculate_userembedding()
.await
.expect("calculate user embedding fail");
let original_user_embedding =
embedding_common::retrieve_user_embedding_through_knowledge(model_related_info.embedding_dimension)
.await
.expect("retrieve user embedding through knowledge base fail");
let original_user_embedding = embedding_common::retrieve_user_embedding_through_knowledge(
model_related_info.embedding_dimension,
)
.await
.expect("retrieve user embedding through knowledge base fail");
let new_user_embedding_result = user_embedding.add(&original_user_embedding);
match new_user_embedding_result {
Ok(current_new_user_embedding) => {
Expand All @@ -181,13 +211,20 @@ mod bertv2test {
async fn test_calculate_single_entry() {
// cargo test bertv2test::test_calculate_single_entry
common_test_operation::init_env();
let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap();
let embedding_method: String =
std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist");
let model_related_info: &embedding_common::ModelInfoField =
embedding_common::MODEL_RELATED_INFO_MAP
.get(embedding_method.as_str())
.unwrap();

let default_model: String = model_related_info.hugging_face_model_name.to_string();
let default_revision: String = model_related_info.hugging_face_model_revision.to_string();
let (model, mut tokenizer) =
embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap();
let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet(
default_model,
default_revision,
)
.unwrap();
let current_tokenizer: &TokenizerImplSimple = tokenizer
.with_padding(None)
.with_truncation(None)
Expand Down
29 changes: 27 additions & 2 deletions user-embedding/src/download_model.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{fs::File, io::Write};

use hf_hub::api::sync::Api;

use userembedding::{
Expand All @@ -10,8 +12,31 @@ fn download_models() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading model {}...", model_related_info.model_name);
let default_model: String = model_related_info.hugging_face_model_name.to_string();
let default_revision: String = model_related_info.hugging_face_model_revision.to_string();
let (_, _) =
embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap();
let (_, _, current_bert_model_file_path) =
embedding_common::build_model_and_tokenizer_from_internet(
default_model,
default_revision,
)
.unwrap();
let serialized = serde_json::to_string(&current_bert_model_file_path).map_err(|e| {
log::error!("Error serializing model file path: {}", e);
e
})?;

let output_json_path = format!(
"/root/.cache/huggingface/{}.json",
model_related_info.model_name
);

let mut file = File::create(output_json_path).map_err(|e| {
log::error!("Error creating file: {}", e);
e
})?;
file.write_all(serialized.as_bytes()).map_err(|e| {
log::error!("Error writing to file: {}", e);
e
})?;

tracing::info!(
"Model {} downloaded successfully",
model_related_info.model_name
Expand Down
115 changes: 95 additions & 20 deletions user-embedding/src/embedding_common.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, ops::Add};
use std::{collections::HashMap, fs, ops::Add};

use anyhow::{bail, Error as AnyhowError, Result as AnyhowResult};
use candle_core::{DType, Device, Result as CandleResult, Tensor};
Expand All @@ -9,6 +9,7 @@ use log::{debug as logdebug, error as logerror};
use ndarray::Array;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use serde::{Deserialize, Serialize};
use text_splitter::TextSplitter;
use tokenizers::Tokenizer;

Expand All @@ -30,11 +31,56 @@ pub fn build_device(cpu: bool) -> CandleResult<Device> {
}
}

pub fn build_model_and_tokenizer(
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BertModelFilePath {
pub config_filename: String,
pub tokenizer_filename: String,
pub weights_filename: String,
}

pub fn build_model_and_tokenizer_from_local(
current_model_info_field: &ModelInfoField,
) -> AnyhowResult<(BertModel, Tokenizer)> {
let device = build_device(false)?;
let current_model_file_path = format!(
"/root/.cache/huggingface/{}.json",
current_model_info_field.model_name
);
let file_content = fs::read_to_string(current_model_file_path).map_err(|e| {
logerror!("read file error {}", e);
AnyhowError::msg(e)
})?;
let data: BertModelFilePath = serde_json::from_str(&file_content).map_err(|e| {
logerror!("parse json error {}", e);
AnyhowError::msg(e)
})?;
let config_filename = fs::canonicalize(data.config_filename)?;
let tokenizer_filename = fs::canonicalize(data.tokenizer_filename)?;
let weights_filename = fs::canonicalize(data.weights_filename)?;

logdebug!(
"[{}] [{}] config_filename {} tokenizer_filename {} weights_filename {}",
file!(),
line!(),
config_filename.display(),
tokenizer_filename.display(),
weights_filename.display()
);

let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(AnyhowError::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}

pub fn build_model_and_tokenizer_from_internet(
default_model: String,
default_revision: String,
) -> AnyhowResult<(BertModel, Tokenizer)> {
) -> AnyhowResult<(BertModel, Tokenizer, BertModelFilePath)> {
let device = build_device(false)?;

let repo = Repo::with_revision(default_model, RepoType::Model, default_revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
Expand All @@ -45,6 +91,11 @@ pub fn build_model_and_tokenizer(

(config, tokenizer, weights)
};
let current_bert_model_file_path = BertModelFilePath {
config_filename: config_filename.display().to_string(),
tokenizer_filename: tokenizer_filename.display().to_string(),
weights_filename: weights_filename.display().to_string(),
};
logdebug!(
"[{}] [{}] config_filename {} tokenizer_filename {} weights_filename {}",
file!(),
Expand All @@ -59,7 +110,7 @@ pub fn build_model_and_tokenizer(
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(AnyhowError::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
Ok((model, tokenizer, current_bert_model_file_path))
}

pub fn normalize_l2(v: &Tensor, dimension: usize) -> AnyhowResult<Tensor> {
Expand Down Expand Up @@ -333,7 +384,6 @@ pub struct ModelInfoField {
pub hugging_face_model_name: &'static str,
pub hugging_face_model_revision: &'static str,
pub embedding_dimension: usize,

}

lazy_static! {
Expand All @@ -347,23 +397,21 @@ lazy_static! {
hugging_face_model_revision: "refs/pr/21",
embedding_dimension: 384,
},

);
m.insert(
"bert_v3",
ModelInfoField {
model_name: "bert_v3",
hugging_face_model_name: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
hugging_face_model_revision: "refs/heads/main" ,
hugging_face_model_name:
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
hugging_face_model_revision: "refs/heads/main",
embedding_dimension: 384,
},

);
m
};
}


#[cfg(test)]
mod embeddingcommontest {

Expand All @@ -380,12 +428,9 @@ mod embeddingcommontest {
common_test_operation::init_env();
let source_name = String::from("bert_v2");

let current_tensor = retrieve_current_algorithm_impression_knowledge(
source_name,
384,
)
.await
.expect("add cumulative tensor fail");
let current_tensor = retrieve_current_algorithm_impression_knowledge(source_name, 384)
.await
.expect("add cumulative tensor fail");
logerror!("current_tensor {}", current_tensor);
}

Expand Down Expand Up @@ -425,14 +470,17 @@ mod embeddingcommontest {
}

#[test]
fn test_build_model_and_tokenizer() {
fn test_build_model_and_tokenizer_from_internet() {
// env::set_var("CUDA_COMPUTE_CAP","86");
// cargo test embeddingcommontest::test_build_model_and_tokenizer
// cargo test embeddingcommontest::test_build_model_and_tokenizer_from_internet
common::init_logger();
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model, mut tokenizer) =
embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap();
let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet(
default_model,
default_revision,
)
.unwrap();
let current_tokenizer: &TokenizerImplSimple = tokenizer
.with_padding(None)
.with_truncation(None)
Expand All @@ -457,4 +505,31 @@ mod embeddingcommontest {
Err(err) => println!("err {}", err),
}
}

#[test]
fn test_build_model_and_tokenizer_from_local() {
// cargo test embeddingcommontest::test_build_model_and_tokenizer_from_local
common::init_logger();
let current_model_info_field = embedding_common::MODEL_RELATED_INFO_MAP
.get("bert_v2")
.unwrap();
let (model, mut tokenizer) =
embedding_common::build_model_and_tokenizer_from_local(current_model_info_field)
.unwrap();
let current_tokenizer: &TokenizerImplSimple = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(AnyhowError::msg)
.unwrap();
let current_tensor = calculate_one_sentence(
&model,
current_tokenizer,
String::from("How beautiful the blonde girl"),
500,
)
.unwrap();
println!("*******************result {:?}", current_tensor);
let result = current_tensor.get(0).unwrap().to_vec1::<f32>().unwrap();
println!("********************** vec<f32> {:?}", result);
}
}

0 comments on commit f3ad9ca

Please sign in to comment.