Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed passing of generation config params to VLM generate. #1180

Merged
Merged
101 changes: 64 additions & 37 deletions src/python/py_utils.cpp
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ ov::genai::OptionalGenerationConfig update_config_from_kwargs(const ov::genai::O
ov::genai::GenerationConfig res_config;
if(config.has_value())
res_config = *config;

ov::AnyMap map;
for (const auto& item : kwargs) {
std::string key = py::cast<std::string>(item.first);
py::object value = py::cast<py::object>(item.second);
Expand All @@ -240,47 +240,74 @@ ov::genai::OptionalGenerationConfig update_config_from_kwargs(const ov::genai::O
// us from reading such configs, e.g. {"typical_p": None, 'top_p': 1.0,...}
return res_config;
}
if (key == "max_new_tokens") {
res_config.max_new_tokens = py::cast<int>(item.second);
} else if (key == "max_length") {
res_config.max_length = py::cast<int>(item.second);
} else if (key == "ignore_eos") {
res_config.ignore_eos = py::cast<bool>(item.second);
} else if (key == "num_beam_groups") {
res_config.num_beam_groups = py::cast<int>(item.second);
} else if (key == "num_beams") {
res_config.num_beams = py::cast<int>(item.second);
} else if (key == "diversity_penalty") {
res_config.diversity_penalty = py::cast<float>(item.second);
} else if (key == "length_penalty") {
res_config.length_penalty = py::cast<float>(item.second);
} else if (key == "num_return_sequences") {
res_config.num_return_sequences = py::cast<int>(item.second);
} else if (key == "no_repeat_ngram_size") {
res_config.no_repeat_ngram_size = py::cast<int>(item.second);
} else if (key == "stop_criteria") {
res_config.stop_criteria = py::cast<StopCriteria>(item.second);
} else if (key == "temperature") {
res_config.temperature = py::cast<float>(item.second);
} else if (key == "top_p") {
res_config.top_p = py::cast<float>(item.second);
} else if (key == "top_k") {
res_config.top_k = py::cast<int>(item.second);
} else if (key == "do_sample") {
res_config.do_sample = py::cast<bool>(item.second);
} else if (key == "repetition_penalty") {
res_config.repetition_penalty = py::cast<float>(item.second);
} else if (key == "eos_token_id") {
res_config.set_eos_token_id(py::cast<int>(item.second));
} else if (key == "adapters") {
res_config.adapters = py::cast<ov::genai::AdapterConfig>(item.second);
} else {
if (!generation_config_param_to_property(key, value, map)) {
throw(std::invalid_argument("'" + key + "' is incorrect GenerationConfig parameter name. "
"Use help(openvino_genai.GenerationConfig) to get list of acceptable parameters."));
}
}

res_config.update_generation_config(map);
return res_config;
}


bool generation_config_param_to_property(std::string key, py::object value, ov::AnyMap& map) {
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
if (key == "max_new_tokens") {
map.insert(ov::genai::max_new_tokens(py::cast<int>(value)));
} else if (key == "max_length") {
map.insert(ov::genai::max_length(py::cast<int>(value)));
} else if (key == "ignore_eos") {
map.insert(ov::genai::ignore_eos(py::cast<bool>(value)));
} else if (key == "min_new_tokens") {
map.insert(ov::genai::min_new_tokens(py::cast<int>(value)));
} else if (key == "stop_strings") {
map.insert(ov::genai::stop_strings(py::cast<std::vector<std::string>>(value)));
} else if (key == "include_stop_str_in_output") {
map.insert(ov::genai::include_stop_str_in_output(py::cast<bool>(value)));
} else if (key == "include_stop_str_in_output") {
map.insert(ov::genai::stop_token_ids(py::cast<std::vector<std::vector<int64_t>>>(value)));
} else if (key == "num_beam_groups") {
map.insert(ov::genai::num_beam_groups(py::cast<int>(value)));
} else if (key == "num_beams") {
map.insert(ov::genai::num_beams(py::cast<int>(value)));
} else if (key == "diversity_penalty") {
map.insert(ov::genai::diversity_penalty(py::cast<float>(value)));
} else if (key == "length_penalty") {
map.insert(ov::genai::length_penalty(py::cast<float>(value)));
} else if (key == "num_return_sequences") {
map.insert(ov::genai::num_return_sequences(py::cast<int>(value)));
} else if (key == "no_repeat_ngram_size") {
map.insert(ov::genai::no_repeat_ngram_size(py::cast<int>(value)));
} else if (key == "stop_criteria") {
map.insert(ov::genai::stop_criteria(py::cast<StopCriteria>(value)));
} else if (key == "temperature") {
map.insert(ov::genai::temperature(py::cast<float>(value)));
} else if (key == "top_p") {
map.insert(ov::genai::top_p(py::cast<float>(value)));
} else if (key == "top_k") {
map.insert(ov::genai::top_k(py::cast<int>(value)));
} else if (key == "do_sample") {
map.insert(ov::genai::do_sample(py::cast<bool>(value)));
} else if (key == "repetition_penalty") {
map.insert(ov::genai::repetition_penalty(py::cast<float>(value)));
} else if (key == "presence_penalty") {
map.insert(ov::genai::presence_penalty(py::cast<float>(value)));
} else if (key == "frequency_penalty") {
map.insert(ov::genai::frequency_penalty(py::cast<float>(value)));
} else if (key == "rng_seed") {
map.insert(ov::genai::rng_seed(py::cast<int>(value)));
} else if (key == "eos_token_id") {
map.insert(ov::genai::eos_token_id(py::cast<int>(value)));
} else if (key == "assistant_confidence_threshold") {
map.insert(ov::genai::assistant_confidence_threshold(py::cast<float>(value)));
} else if (key == "num_assistant_tokens") {
map.insert(ov::genai::num_assistant_tokens(py::cast<int>(value)));
} else if (key == "adapters") {
map.insert(ov::genai::adapters(py::cast<ov::genai::AdapterConfig>(value)));
} else {
return false;
}
return true;
}


} // namespace ov::genai::pybind::utils
2 changes: 2 additions & 0 deletions src/python/py_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ ov::genai::OptionalGenerationConfig update_config_from_kwargs(const ov::genai::O

ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& py_streamer);

