Skip to content

Commit

Permalink
More tribuo models (#72)
Browse files Browse the repository at this point in the history
* added libsvm and liblinear artifacts
* added regressions tests
* added adaboost test
  • Loading branch information
behrica authored Dec 3, 2024
1 parent fdf82d4 commit 0d02257
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 55 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
All notable changes to this project will be documented in this file. This change log follows the conventions of [keepachangelog.com](http://keepachangelog.com/).

## [???] - unreleased
- added libsvm and liblinear Tribuo models
- updated deps (tech.ml.dataset, metamorph.ml, scicloj.ml.tribuo)

## [2-alpha12.1] - 2024-11-16
Expand Down
13 changes: 10 additions & 3 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
techascent/tech.ml.dataset {:mvn/version "7.032"} ;; no JVM crash
;;techascent/tech.ml.dataset {:mvn/version "7.033"} ;; JVM crash
;;techascent/tech.ml.dataset {:mvn/version "7.034"} ;; JVM crash

org.tribuo/tribuo-regression-liblinear {:mvn/version "4.3.1"}
org.tribuo/tribuo-regression-libsvm {:mvn/version "4.3.1"}
org.tribuo/tribuo-regression-sgd {:mvn/version "4.3.1"}
org.tribuo/tribuo-regression-tree {:mvn/version "4.3.1"}
org.tribuo/tribuo-regression-xgboost {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-sgd {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-tree {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-xgboost {:mvn/version "4.3.1"}

org.tribuo/tribuo-classification-liblinear {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-libsvm {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-sgd {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-tree {:mvn/version "4.3.1"}
org.tribuo/tribuo-classification-xgboost {:mvn/version "4.3.1"}


clj-python/libpython-clj {:mvn/version "2.025"}
org.scicloj/kind-pyplot {:mvn/version "1-beta2.1"}
Expand Down
201 changes: 149 additions & 52 deletions model-integration-tests/model_integration_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
[tablecloth.api :as tc]
[taoensso.nippy :as nippy]
[tech.v3.dataset :as ds]
[tech.v3.dataset.categorical :as ds-cat])
[tech.v3.dataset.categorical :as ds-cat]
[tech.v3.dataset.modelling :as ds-mod])
(:import
[org.tribuo.classification.libsvm SVMClassificationType$SVMMode]
[org.slf4j.bridge SLF4JBridgeHandler]
(smile.base.mlp
ActivationFunction
Expand All @@ -38,7 +40,10 @@ warnings.simplefilter('ignore')")


(require '[scicloj.metamorph.ml.classification]
'[scicloj.metamorph.ml.regression]
'[scicloj.ml.smile.classification]
'[scicloj.ml.smile.regression]

'[scicloj.ml.tribuo]
'[scicloj.sklearn-clj.ml]
'[scicloj.ml.xgboost])
Expand Down Expand Up @@ -107,22 +112,53 @@ warnings.simplefilter('ignore')")
not-working-with-iris-data-or-default-params-or-no-probab))))))

(def tribuo-model-specs [
[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "logistic"
:type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"
:properties {:seed "1234"
:shuffle "false"
:epochs "10"}}]
:tribuo-trainer-name "logistic"}]
[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "random-forest"
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
:properties {:maxDepth "8"
:useRandomSplitPoints "false"
:fractionFeaturesInSplit "0.5"}}]
:tribuo-trainer-name "random-forest"}]

])
[0.95 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "logistic"
:type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"
:properties {:seed "1234"
:shuffle "false"
:epochs "10"}}
{:name "ada"
:type "org.tribuo.classification.ensemble.AdaBoostTrainer"
:properties {:innerTrainer "logistic"
:numMembers "5"
:seed "1234"
}}]
:tribuo-trainer-name "ada"}]

[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "logistic"
:type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"
:properties {:seed "1234"
:shuffle "false"
:epochs "10"}}]
:tribuo-trainer-name "logistic"}]
[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "liblinear"
:type "org.tribuo.classification.liblinear.LibLinearClassificationTrainer"
:properties {:seed "1234"}}]
:tribuo-trainer-name "liblinear"}]

[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "C_SVC"
:type "org.tribuo.classification.libsvm.SVMClassificationType"
:properties {:type "C_SVC"}}

{:name "libsvm"
:type "org.tribuo.classification.libsvm.LibSVMClassificationTrainer"
:properties {:seed "1234"
:svmType "C_SVC"}}]
:tribuo-trainer-name "libsvm"}]



[0.93 {:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "random-forest"
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
:properties {:maxDepth "8"
:useRandomSplitPoints "false"
:fractionFeaturesInSplit "0.5"}}]
:tribuo-trainer-name "random-forest"}]])
(def xgboost-specs [
[0.94 {:model-type :xgboost/classification
:num-class 3}]])
Expand All @@ -138,16 +174,14 @@ warnings.simplefilter('ignore')")
other-specs
tribuo-model-specs
smile-model-specs
sklearn-model-specs))
sklearn-model-specs)
)





