Skip to content

Commit

Permalink
add parralel solving option to pstore.solve_models
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvonk committed Sep 30, 2024
1 parent f476d2a commit 4fe532b
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions pastastore/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import warnings
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Literal, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1185,6 +1186,8 @@ def solve_models(
ignore_solve_errors: bool = False,
store_result: bool = True,
progressbar: bool = True,
parallel: bool = False,
max_workers: Optional[int] = None,
**kwargs,
) -> None:
"""Solves the models in the store.
Expand All @@ -1204,7 +1207,12 @@ def solve_models(
store_result : bool, optional
if True save optimized models, default is True
progressbar : bool, optional
show progressbar, default is True
show progressbar, default is True. Does not work (yet) for parallel.
parralel: bool, optional
if True, solve models in parallel using ProcessPoolExecutor
max_workers: int, optional
maximum number of workers to use in parallel solving, default is
None which will use the number of cores available on the machine
**kwargs :
arguments are passed to the solve method.
"""
Expand All @@ -1213,10 +1221,7 @@ def solve_models(
elif isinstance(mls, ps.Model):
mls = [mls.name]

desc = "Solving models"
for ml_name in tqdm(mls, desc=desc) if progressbar else mls:
ml = self.conn.get_models(ml_name)

def solve_model(ml: ps.Model) -> None:
m_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, pd.Series):
Expand All @@ -1239,6 +1244,15 @@ def solve_models(
else:
raise e

if parallel:
models = self.conn.get_models(mls, progressbar=False)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
executor.map(solve_model, models)
else:
for ml_name in tqdm(mls, desc="Solving models") if progressbar else mls:
ml = self.conn.get_models(ml_name, progressbar=False)
solve_model(ml)

def model_results(
self,
mls: Optional[Union[ps.Model, list, str]] = None,
Expand Down

0 comments on commit 4fe532b

Please sign in to comment.