Skip to content

Commit

Permalink
Fix failed catboost bind on GPU (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
allegroai committed Mar 6, 2022
1 parent ac1750b commit e142954
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions clearml/binding/frameworks/catboost_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def _patch_model_io():
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
except ImportError:
pass
except Exception:
pass
except Exception as e:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")

@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
Expand Down Expand Up @@ -94,7 +93,17 @@ def _load(original_fn, f, *args, **kwargs):
def _fit(original_fn, obj, *args, **kwargs):
callbacks = kwargs.get("callbacks") or []
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
return original_fn(obj, *args, **kwargs)
# noinspection PyBroadException
try:
return original_fn(obj, *args, **kwargs)
except Exception:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger.report_text(
"Catboost metrics logging is not supported for GPU. "
"See https://github.com/catboost/catboost/issues/1792"
)
del kwargs["callbacks"]
return original_fn(obj, *args, **kwargs)

@staticmethod
def _generate_training_callback_class():
Expand Down

0 comments on commit e142954

Please sign in to comment.