(defn my-classification-accuracy [lhs rhs]
;(println :lhs (meta lhs))
;(println :rhs (meta rhs))

(loss/classification-accuracy lhs rhs))

(defn- validate-nippy-round-trip [model-spec result val-ds]
Expand All @@ -173,10 +207,16 @@ warnings.simplefilter('ignore')")
new-accurcay
(loss/classification-accuracy
new-prediction
new-trueth)]
(is (<
(get min-accuracies (:model-type model-spec) 0.7)
new-accurcay))))
new-trueth)
min-accurcay (get min-accuracies (:model-type model-spec) 0.7)
]
(is (< min-accurcay

new-accurcay)
(format "min accurcay (%s) validation failed for: %s"
min-accurcay
model-spec)
)))

(defn classify [model-spec ds]
(println :verify (:model-type model-spec))
Expand Down Expand Up @@ -215,18 +255,16 @@ warnings.simplefilter('ignore')")
(defn- verify-fn [[expected-acc spec] iris]
(try
(let [acc (classify spec iris)]
(println :acc acc)
(is
(>= acc
expected-acc)

(format "%s: expect at least: %s, found : %s"
(:model-type spec)
expected-acc acc)))

(catch Exception e (is false e)))
)


(catch Exception e (is false e))))


(deftest verify-classification-iris-int-catmap
Expand All @@ -253,7 +291,7 @@ warnings.simplefilter('ignore')")
)]
(run!
#(verify-fn % iris)
;; only tribuo can deal with "string" target
;; only tribuo can deal with "string" target column
;;https://github.com/scicloj/noj/issues/36
(concat
;other-specs
Expand Down Expand Up @@ -290,29 +328,88 @@ warnings.simplefilter('ignore')")
(-> model-specs
;;https://github.com/scicloj/scicloj.ml.smile/issues/19
(remove-model-type :smile.classification/mlp)
;;https://github.com/scicloj/scicloj.ml.xgboost/issues/1
(remove-model-type :xgboost/classification)
))))



(comment
;; inspect trainer
(import '[ com.oracle.labs.mlrg.olcut.config DescribeConfigurable
ConfigurationManager]
'[com.oracle.labs.mlrg.olcut.config.edn EdnConfigFactory]
'[com.oracle.labs.mlrg.olcut.config.json JsonConfigFactory])
(def iris-ds-regression
(->
(data/iris-ds)
(tc/drop-columns [:species])
(ds-mod/set-inference-target :sepal_length)))


(def split
(first
(tc/split->seq
iris-ds-regression
:holdout
)))

(def iris-ds-regression--train
(:train split))

(def iris-ds-regression--test
(:test split))


(defn validate-regression [model-map]
(let [ model
(ml/train
iris-ds-regression--train
model-map)
mae
(loss/mae
(-> iris-ds-regression--test :sepal_length)
(-> (ml/predict iris-ds-regression--test model) :sepal_length))]

(println :mae mae)
(is (>
0.4 ;; dummy-model has mae of 0.69
mae) (format "mae validation failed: %s" model-map))))

(deftest regression-works
(run!
#(validate-regression {:model-type %})
[:metamorph.ml/ols
:fastmath/ols
:smile.regression/ordinary-least-square
:smile.regression/elastic-net
:smile.regression/lasso
:smile.regression/ridge
:smile.regression/gradient-tree-boost
:smile.regression/random-forest
:xgboost/linear-regression
:xgboost/regression
:sklearn.regression/linear-regression
:sklearn.regression/decision-tree-regressor
:sklearn.regression/random-forest-regressor

]))

(deftest tribuo-regression-works

(run!
#(validate-regression
{:model-type :scicloj.ml.tribuo/regression
:tribuo-trainer-name "reg"
:tribuo-components %})
[[{:name "loss"
:type "org.tribuo.regression.sgd.objectives.AbsoluteLoss"}
{:name "reg"
:type "org.tribuo.regression.sgd.linear.LinearSGDTrainer"
:properties {:objective "loss"}}]

[{:name "reg"
:type "org.tribuo.regression.rtree.CARTRegressionTrainer"}]

[{:name "reg"
:type "org.tribuo.regression.xgboost.XGBoostRegressionTrainer"
:properties {:numTrees "10"}}]

[{:name "nu"
:type "org.tribuo.regression.libsvm.SVMRegressionType"
:properties {:type "NU_SVR"}}
{:name "reg"
:type "org.tribuo.regression.libsvm.LibSVMRegressionTrainer"
:properties {:svmType "nu"}}]]))

(ConfigurationManager/addFileFormatFactory (EdnConfigFactory.))

(ConfigurationManager/addFileFormatFactory (JsonConfigFactory.))


(DescribeConfigurable/writeExampleConfig
(io/output-stream "/tmp/sdg.edn")
"edn"
org.tribuo.classification.sgd.linear.LinearSGDTrainer
(DescribeConfigurable/generateFieldInfo org.tribuo.classification.sgd.linear.LinearSGDTrainer)
)
)

0 comments on commit 0d02257

Please sign in to comment.