From 74ba6bf0876094025b1a7b949468afc4dcffe9ad Mon Sep 17 00:00:00 2001 From: Deepali Jain Date: Mon, 9 Dec 2024 04:33:08 -0800 Subject: [PATCH] Config files for Iris experiments PiperOrigin-RevId: 704233511 --- iris/configs/example_config.py | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 iris/configs/example_config.py diff --git a/iris/configs/example_config.py b/iris/configs/example_config.py new file mode 100644 index 0000000..b1b432a --- /dev/null +++ b/iris/configs/example_config.py @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example configuration for Iris experiments.""" + + +from iris import worker +from iris.algorithms import ars_algorithm +from ml_collections import config_dict +import numpy as np + + +def get_coordinator_config(): + """Coordinator config.""" + + return config_dict.ConfigDict( + dict( + save_rate=10, + eval_rate=1, + num_iterations=400, + num_evals_per_suggestion=1)) + + +def get_worker_config(): + """Worker config.""" + + return config_dict.ConfigDict( + dict( + worker_class=worker.SimpleWorker, + worker_args={'initial_params': np.ones(10), + 'blackbox_function': lambda x: -1 * np.sum(x**2)})) + + +def get_algo_config(): + """Algorithm config.""" + + return config_dict.ConfigDict( + dict( + algorithm_class=ars_algorithm.AugmentedRandomSearch, + algorithm_args=dict( + num_suggestions=8, + num_evals=10, + top_percentage=0.5, + std=0.1, + step_size=0.1, + random_seed=7))) + + +def get_config(): + """Main config.""" + return config_dict.ConfigDict( + dict( + coordinator=get_coordinator_config(), + worker=get_worker_config(), + algo=get_algo_config()))