bool generation_config_param_to_property(std::string key, py::object value, ov::AnyMap& map);
popovaan marked this conversation as resolved.
Show resolved Hide resolved

} // namespace ov::genai::pybind::utils
35 changes: 19 additions & 16 deletions src/python/py_vlm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,18 @@ py::object call_vlm_generate(
return py::cast(pipe.generate(prompt, images, updated_config, streamer));
}

ov::AnyMap vlm_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_compile_properties=true) {
ov::AnyMap vlm_kwargs_to_any_map(const py::kwargs& kwargs, bool called_from_init=true) {
ov::AnyMap params = {};

for (const auto& item : kwargs) {
std::string key = py::cast<std::string>(item.first);
py::object value = py::cast<py::object>(item.second);

if (key == "images") {
params.insert({ov::genai::images(std::move(py::cast<std::vector<ov::Tensor>>(value)))});
} else if (key == "image") {
params.insert({ov::genai::image(std::move(py::cast<ov::Tensor>(value)))});
} else if (key == "generation_config") {
params.insert({ov::genai::generation_config(std::move(py::cast<ov::genai::GenerationConfig>(value)))});
} else if (key == "streamer") {
auto py_streamer = py::cast<pyutils::PyBindStreamerVariant>(value);
params.insert({ov::genai::streamer(std::move(pyutils::pystreamer_to_streamer(py_streamer)))});

}
if (pyutils::generation_config_param_to_property(key, value, params)) {
continue;
}
else {
if (allow_compile_properties) {
if (called_from_init) {
// convert arbitrary objects to ov::Any
// not supported properties are not checked, as these properties are passed to compile(), which will throw exception in case of unsupported property
if (pyutils::py_object_is_any_map(value)) {
Expand All @@ -102,9 +94,20 @@ ov::AnyMap vlm_kwargs_to_any_map(const py::kwargs& kwargs, bool allow_compile_pr
}
}
else {
// generate doesn't run compile(), so only VLMPipeline specific properties are allowed
throw(std::invalid_argument("'" + key + "' is unexpected parameter name. "
"Use help(openvino_genai.VLMPipeline.generate) to get list of acceptable parameters."));
if (key == "images") {
params.insert({ov::genai::images(std::move(py::cast<std::vector<ov::Tensor>>(value)))});
} else if (key == "image") {
params.insert({ov::genai::image(std::move(py::cast<ov::Tensor>(value)))});
} else if (key == "generation_config") {
params.insert({ov::genai::generation_config(std::move(py::cast<ov::genai::GenerationConfig>(value)))});
} else if (key == "streamer") {
auto py_streamer = py::cast<pyutils::PyBindStreamerVariant>(value);
params.insert({ov::genai::streamer(std::move(pyutils::pystreamer_to_streamer(py_streamer)))});
} else {
// generate doesn't run compile(), so only VLMPipeline specific properties are allowed
throw(std::invalid_argument("'" + key + "' is unexpected parameter name. "
"Use help(openvino_genai.VLMPipeline.generate) to get list of acceptable parameters."));
}
}
}
}
Expand Down
Loading