diff --git a/Dockerfile.r4userembedding b/Dockerfile.r4userembedding index 367ed2f..2170d15 100644 --- a/Dockerfile.r4userembedding +++ b/Dockerfile.r4userembedding @@ -10,7 +10,7 @@ RUN mkdir -p /userembedding && \ apt install build-essential -y && \ apt-get install libssl-dev -y && \ apt-get install pkg-config -y - + ENV PATH="/root/.cargo/bin:${PATH}" WORKDIR /userembedding COPY user-embedding/src /userembedding/src @@ -22,7 +22,7 @@ RUN cargo build --release FROM ubuntu:jammy # Import from builder. - +ENV MODEL_INTERNET = "true" WORKDIR /userembedding diff --git a/Dockerfile.r4userembeddingwithmodel b/Dockerfile.r4userembeddingwithmodel index 65eaf2f..c74a648 100644 --- a/Dockerfile.r4userembeddingwithmodel +++ b/Dockerfile.r4userembeddingwithmodel @@ -23,6 +23,7 @@ RUN /userembedding/target/release/downloadmodel FROM ubuntu:jammy # Import from builder. +ENV MODEL_INTERNET = "false" WORKDIR /userembedding diff --git a/user-embedding/src/bertcommon.rs b/user-embedding/src/bertcommon.rs index 5577a55..595f572 100644 --- a/user-embedding/src/bertcommon.rs +++ b/user-embedding/src/bertcommon.rs @@ -6,6 +6,7 @@ use candle_transformers::models::bert::BertModel; use log::{debug as logdebug, error as logerror}; use ndarray_rand::rand; use rand::seq::SliceRandom; +use tokenizers::Tokenizer; use crate::{ embedding_common::{self, normalize_l2, TokenizerImplSimple}, @@ -13,7 +14,7 @@ use crate::{ }; // const HUGGING_FACE_MODEL_NAME: &str = "sentence-transformers/all-MiniLM-L6-v2"; -// const HUGGING_FACE_MODEL_REVISION: &str = "refs/pr/21"; +// const HUGGING_FACE_MODEL_REVISION: &str = "refs/pr/21"; pub const BERT_V2_EMBEDDING_DIMENSION: usize = 384; pub const BERT_V3_EMBEDDING_DIMENSION: usize = 384; @@ -72,26 +73,50 @@ pub async fn calculate_single_entry_pure( async fn calculate_userembedding() -> AnyhowResult { let current_source_name: String = env::var("TERMINUS_RECOMMEND_SOURCE_NAME") .expect("TERMINUS_RECOMMEND_SOURCE_NAME env not found."); - let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); - let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap(); + let embedding_method: String = + std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); + let model_related_info: &embedding_common::ModelInfoField = + embedding_common::MODEL_RELATED_INFO_MAP + .get(embedding_method.as_str()) + .unwrap(); let mut option_cumulative_tensor: Option = None; - if model_related_info.model_name == "bert_v2"{ - let cumulative_embedding_data: [f32; BERT_V2_EMBEDDING_DIMENSION] = - [0f32; BERT_V2_EMBEDDING_DIMENSION]; + if model_related_info.model_name == "bert_v2" { + let cumulative_embedding_data: [f32; BERT_V2_EMBEDDING_DIMENSION] = + [0f32; BERT_V2_EMBEDDING_DIMENSION]; option_cumulative_tensor = Some(Tensor::new(&cumulative_embedding_data, &Device::Cpu)?); - }else if model_related_info.model_name == "bert_v3"{ + } else if model_related_info.model_name == "bert_v3" { let cumulative_embedding_data: [f32; BERT_V3_EMBEDDING_DIMENSION] = - [0f32; BERT_V3_EMBEDDING_DIMENSION]; + [0f32; BERT_V3_EMBEDDING_DIMENSION]; option_cumulative_tensor = Some(Tensor::new(&cumulative_embedding_data, &Device::Cpu)?); - }else{ + } else { tracing::error!("embedding method {} not exist", embedding_method); return Err(AnyhowError::msg("embedding method not exist")); } let mut cumulative_tensor: Tensor = option_cumulative_tensor.unwrap(); let default_model: String = model_related_info.hugging_face_model_name.to_string(); let default_revision: String = model_related_info.hugging_face_model_revision.to_string(); - let (model, mut tokenizer) = - embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap(); + let MODEL_INTERNET: String = env::var("MY_ENV_VAR").unwrap_or("false".to_string()); + let mut model_option: Option = None; + let mut model_tokenizer: Option = None; + + if MODEL_INTERNET == "true" { + logdebug!("use internet model"); + let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet( + default_model, + default_revision, + ) + .unwrap(); + model_option = Some(model); + model_tokenizer = Some(tokenizer); + } else { + logdebug!("use local model"); + let (model, mut tokenizers) = + embedding_common::build_model_and_tokenizer_from_local(model_related_info).unwrap(); + model_option = Some(model); + model_tokenizer = Some(tokenizers); + } + let model: BertModel = model_option.unwrap(); + let mut tokenizer: Tokenizer = model_tokenizer.unwrap(); let current_tokenizer: &TokenizerImplSimple = tokenizer .with_padding(None) .with_truncation(None) @@ -146,15 +171,20 @@ async fn calculate_userembedding() -> AnyhowResult { } pub async fn execute_bertv2_user_embedding() { - let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); - let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap(); + let embedding_method: String = + std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); + let model_related_info: &embedding_common::ModelInfoField = + embedding_common::MODEL_RELATED_INFO_MAP + .get(embedding_method.as_str()) + .unwrap(); let user_embedding: Tensor = calculate_userembedding() .await .expect("calculate user embedding fail"); - let original_user_embedding = - embedding_common::retrieve_user_embedding_through_knowledge(model_related_info.embedding_dimension) - .await - .expect("retrieve user embedding through knowledge base fail"); + let original_user_embedding = embedding_common::retrieve_user_embedding_through_knowledge( + model_related_info.embedding_dimension, + ) + .await + .expect("retrieve user embedding through knowledge base fail"); let new_user_embedding_result = user_embedding.add(&original_user_embedding); match new_user_embedding_result { Ok(current_new_user_embedding) => { @@ -181,13 +211,20 @@ mod bertv2test { async fn test_calculate_single_entry() { // cargo test bertv2test::test_calculate_single_entry common_test_operation::init_env(); - let embedding_method: String =std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); - let model_related_info: &embedding_common::ModelInfoField = embedding_common::MODEL_RELATED_INFO_MAP.get(embedding_method.as_str()).unwrap(); + let embedding_method: String = + std::env::var("EMBEDDING_METHOD").expect("EMBEDDING_METHOD not exist"); + let model_related_info: &embedding_common::ModelInfoField = + embedding_common::MODEL_RELATED_INFO_MAP + .get(embedding_method.as_str()) + .unwrap(); let default_model: String = model_related_info.hugging_face_model_name.to_string(); let default_revision: String = model_related_info.hugging_face_model_revision.to_string(); - let (model, mut tokenizer) = - embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap(); + let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet( + default_model, + default_revision, + ) + .unwrap(); let current_tokenizer: &TokenizerImplSimple = tokenizer .with_padding(None) .with_truncation(None) diff --git a/user-embedding/src/download_model.rs b/user-embedding/src/download_model.rs index cbfc736..a6e4e3d 100644 --- a/user-embedding/src/download_model.rs +++ b/user-embedding/src/download_model.rs @@ -1,3 +1,5 @@ +use std::{fs::File, io::Write}; + use hf_hub::api::sync::Api; use userembedding::{ @@ -10,8 +12,31 @@ fn download_models() -> Result<(), Box> { tracing::info!("Downloading model {}...", model_related_info.model_name); let default_model: String = model_related_info.hugging_face_model_name.to_string(); let default_revision: String = model_related_info.hugging_face_model_revision.to_string(); - let (_, _) = - embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap(); + let (_, _, current_bert_model_file_path) = + embedding_common::build_model_and_tokenizer_from_internet( + default_model, + default_revision, + ) + .unwrap(); + let serialized = serde_json::to_string(¤t_bert_model_file_path).map_err(|e| { + log::error!("Error serializing model file path: {}", e); + e + })?; + + let output_json_path = format!( + "/root/.cache/huggingface/{}.json", + model_related_info.model_name + ); + + let mut file = File::create(output_json_path).map_err(|e| { + log::error!("Error creating file: {}", e); + e + })?; + file.write_all(serialized.as_bytes()).map_err(|e| { + log::error!("Error writing to file: {}", e); + e + })?; + tracing::info!( "Model {} downloaded successfully", model_related_info.model_name diff --git a/user-embedding/src/embedding_common.rs b/user-embedding/src/embedding_common.rs index dffe864..8a10acf 100755 --- a/user-embedding/src/embedding_common.rs +++ b/user-embedding/src/embedding_common.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, ops::Add}; +use std::{collections::HashMap, fs, ops::Add}; use anyhow::{bail, Error as AnyhowError, Result as AnyhowResult}; use candle_core::{DType, Device, Result as CandleResult, Tensor}; @@ -9,6 +9,7 @@ use log::{debug as logdebug, error as logerror}; use ndarray::Array; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; +use serde::{Deserialize, Serialize}; use text_splitter::TextSplitter; use tokenizers::Tokenizer; @@ -30,11 +31,56 @@ pub fn build_device(cpu: bool) -> CandleResult { } } -pub fn build_model_and_tokenizer( +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct BertModelFilePath { + pub config_filename: String, + pub tokenizer_filename: String, + pub weights_filename: String, +} + +pub fn build_model_and_tokenizer_from_local( + current_model_info_field: &ModelInfoField, +) -> AnyhowResult<(BertModel, Tokenizer)> { + let device = build_device(false)?; + let current_model_file_path = format!( + "/root/.cache/huggingface/{}.json", + current_model_info_field.model_name + ); + let file_content = fs::read_to_string(current_model_file_path).map_err(|e| { + logerror!("read file error {}", e); + AnyhowError::msg(e) + })?; + let data: BertModelFilePath = serde_json::from_str(&file_content).map_err(|e| { + logerror!("parse json error {}", e); + AnyhowError::msg(e) + })?; + let config_filename = fs::canonicalize(data.config_filename)?; + let tokenizer_filename = fs::canonicalize(data.tokenizer_filename)?; + let weights_filename = fs::canonicalize(data.weights_filename)?; + + logdebug!( + "[{}] [{}] config_filename {} tokenizer_filename {} weights_filename {}", + file!(), + line!(), + config_filename.display(), + tokenizer_filename.display(), + weights_filename.display() + ); + + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(AnyhowError::msg)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; + let model = BertModel::load(vb, &config)?; + Ok((model, tokenizer)) +} + +pub fn build_model_and_tokenizer_from_internet( default_model: String, default_revision: String, -) -> AnyhowResult<(BertModel, Tokenizer)> { +) -> AnyhowResult<(BertModel, Tokenizer, BertModelFilePath)> { let device = build_device(false)?; + let repo = Repo::with_revision(default_model, RepoType::Model, default_revision); let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; @@ -45,6 +91,11 @@ pub fn build_model_and_tokenizer( (config, tokenizer, weights) }; + let current_bert_model_file_path = BertModelFilePath { + config_filename: config_filename.display().to_string(), + tokenizer_filename: tokenizer_filename.display().to_string(), + weights_filename: weights_filename.display().to_string(), + }; logdebug!( "[{}] [{}] config_filename {} tokenizer_filename {} weights_filename {}", file!(), @@ -59,7 +110,7 @@ pub fn build_model_and_tokenizer( let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(AnyhowError::msg)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; let model = BertModel::load(vb, &config)?; - Ok((model, tokenizer)) + Ok((model, tokenizer, current_bert_model_file_path)) } pub fn normalize_l2(v: &Tensor, dimension: usize) -> AnyhowResult { @@ -333,7 +384,6 @@ pub struct ModelInfoField { pub hugging_face_model_name: &'static str, pub hugging_face_model_revision: &'static str, pub embedding_dimension: usize, - } lazy_static! { @@ -347,23 +397,21 @@ lazy_static! { hugging_face_model_revision: "refs/pr/21", embedding_dimension: 384, }, - ); m.insert( "bert_v3", ModelInfoField { model_name: "bert_v3", - hugging_face_model_name: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", - hugging_face_model_revision: "refs/heads/main" , + hugging_face_model_name: + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + hugging_face_model_revision: "refs/heads/main", embedding_dimension: 384, }, - ); m }; } - #[cfg(test)] mod embeddingcommontest { @@ -380,12 +428,9 @@ mod embeddingcommontest { common_test_operation::init_env(); let source_name = String::from("bert_v2"); - let current_tensor = retrieve_current_algorithm_impression_knowledge( - source_name, - 384, - ) - .await - .expect("add cumulative tensor fail"); + let current_tensor = retrieve_current_algorithm_impression_knowledge(source_name, 384) + .await + .expect("add cumulative tensor fail"); logerror!("current_tensor {}", current_tensor); } @@ -425,14 +470,17 @@ mod embeddingcommontest { } #[test] - fn test_build_model_and_tokenizer() { + fn test_build_model_and_tokenizer_from_internet() { // env::set_var("CUDA_COMPUTE_CAP","86"); - // cargo test embeddingcommontest::test_build_model_and_tokenizer + // cargo test embeddingcommontest::test_build_model_and_tokenizer_from_internet common::init_logger(); let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); let default_revision = "refs/pr/21".to_string(); - let (model, mut tokenizer) = - embedding_common::build_model_and_tokenizer(default_model, default_revision).unwrap(); + let (model, mut tokenizer, _) = embedding_common::build_model_and_tokenizer_from_internet( + default_model, + default_revision, + ) + .unwrap(); let current_tokenizer: &TokenizerImplSimple = tokenizer .with_padding(None) .with_truncation(None) @@ -457,4 +505,31 @@ mod embeddingcommontest { Err(err) => println!("err {}", err), } } + + #[test] + fn test_build_model_and_tokenizer_from_local() { + // cargo test embeddingcommontest::test_build_model_and_tokenizer_from_local + common::init_logger(); + let current_model_info_field = embedding_common::MODEL_RELATED_INFO_MAP + .get("bert_v2") + .unwrap(); + let (model, mut tokenizer) = + embedding_common::build_model_and_tokenizer_from_local(current_model_info_field) + .unwrap(); + let current_tokenizer: &TokenizerImplSimple = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(AnyhowError::msg) + .unwrap(); + let current_tensor = calculate_one_sentence( + &model, + current_tokenizer, + String::from("How beautiful the blonde girl"), + 500, + ) + .unwrap(); + println!("*******************result {:?}", current_tensor); + let result = current_tensor.get(0).unwrap().to_vec1::().unwrap(); + println!("********************** vec {:?}", result); + } }