Skip to content

Commit

Permalink
fix: forward start model parameters (#1825)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Dec 26, 2024
1 parent fb72167 commit 1a73c0c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 73 deletions.
53 changes: 15 additions & 38 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,65 +488,40 @@ void Models::StartModel(
if (!http_util::HasFieldInReq(req, callback, "model"))
return;
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
StartParameterOverride params_override;
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
params_override.custom_prompt_template = o.asString();
}

if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
params_override.cache_enabled = o.asBool();
}

if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
params_override.ngl = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
params_override.n_parallel = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
params_override.ctx_len = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
params_override.cache_type = o.asString();
}

std::optional<std::string> mmproj;
if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) {
params_override.mmproj = o.asString();
mmproj = o.asString();
}

auto bypass_llama_model_path = false;
// Support both llama_model_path and model_path for backward compatible
// model_path has higher priority
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
auto model_path = o.asString();
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
// Bypass if model does not exist in DB and llama_model_path exists
if (std::filesystem::exists(params_override.model_path.value()) &&
if (std::filesystem::exists(model_path) &&
!model_service_->HasModel(model_handle)) {
CTL_INF("llama_model_path exists, bypass check model id");
params_override.bypass_llama_model_path = true;
bypass_llama_model_path = true;
}
}
}

if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
}
auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path);

auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value() && !params_override.bypass_model_check()) {
if (!model_entry.has_value() && !bypass_model_check) {
Json::Value ret;
ret["message"] = "Cannot find model: " + model_handle;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
std::string engine_name = params_override.bypass_model_check()
? kLlamaEngine
: model_entry.value().engine;
std::string engine_name =
bypass_model_check ? kLlamaEngine : model_entry.value().engine;
auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -565,7 +540,9 @@ void Models::StartModel(
return;
}

auto result = model_service_->StartModel(model_handle, params_override);
auto result = model_service_->StartModel(
model_handle, *(req->getJsonObject()) /*params_override*/,
bypass_model_check);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down Expand Up @@ -668,7 +645,7 @@ void Models::AddRemoteModel(

auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto engine_name = (*(req->getJsonObject())).get("engine", "").asString();

auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -687,7 +664,7 @@ void Models::AddRemoteModel(
callback(resp);
return;
}

config::RemoteModelConfig model_config;
model_config.LoadFromJson(*(req->getJsonObject()));
cortex::db::Models modellist_utils_obj;
Expand Down
35 changes: 17 additions & 18 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,19 +749,28 @@ cpp::result<void, std::string> ModelService::DeleteModel(
}

cpp::result<StartModelResult, std::string> ModelService::StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override) {
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
std::optional<std::string> custom_prompt_template;
std::optional<int> ctx_len;
if (auto& o = params_override["prompt_template"]; !o.isNull()) {
custom_prompt_template = o.asString();
}

if (auto& o = params_override["ctx_len"]; !o.isNull()) {
ctx_len = o.asInt();
}

try {
constexpr const int kDefautlContextLength = 8192;
int max_model_context_length = kDefautlContextLength;
Json::Value json_data;
// Currently we don't support download vision models, so we need to bypass check
if (!params_override.bypass_model_check()) {
if (!bypass_model_check) {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
Expand Down Expand Up @@ -839,29 +848,19 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
}

json_data["model"] = model_handle;
if (auto& cpt = params_override.custom_prompt_template;
!cpt.value_or("").empty()) {
if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) {
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
}

#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
if (param_override.param_name) { \
json_obj[#param_name] = param_override.param_name.value(); \
}
json_helper::MergeJson(json_data, params_override);

ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
ASSIGN_IF_PRESENT(json_data, params_override, mmproj);
ASSIGN_IF_PRESENT(json_data, params_override, model_path);
#undef ASSIGN_IF_PRESENT
if (params_override.ctx_len) {
// Set the latest ctx_len
if (ctx_len) {
json_data["ctx_len"] =
std::min(params_override.ctx_len.value(), max_model_context_length);
std::min(ctx_len.value(), max_model_context_length);
}
CTL_INF(json_data.toStyledString());
auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(),
Expand Down
19 changes: 2 additions & 17 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ struct ModelPullInfo {
std::string download_url;
};

struct StartParameterOverride {
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
std::optional<std::string> mmproj;
std::optional<std::string> model_path;
bool bypass_llama_model_path = false;
bool bypass_model_check() const {
return mmproj.has_value() || bypass_llama_model_path;
}
};

struct StartModelResult {
bool success;
std::optional<std::string> warning;
Expand Down Expand Up @@ -82,8 +67,8 @@ class ModelService {
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);

cpp::result<StartModelResult, std::string> StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override);
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check);

cpp::result<bool, std::string> StopModel(const std::string& model_handle);

Expand Down
58 changes: 58 additions & 0 deletions engine/test/components/test_json_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) {

EXPECT_TRUE(result.isNull());
}

TEST(MergeJsonTest, MergeSimpleObjects) {
Json::Value json1, json2;
json1["name"] = "John";
json1["age"] = 30;

json2["age"] = 31;
json2["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["name"] = "John";
expected["age"] = 31;
expected["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeNestedObjects) {
Json::Value json1, json2;
json1["person"]["name"] = "John";
json1["person"]["age"] = 30;

json2["person"]["age"] = 31;
json2["person"]["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["person"]["name"] = "John";
expected["person"]["age"] = 31;
expected["person"]["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeArrays) {
Json::Value json1, json2;
json1["hobbies"] = Json::Value(Json::arrayValue);
json1["hobbies"].append("reading");
json1["hobbies"].append("painting");

json2["hobbies"] = Json::Value(Json::arrayValue);
json2["hobbies"].append("hiking");
json2["hobbies"].append("painting");

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["hobbies"] = Json::Value(Json::arrayValue);
expected["hobbies"].append("reading");
expected["hobbies"].append("painting");
expected["hobbies"].append("hiking");
expected["hobbies"].append("painting");

EXPECT_EQ(json1, expected);
}
24 changes: 24 additions & 0 deletions engine/utils/json_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,28 @@ inline std::string DumpJsonString(const Json::Value& json) {
builder["indentation"] = "";
return Json::writeString(builder, json);
}

inline void MergeJson(Json::Value& target, const Json::Value& source) {
for (const auto& member : source.getMemberNames()) {
if (target.isMember(member)) {
// If the member exists in both objects, recursively merge the values
if (target[member].type() == Json::objectValue &&
source[member].type() == Json::objectValue) {
MergeJson(target[member], source[member]);
} else if (target[member].type() == Json::arrayValue &&
source[member].type() == Json::arrayValue) {
// If the member is an array in both objects, merge the arrays
for (const auto& value : source[member]) {
target[member].append(value);
}
} else {
// Otherwise, overwrite the value in the target with the value from the source
target[member] = source[member];
}
} else {
// If the member doesn't exist in the target, add it
target[member] = source[member];
}
}
}
} // namespace json_helper

0 comments on commit 1a73c0c

Please sign in to comment.