Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: forward start model parameters #1825

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 15 additions & 38 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,65 +488,40 @@ void Models::StartModel(
if (!http_util::HasFieldInReq(req, callback, "model"))
return;
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
StartParameterOverride params_override;
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
params_override.custom_prompt_template = o.asString();
}

if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
params_override.cache_enabled = o.asBool();
}

if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
params_override.ngl = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
params_override.n_parallel = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
params_override.ctx_len = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
params_override.cache_type = o.asString();
}

std::optional<std::string> mmproj;
if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) {
params_override.mmproj = o.asString();
mmproj = o.asString();
}

auto bypass_llama_model_path = false;
// Support both llama_model_path and model_path for backward compatible
// model_path has higher priority
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
auto model_path = o.asString();
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
// Bypass if model does not exist in DB and llama_model_path exists
if (std::filesystem::exists(params_override.model_path.value()) &&
if (std::filesystem::exists(model_path) &&
!model_service_->HasModel(model_handle)) {
CTL_INF("llama_model_path exists, bypass check model id");
params_override.bypass_llama_model_path = true;
bypass_llama_model_path = true;
}
}
}

if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
}
auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path);

auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value() && !params_override.bypass_model_check()) {
if (!model_entry.has_value() && !bypass_model_check) {
Json::Value ret;
ret["message"] = "Cannot find model: " + model_handle;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
std::string engine_name = params_override.bypass_model_check()
? kLlamaEngine
: model_entry.value().engine;
std::string engine_name =
bypass_model_check ? kLlamaEngine : model_entry.value().engine;
auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -565,7 +540,9 @@ void Models::StartModel(
return;
}

auto result = model_service_->StartModel(model_handle, params_override);
auto result = model_service_->StartModel(
model_handle, *(req->getJsonObject()) /*params_override*/,
bypass_model_check);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down Expand Up @@ -668,7 +645,7 @@ void Models::AddRemoteModel(

auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto engine_name = (*(req->getJsonObject())).get("engine", "").asString();

auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -687,7 +664,7 @@ void Models::AddRemoteModel(
callback(resp);
return;
}

config::RemoteModelConfig model_config;
model_config.LoadFromJson(*(req->getJsonObject()));
cortex::db::Models modellist_utils_obj;
Expand Down
35 changes: 17 additions & 18 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,19 +749,28 @@ cpp::result<void, std::string> ModelService::DeleteModel(
}

cpp::result<StartModelResult, std::string> ModelService::StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override) {
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
std::optional<std::string> custom_prompt_template;
std::optional<int> ctx_len;
if (auto& o = params_override["prompt_template"]; !o.isNull()) {
custom_prompt_template = o.asString();
}

if (auto& o = params_override["ctx_len"]; !o.isNull()) {
ctx_len = o.asInt();
}

try {
constexpr const int kDefautlContextLength = 8192;
int max_model_context_length = kDefautlContextLength;
Json::Value json_data;
// Currently we don't support download vision models, so we need to bypass check
if (!params_override.bypass_model_check()) {
if (!bypass_model_check) {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
Expand Down Expand Up @@ -839,29 +848,19 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
}

json_data["model"] = model_handle;
if (auto& cpt = params_override.custom_prompt_template;
!cpt.value_or("").empty()) {
if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) {
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
}

#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
if (param_override.param_name) { \
json_obj[#param_name] = param_override.param_name.value(); \
}
json_helper::MergeJson(json_data, params_override);

ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
ASSIGN_IF_PRESENT(json_data, params_override, mmproj);
ASSIGN_IF_PRESENT(json_data, params_override, model_path);
#undef ASSIGN_IF_PRESENT
if (params_override.ctx_len) {
// Set the latest ctx_len
if (ctx_len) {
json_data["ctx_len"] =
std::min(params_override.ctx_len.value(), max_model_context_length);
std::min(ctx_len.value(), max_model_context_length);
}
CTL_INF(json_data.toStyledString());
auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(),
Expand Down
19 changes: 2 additions & 17 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ struct ModelPullInfo {
std::string download_url;
};

struct StartParameterOverride {
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
std::optional<std::string> mmproj;
std::optional<std::string> model_path;
bool bypass_llama_model_path = false;
bool bypass_model_check() const {
return mmproj.has_value() || bypass_llama_model_path;
}
};

struct StartModelResult {
bool success;
std::optional<std::string> warning;
Expand Down Expand Up @@ -82,8 +67,8 @@ class ModelService {
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);

cpp::result<StartModelResult, std::string> StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override);
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check);

cpp::result<bool, std::string> StopModel(const std::string& model_handle);

Expand Down
58 changes: 58 additions & 0 deletions engine/test/components/test_json_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) {

EXPECT_TRUE(result.isNull());
}

TEST(MergeJsonTest, MergeSimpleObjects) {
Json::Value json1, json2;
json1["name"] = "John";
json1["age"] = 30;

json2["age"] = 31;
json2["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["name"] = "John";
expected["age"] = 31;
expected["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeNestedObjects) {
Json::Value json1, json2;
json1["person"]["name"] = "John";
json1["person"]["age"] = 30;

json2["person"]["age"] = 31;
json2["person"]["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["person"]["name"] = "John";
expected["person"]["age"] = 31;
expected["person"]["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeArrays) {
Json::Value json1, json2;
json1["hobbies"] = Json::Value(Json::arrayValue);
json1["hobbies"].append("reading");
json1["hobbies"].append("painting");

json2["hobbies"] = Json::Value(Json::arrayValue);
json2["hobbies"].append("hiking");
json2["hobbies"].append("painting");

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["hobbies"] = Json::Value(Json::arrayValue);
expected["hobbies"].append("reading");
expected["hobbies"].append("painting");
expected["hobbies"].append("hiking");
expected["hobbies"].append("painting");

EXPECT_EQ(json1, expected);
}
24 changes: 24 additions & 0 deletions engine/utils/json_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,28 @@ inline std::string DumpJsonString(const Json::Value& json) {
builder["indentation"] = "";
return Json::writeString(builder, json);
}

inline void MergeJson(Json::Value& target, const Json::Value& source) {
for (const auto& member : source.getMemberNames()) {
if (target.isMember(member)) {
// If the member exists in both objects, recursively merge the values
if (target[member].type() == Json::objectValue &&
source[member].type() == Json::objectValue) {
MergeJson(target[member], source[member]);
} else if (target[member].type() == Json::arrayValue &&
source[member].type() == Json::arrayValue) {
// If the member is an array in both objects, merge the arrays
for (const auto& value : source[member]) {
target[member].append(value);
}
} else {
// Otherwise, overwrite the value in the target with the value from the source
target[member] = source[member];
}
} else {
// If the member doesn't exist in the target, add it
target[member] = source[member];
}
}
}
} // namespace json_helper