diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 55dd2901..667b2a60 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -218,7 +218,7 @@ def to_app_config(self): config["embedding"]["source"] = self.embed_source config["embedding"]["model_name"] = self.embed_model config["embedding"]["api_key"] = self.embed_api_key - config["embedding"]["embed_batch_size"] = self.embed_batch_size + config["embedding"]["embed_batch_size"] = int(self.embed_batch_size) config["llm"]["source"] = self.llm config["llm"]["endpoint"] = self.llm_eas_url diff --git a/src/pai_rag/modules/embedding/embedding.py b/src/pai_rag/modules/embedding/embedding.py index 0d15a743..8b8350aa 100644 --- a/src/pai_rag/modules/embedding/embedding.py +++ b/src/pai_rag/modules/embedding/embedding.py @@ -26,6 +26,9 @@ def _create_new_instance(self, new_params: Dict[str, Any]): source = config["source"].lower() embed_batch_size = config.get("embed_batch_size", DEFAULT_EMBED_BATCH_SIZE) + if not isinstance(embed_batch_size, int): + raise TypeError("embed_batch_size must be of type int") + if source == "openai": embed_model = OpenAIEmbedding( api_key=config.get("api_key", None), @@ -52,6 +55,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): embed_model = HuggingFaceEmbedding( model_name=model_path, embed_batch_size=embed_batch_size ) + logger.info( f"Initialized HuggingFace embedding model {model_name} with {embed_batch_size} batch size." )