Skip to content

Commit

Permalink
feat: add function to get build param from XGB (#190)
Browse files Browse the repository at this point in the history
* feat: add function to get build param from XGB

* build: update version to 2.1.0
  • Loading branch information
phamhoangtuan authored May 2, 2023
1 parent f14a4ac commit 214f6d1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
42 changes: 42 additions & 0 deletions h1st/model/ml/xgboost/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,47 @@ def __init__(
'debug': debug,
}

def __get_model_build_params(self, model: XGBRegressionModel) -> dict:
return {
'max_depth': model.max_depth,
'max_leaves': model.max_leaves,
'max_bin': model.max_bin,
'grow_policy': model.grow_policy,
'learning_rate': model.learning_rate,
'n_estimators': model.n_estimators,
'verbosity': model.verbosity,
'booster': model.booster,
'tree_method': model.tree_method,
'n_jobs': model.n_jobs,
'gamma': model.gamma,
'min_child_weight': model.min_child_weight,
'max_delta_step': model.max_delta_step,
'subsample': model.subsample,
'sampling_method': model.sampling_method,
'colsample_bytree': model.colsample_bytree,
'colsample_bylevel': model.colsample_bylevel,
'colsample_bynode': model.colsample_bynode,
'reg_alpha': model.reg_alpha,
'reg_lambda': model.reg_lambda,
'scale_pos_weight': model.scale_pos_weight,
'base_score': model.base_score,
'random_state': model.random_state,
'missing': model.missing,
'num_parallel_tree': model.num_parallel_tree,
'monotone_constraints': model.monotone_constraints,
'interaction_constraints': model.interaction_constraints,
'importance_type': model.importance_type,
'gpu_id': model.gpu_id,
'validate_parameters': model.validate_parameters,
'predictor': model.predictor,
'enable_categorical': model.enable_categorical,
'feature_types': model.feature_types,
'max_cat_to_onehot': model.max_cat_to_onehot,
'max_cat_threshold': model.max_cat_threshold,
'eval_metric': model.eval_metric,
'early_stopping_rounds': model.early_stopping_rounds
}

def train_base_model(self, input_data: dict) -> XGBRegressor:
"""
This function can be used to build and train XGBRegression model.
Expand Down Expand Up @@ -329,6 +370,7 @@ def train_base_model(self, input_data: dict) -> XGBRegressor:
}
)
self.stats['input_features'] = features
self.stats['build_params'] = self.__get_model_build_params(model)
return model

def prepare_data(self, prepared_data: dict):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "h1st"
version = "0.1.13"
version = "2.1.0"
description = "Human-First AI (H1st)"
authors = ["Aitomatic, Inc. <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

[metadata]
name = h1st
version = 0.1.12
version = 2.1.0

0 comments on commit 214f6d1

Please sign in to comment.