From 67987e7d49f060bfad161a1b6e2fc5a67982913c Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Thu, 26 Sep 2024 15:22:05 +0000 Subject: [PATCH 1/5] added model integration test --- build.clj | 15 +- deps.edn | 57 ++++-- .../model_integration_test.clj | 169 +++++++++++++++++ notebooks/noj_book/automl.clj | 46 ++--- notebooks/noj_book/ml_basic.clj | 2 +- poetry.lock | 179 +----------------- pyproject.toml | 3 +- python.edn | 2 +- 8 files changed, 248 insertions(+), 225 deletions(-) create mode 100644 model-integration-tests/model_integration_test.clj diff --git a/build.clj b/build.clj index 104dd5f..a4a25f5 100644 --- a/build.clj +++ b/build.clj @@ -53,7 +53,7 @@ :pom-data (pom-template version)))) (defn generate-tests [opts] - (let [basis (b/create-basis {:aliases [:dev]}) + (let [basis (b/create-basis {:aliases [:gen-tests :model-integration-tests]}) cmds (b/java-command {:basis basis @@ -80,3 +80,16 @@ (dd/deploy {:installer :remote :artifact (b/resolve-path jar-file) :pom-file (b/pom-path (select-keys opts [:lib :class-dir]))})) opts) + + +(defn models-integration-tests "Run integration tests." [opts] + (let [basis (b/create-basis { :aliases [:model-integration-tests ]}) + cmds (b/java-command + {:basis basis + :main 'clojure.main + :main-args ["-m" "cognitect.test-runner" "-d" "model-integration-tests"]})] + (b/process cmds) + )opts) + + + diff --git a/deps.edn b/deps.edn index 68d220f..5177ef5 100644 --- a/deps.edn +++ b/deps.edn @@ -5,35 +5,50 @@ org.scicloj/kindly {:mvn/version "4-beta12"} generateme/fastmath {:mvn/version "3.0.0-alpha1"} aerial.hanami/aerial.hanami {:mvn/version "0.20.0"} - org.scicloj/hanamicloth {:mvn/version "1-alpha8" - :exclusions [scicloj/metamorph.ml]} - scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml.git" - :git/sha "50f47dad934a2786b3cf025bef509f1f3d1a7e1d" - :exclusions [generateme/fastmath]} - org.scicloj/scicloj.ml.tribuo {:mvn/version "0.1.1-branch-noj-2-alpha4-SNAPSHOT" - :exclusions [scicloj/metamorph.ml]} - org.tribuo/tribuo-regression-sgd {:mvn/version "4.3.1"} - org.tribuo/tribuo-regression-tree {:mvn/version "4.3.1"} - org.tribuo/tribuo-regression-xgboost {:mvn/version "4.3.1"} - org.tribuo/tribuo-classification-sgd {:mvn/version "4.3.1"} - org.tribuo/tribuo-classification-tree {:mvn/version "4.3.1"} - org.tribuo/tribuo-classification-xgboost {:mvn/version "4.3.1"} + org.scicloj/hanamicloth {:mvn/version "1-alpha8"} + + scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml" + :git/sha "50f47dad934a2786b3cf025bef509f1f3d1a7e1d"} + + ;;scicloj/metamorph.ml + ;;{:mvn/version "0.8.2-branch-noj-2-alpha4-SNAPSHOT"} + org.scicloj/scicloj.ml.tribuo {:mvn/version "0.1.1-branch-noj-2-alpha4-SNAPSHOT"} + org.tribuo/tribuo-regression-sgd {:mvn/version "4.2.0"} + org.tribuo/tribuo-regression-tree {:mvn/version "4.2.0"} + org.tribuo/tribuo-classification-sgd {:mvn/version "4.2.0"} + org.tribuo/tribuo-classification-tree {:mvn/version "4.2.0"} clj-python/libpython-clj {:mvn/version "2.025"} org.scicloj/kind-pyplot {:mvn/version "1-beta1"} scicloj/clojisr {:mvn/version "1.0.0"}} :aliases - {:build {:deps {io.github.clojure/tools.build {:mvn/version "0.9.6"} + {:gen-tests {:extra-paths ["build" "notebooks"] + :extra-deps { org.scicloj/clay {:mvn/version "2-beta16"}}} + + :build {:deps {io.github.clojure/tools.build {:mvn/version "0.9.6"} slipset/deps-deploy {:mvn/version "0.2.1"}} :ns-default build} :test {:extra-paths ["test" "notebooks"] :extra-deps {org.clojure/test.check {:mvn/version "1.1.1"} - io.github.cognitect-labs/test-runner - {:git/tag "v0.5.1" :git/sha "dfb30dd"} + io.github.cognitect-labs/test-runner {:git/tag "v0.5.1" :git/sha "dfb30dd"} org.scicloj/clay {:mvn/version "2-beta16"}}} - :dev {:extra-paths ["notebooks" "build"] - :extra-deps {org.scicloj/clay {:mvn/version "2-beta16"} - scicloj/scicloj.ml.smile {:mvn/version "7.4.1"} - org.scicloj/sklearn-clj {:mvn/version "0.4.1"} - }}}} + + :model-integration-tests + {:extra-paths ["model-integration-tests"] + :extra-deps {org.scicloj/scicloj.ml.smile {:mvn/version "7.4.2"} + org.scicloj/sklearn-clj {:mvn/version "0.4.1"} + scicloj/scicloj.ml.xgboost {:mvn/version "6.0.0"} + + org.bytedeco/arpack-ng {:mvn/version "3.7.0-1.5.4"} + org.bytedeco/openblas-platform {:mvn/version "0.3.10-1.5.4"} + org.bytedeco/arpack-ng-platform {:mvn/version "3.7.0-1.5.4"} + org.bytedeco/openblas {:mvn/version "0.3.10-1.5.4"} + org.bytedeco/javacpp {:mvn/version "1.5.4"} + scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml" + :git/sha "60ed8aa3aa51b3653794754dbd484168876549d4"} + io.github.cognitect-labs/test-runner {:git/tag "v0.5.1" :git/sha "dfb30dd"} + }} + + :dev {:extra-paths ["notebooks"] + :extra-deps {org.scicloj/clay {:mvn/version "2-beta16"}}}}} diff --git a/model-integration-tests/model_integration_test.clj b/model-integration-tests/model_integration_test.clj new file mode 100644 index 0000000..006bf39 --- /dev/null +++ b/model-integration-tests/model_integration_test.clj @@ -0,0 +1,169 @@ +(ns model-integration-test + (:require [scicloj.metamorph.core :as mm] + [scicloj.metamorph.ml :as ml] + [scicloj.metamorph.ml.loss :as loss] + [scicloj.metamorph.ml.toydata :as data] + [tech.v3.dataset.categorical :as ds-cat] + [tablecloth.api :as tc] + [clojure.string :as str] + [clojure.set :as set] + [clojure.test :refer [is deftest]] + [tech.v3.dataset :as ds]) + (:import + (smile.base.mlp ActivationFunction Cost HiddenLayerBuilder LayerBuilder OutputFunction OutputLayerBuilder)) + ) + +(def mlp-hidden-layer-builder + (HiddenLayerBuilder. 1 (ActivationFunction/linear))) + +(def mlp-output-layer-builder + (OutputLayerBuilder. 3 OutputFunction/LINEAR Cost/MEAN_SQUARED_ERROR)) + + +(require '[scicloj.metamorph.ml.classification] + '[scicloj.ml.smile.classification] + '[scicloj.ml.tribuo] + '[scicloj.sklearn-clj.ml] + '[scicloj.ml.xgboost] + ) + +(def min-accuracies + {:smile.classification/linear-discriminant-analysis 0.85}) + +(def smile-model-specs + (map + #(vector (get min-accuracies % 0.95) + {:model-type %}) + (->> (ml/model-definition-names) + (filter #(str/starts-with? (namespace %) "smile.classification")) + set + ((fn [x] (set/difference + x + #{:smile.classification/sparse-svm + :smile.classification/maxent-binomial + :smile.classification/maxent-multinomial + :smile.classification/mlp + :smile.classification/svm + :smile.classification/sparse-logistic-regression + :smile.classification/discrete-naive-bayes})))))) + + + +(def sklearn-model-specs + (map + #(vector 0.90 + {:model-type %}) + (->> (ml/model-definition-names) + (filter #(str/starts-with? (namespace %) "sklearn.classification" )) + set + ((fn [x] (set/difference + x + #{:sklearn.classification/perceptron + :sklearn.classification/sgd-classifier + :sklearn.classification/svc + + }))) + ))) + + +(def model-specs + (concat + [ + [0.98 { + ;; :validate-parameters 1 + ;; :round 10 + ;; :silent 0 + ;; :verbosity 3 + :model-type :xgboost/classification}] + [0.30 {:model-type :smile.classification/mlp + :layer-builders [mlp-hidden-layer-builder mlp-output-layer-builder]}] + [0.95 {:model-type :sklearn.classification/decision-tree-classifier}] + [0.95 {:model-type :sklearn.classification/random-forest-classifier}] + [0.95 {:model-type :sklearn.classification/logistic-regression}] + [0.93 {:model-type :scicloj.ml.tribuo/classification + :tribuo-components [{:name "logistic" + :type "org.tribuo.classification.sgd.linear.LinearSGDTrainer" + :properties {:seed "1234" + :shuffle "false" + :epochs "10"}}] + :tribuo-trainer-name "logistic"}] + [0.94 {:model-type :scicloj.ml.tribuo/classification + :tribuo-components [{:name "random-forest" + :type "org.tribuo.classification.dtree.CARTClassificationTrainer" + :properties {:maxDepth "8" + :useRandomSplitPoints "false" + :fractionFeaturesInSplit "0.5"}}] + :tribuo-trainer-name "random-forest"}] + [0.30 {:model-type :metamorph.ml/dummy-classifier}] + ] + smile-model-specs + ;sklearn-model-specs + )) + + + + + +(defn my-classification-accuracy [lhs rhs] + ;(println :lhs (meta lhs)) + ;(println :rhs (meta rhs)) + + (loss/classification-accuracy lhs rhs) + ) + +(defn verify-classification [model-spec expected-accuracy ds] + (println :verify (:model-type model-spec)) + (let [ + train-test-split + (tc/split->seq ds :kfold {:seed 1234 :k 10}) + + pipe + (mm/pipeline + {:metamorph/id :model} + (ml/model model-spec)) + + result + (ml/evaluate-pipelines + [pipe] + train-test-split + my-classification-accuracy + :accuracy) + accuracy (-> result first first :train-transform :mean) + ] + + (is (>= accuracy expected-accuracy) + (format "%s: expect at least: %s, found : %s" + (:model-type model-spec) + expected-accuracy accuracy)))) + + +(deftest verify-classifictions-iris + (run! + (fn [[acc spec]] (verify-classification spec acc (data/iris-ds))) + model-specs)) + + +(def iris-2 + (-> + (data/iris-ds) + ds-cat/reverse-map-categorical-xforms + )) + +;; (deftest verify-classification-iris-2 +;; (run! +;; (fn [[acc spec]] (verify-classification spec acc iris-2)) +;; smile-model-specs)) + + +(def iris-3 + (-> + (data/iris-ds) + (ds/assoc-metadata [:species] :categorical-map nil) + )) + + +(deftest verify-classification-iris-3 + (run! + (fn [[acc spec]] (verify-classification spec acc iris-3)) + smile-model-specs)) + diff --git a/notebooks/noj_book/automl.clj b/notebooks/noj_book/automl.clj index 6600cd4..dcca883 100644 --- a/notebooks/noj_book/automl.clj +++ b/notebooks/noj_book/automl.clj @@ -277,11 +277,9 @@ ctx-after-train (require '[scicloj.metamorph.ml :as ml] '[scicloj.metamorph.ml.loss :as loss] '[scicloj.metamorph.core :as mm] - '[scicloj.ml.tribuo] ;; register the tribuo models - '[scicloj.ml.smile.classification] ;; register the smile classification models - '[scicloj.metamorph.ml.classification] ;; register dummy classifier - '[scicloj.sklearn-clj.ml] ;; register all sklern models classifier - ) + '[scicloj.ml.tribuo] + '[scicloj.ml.xgboost] + '[scicloj.sklearn-clj.ml]) ;; ## Finding the best model automatically @@ -312,26 +310,23 @@ ctx-after-train (-> titanic-k-fold count) ;; The list of the model types we want to try: -(def models [{:model-type :metamorph.ml/dummy-classifier} - +(def models [{ :model-type :xgboost/classification + :round 10} + {:model-type :sklearn.classification/decision-tree-classifier} {:model-type :sklearn.classification/logistic-regression} - - {:model-type :smile.classification/random-forest} - - {:model-type :scicloj.ml.tribuo/classification - :tribuo-components [{:name "logistic" - :type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"}] - :tribuo-trainer-name "logistic"} - {:model-type :scicloj.ml.tribuo/classification - :tribuo-components [{:name "random-forest" - :type "org.tribuo.classification.dtree.CARTClassificationTrainer" - :properties {:maxDepth "8" - :useRandomSplitPoints "false" - :fractionFeaturesInSplit "0.5"}}] - :tribuo-trainer-name "random-forest"} - - - ]) + {:model-type :sklearn.classification/random-forest-classifier} + {:model-type :metamorph.ml/dummy-classifier} + {:model-type :scicloj.ml.tribuo/classification + :tribuo-components [{:name "logistic" + :type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"}] + :tribuo-trainer-name "logistic"} + {:model-type :scicloj.ml.tribuo/classification + :tribuo-components [{:name "random-forest" + :type "org.tribuo.classification.dtree.CARTClassificationTrainer" + :properties {:maxDepth "8" + :useRandomSplitPoints "false" + :fractionFeaturesInSplit "0.5"}}] + :tribuo-trainer-name "random-forest"}]) ;; This uses models from Smile and Tribuo, but could be any @@ -383,7 +378,8 @@ ctx-after-train titanic-k-fold loss/classification-accuracy :accuracy - {:return-best-crossvalidation-only false + {:map-fn :map + :return-best-crossvalidation-only false :return-best-pipeline-only false})) diff --git a/notebooks/noj_book/ml_basic.clj b/notebooks/noj_book/ml_basic.clj index f76c953..d4a2587 100644 --- a/notebooks/noj_book/ml_basic.clj +++ b/notebooks/noj_book/ml_basic.clj @@ -191,7 +191,7 @@ split (loss/classification-accuracy (:survived (ds-cat/reverse-map-categorical-xforms (:test split))) - (:survived (ds-cat/reverse-map-categorical-xforms lreg-prediction))) + (:survived lreg-prediction)) (kindly/check = 0.7373737373737373) ;; Its performance is better, 73 % diff --git a/poetry.lock b/poetry.lock index 4cdffbd..38f0047 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,99 +1,5 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. -[[package]] -name = "attrs" -version = "24.2.0" -description = "Classes Without Boilerplate" -optional = false -python-versions = ">=3.7" -files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, -] - -[package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] - -[[package]] -name = "basilisp" -version = "0.2.3" -description = "A Clojure-like lisp written for Python" -optional = false -python-versions = "<4.0,>=3.8" -files = [ - {file = "basilisp-0.2.3-py3-none-any.whl", hash = "sha256:d0d89643971256ad8379dfd2c7552bb8a7bb894a572d3169a363d06fff9a6956"}, - {file = "basilisp-0.2.3.tar.gz", hash = "sha256:fa447bbf3030eb0e902d0b3e34bf285a3fe92dd01190ede7fe77bda2a3465702"}, -] - -[package.dependencies] -attrs = ">=22.2.0" -immutables = ">=0.20,<1.0.0" -prompt-toolkit = ">=3.0.0,<4.0.0" -pyrsistent = ">=0.18.0,<1.0.0" -typing-extensions = ">=4.7.0,<5.0.0" - -[package.extras] -pygments = ["pygments (>=2.9.0,<3.0.0)"] -pytest = ["pytest (>=7.0.0,<9.0.0)"] - -[[package]] -name = "immutables" -version = "0.20" -description = "Immutable Collections" -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "immutables-0.20-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dea0ae4d7f31b145c18c16badeebc2f039d09411be4a8febb86e1244cf7f1ce0"}, - {file = "immutables-0.20-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2dd0dcef2f8d4523d34dbe1d2b7804b3d2a51fddbd104aad13f506a838a2ea15"}, - {file = "immutables-0.20-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:393dde58ffd6b4c089ffdf4cef5fe73dad37ce4681acffade5f5d5935ec23c93"}, - {file = "immutables-0.20-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1214b5a175df783662b7de94b4a82db55cc0ee206dd072fa9e279fb8895d8df"}, - {file = "immutables-0.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:2761e3dc2a6406943ce77b3505e9b3c1187846de65d7247548dc7edaa202fcba"}, - {file = "immutables-0.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2bcea81e7516bd823b4ed16f4f794531097888675be13e833b1cc946370d5237"}, - {file = "immutables-0.20-cp310-cp310-win32.whl", hash = "sha256:d828e7580f1fa203ddeab0b5e91f44bf95706e7f283ca9fbbcf0ae08f63d3084"}, - {file = "immutables-0.20-cp310-cp310-win_amd64.whl", hash = "sha256:380e2957ba3d63422b2f3fbbff0547c7bbe6479d611d3635c6411005a4264525"}, - {file = "immutables-0.20-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:532be32c7a25dae6cade28825c76d3004cf4d166a0bfacf04bda16056d59ba26"}, - {file = "immutables-0.20-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5302ce9c7827f8300f3dc34a695abb71e4a32bab09e65e5ad6e454785383347f"}, - {file = "immutables-0.20-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b51aec54b571ae466113509d4dc79a2808dc2ae9263b71fd6b37778cb49eb292"}, - {file = "immutables-0.20-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f56aea56e597ecf6631f24a4e26007b6a5f4fe30278b96eb90bc1f60506164"}, - {file = "immutables-0.20-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:085ac48ee3eef7baf070f181cae574489bbf65930a83ec5bbd65c9940d625db3"}, - {file = "immutables-0.20-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f063f53b5c0e8f541ae381f1d828f3d05bbed766a2d6c817f9218b8b37a4cb66"}, - {file = "immutables-0.20-cp311-cp311-win32.whl", hash = "sha256:b0436cc831b47e26bef637bcf143cf0273e49946cfb7c28c44486d70513a3080"}, - {file = "immutables-0.20-cp311-cp311-win_amd64.whl", hash = "sha256:5bb32aee1ea16fbb90f58f8bd96016bca87aba0a8e574e5fa218d0d83b142851"}, - {file = "immutables-0.20-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4ba726b7a3a696b9d4b122fa2c956bc68e866f3df1b92765060c88c64410ff82"}, - {file = "immutables-0.20-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5a88adf1dcc9d8ab07dba5e74deefcd5b5e38bc677815cbf9365dc43b69f1f08"}, - {file = "immutables-0.20-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1009a4e00e2e69a9b40c2f1272795f5a06ad72c9bf4638594d518e9cbd7a721a"}, - {file = "immutables-0.20-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96899994842c37cf4b9d6d2bedf685aae7810bd73f1538f8cba5426e2d65cb85"}, - {file = "immutables-0.20-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a606410b2ccb6ae339c3f26cccc9a92bcb16dc06f935d51edfd8ca68cf687e50"}, - {file = "immutables-0.20-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8e82754f72823085643a2c0e6a4c489b806613e94af205825fa81df2ba147a0"}, - {file = "immutables-0.20-cp312-cp312-win32.whl", hash = "sha256:525fb361bd7edc8a891633928d549713af8090c79c25af5cc06eb90b48cb3c64"}, - {file = "immutables-0.20-cp312-cp312-win_amd64.whl", hash = "sha256:a82afc3945e9ceb9bcd416dc4ed9b72f92760c42787e26de50610a8b81d48120"}, - {file = "immutables-0.20-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f17f25f21e82a1c349a61191cfb13e442a348b880b74cb01b00e0d1e848b63f4"}, - {file = "immutables-0.20-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:65954eb861c61af48debb1507518d45ae7d594b4fba7282785a70b48c5f51f9b"}, - {file = "immutables-0.20-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62f8a7a22939278127b7a206d05679b268b9cf665437125625348e902617cbad"}, - {file = "immutables-0.20-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac86f4372f4cfaa00206c12472fd3a78753092279e0552b7e1880944d71b04fe"}, - {file = "immutables-0.20-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e771198edc11a9e02ffa693911b3918c6cde0b64ad2e6672b076dbe005557ad8"}, - {file = "immutables-0.20-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc739fc07cff5df2e4f31addbd48660b5ac0da56e9f719f8bb45da8ddd632c63"}, - {file = "immutables-0.20-cp38-cp38-win32.whl", hash = "sha256:c086ccb44d9d3824b9bf816365d10b1b82837efc7119f8bab56bd7a27ed805a9"}, - {file = "immutables-0.20-cp38-cp38-win_amd64.whl", hash = "sha256:9cd2ee9c10bf00be3c94eb51854bc0b761326bd0a7ea0dad4272a3f182269ae6"}, - {file = "immutables-0.20-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d4f78cb748261f852953620ed991de74972446fd484ec69377a41e2f1a1beb75"}, - {file = "immutables-0.20-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d6449186ea91b7c17ec8e7bd9bf059858298b1db5c053f5d27de8eba077578ce"}, - {file = "immutables-0.20-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85dd9765b068f7beb297553fddfcf7f904bd58a184c520830a106a58f0c9bfb4"}, - {file = "immutables-0.20-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f349a7e0327b92dcefb863e49ace086f2f26e6689a4e022c98720c6e9696e763"}, - {file = "immutables-0.20-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e3a5462f6d3549bbf7d02ce929fb0cb6df9539445f0589105de4e8b99b906e69"}, - {file = "immutables-0.20-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cc51a01a64a6d2cd7db210a49ad010c2ac2e9e026745f23fd31e0784096dcfff"}, - {file = "immutables-0.20-cp39-cp39-win32.whl", hash = "sha256:83794712f0507416f2818edc63f84305358b8656a93e5b9e2ab056d9803c7507"}, - {file = "immutables-0.20-cp39-cp39-win_amd64.whl", hash = "sha256:2837b1078abc66d9f009bee9085cf62515d5516af9a5c9ea2751847e16efd236"}, - {file = "immutables-0.20.tar.gz", hash = "sha256:1d2f83e6a6a8455466cd97b9a90e2b4f7864648616dfa6b19d18f49badac3876"}, -] - -[package.extras] -test = ["flake8 (>=5.0,<6.0)", "mypy (>=1.4,<2.0)", "pycodestyle (>=2.9,<3.0)", "pytest (>=7.4,<8.0)"] - [[package]] name = "joblib" version = "1.4.2" @@ -249,61 +155,6 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] -[[package]] -name = "prompt-toolkit" -version = "3.0.47" -description = "Library for building powerful interactive command lines in Python" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, - {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, -] - -[package.dependencies] -wcwidth = "*" - -[[package]] -name = "pyrsistent" -version = "0.20.0" -description = "Persistent/Functional/Immutable data structures" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyrsistent-0.20.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c3aba3e01235221e5b229a6c05f585f344734bd1ad42a8ac51493d74722bbce"}, - {file = "pyrsistent-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1beb78af5423b879edaf23c5591ff292cf7c33979734c99aa66d5914ead880f"}, - {file = "pyrsistent-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21cc459636983764e692b9eba7144cdd54fdec23ccdb1e8ba392a63666c60c34"}, - {file = "pyrsistent-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5ac696f02b3fc01a710427585c855f65cd9c640e14f52abe52020722bb4906b"}, - {file = "pyrsistent-0.20.0-cp310-cp310-win32.whl", hash = "sha256:0724c506cd8b63c69c7f883cc233aac948c1ea946ea95996ad8b1380c25e1d3f"}, - {file = "pyrsistent-0.20.0-cp310-cp310-win_amd64.whl", hash = "sha256:8441cf9616d642c475684d6cf2520dd24812e996ba9af15e606df5f6fd9d04a7"}, - {file = "pyrsistent-0.20.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0f3b1bcaa1f0629c978b355a7c37acd58907390149b7311b5db1b37648eb6958"}, - {file = "pyrsistent-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cdd7ef1ea7a491ae70d826b6cc64868de09a1d5ff9ef8d574250d0940e275b8"}, - {file = "pyrsistent-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cae40a9e3ce178415040a0383f00e8d68b569e97f31928a3a8ad37e3fde6df6a"}, - {file = "pyrsistent-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6288b3fa6622ad8a91e6eb759cfc48ff3089e7c17fb1d4c59a919769314af224"}, - {file = "pyrsistent-0.20.0-cp311-cp311-win32.whl", hash = "sha256:7d29c23bdf6e5438c755b941cef867ec2a4a172ceb9f50553b6ed70d50dfd656"}, - {file = "pyrsistent-0.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:59a89bccd615551391f3237e00006a26bcf98a4d18623a19909a2c48b8e986ee"}, - {file = "pyrsistent-0.20.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:09848306523a3aba463c4b49493a760e7a6ca52e4826aa100ee99d8d39b7ad1e"}, - {file = "pyrsistent-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a14798c3005ec892bbada26485c2eea3b54109cb2533713e355c806891f63c5e"}, - {file = "pyrsistent-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b14decb628fac50db5e02ee5a35a9c0772d20277824cfe845c8a8b717c15daa3"}, - {file = "pyrsistent-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e2c116cc804d9b09ce9814d17df5edf1df0c624aba3b43bc1ad90411487036d"}, - {file = "pyrsistent-0.20.0-cp312-cp312-win32.whl", hash = "sha256:e78d0c7c1e99a4a45c99143900ea0546025e41bb59ebc10182e947cf1ece9174"}, - {file = "pyrsistent-0.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:4021a7f963d88ccd15b523787d18ed5e5269ce57aa4037146a2377ff607ae87d"}, - {file = "pyrsistent-0.20.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:79ed12ba79935adaac1664fd7e0e585a22caa539dfc9b7c7c6d5ebf91fb89054"}, - {file = "pyrsistent-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f920385a11207dc372a028b3f1e1038bb244b3ec38d448e6d8e43c6b3ba20e98"}, - {file = "pyrsistent-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f5c2d012671b7391803263419e31b5c7c21e7c95c8760d7fc35602353dee714"}, - {file = "pyrsistent-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef3992833fbd686ee783590639f4b8343a57f1f75de8633749d984dc0eb16c86"}, - {file = "pyrsistent-0.20.0-cp38-cp38-win32.whl", hash = "sha256:881bbea27bbd32d37eb24dd320a5e745a2a5b092a17f6debc1349252fac85423"}, - {file = "pyrsistent-0.20.0-cp38-cp38-win_amd64.whl", hash = "sha256:6d270ec9dd33cdb13f4d62c95c1a5a50e6b7cdd86302b494217137f760495b9d"}, - {file = "pyrsistent-0.20.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ca52d1ceae015859d16aded12584c59eb3825f7b50c6cfd621d4231a6cc624ce"}, - {file = "pyrsistent-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b318ca24db0f0518630e8b6f3831e9cba78f099ed5c1d65ffe3e023003043ba0"}, - {file = "pyrsistent-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fed2c3216a605dc9a6ea50c7e84c82906e3684c4e80d2908208f662a6cbf9022"}, - {file = "pyrsistent-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e14c95c16211d166f59c6611533d0dacce2e25de0f76e4c140fde250997b3ca"}, - {file = "pyrsistent-0.20.0-cp39-cp39-win32.whl", hash = "sha256:f058a615031eea4ef94ead6456f5ec2026c19fb5bd6bfe86e9665c4158cf802f"}, - {file = "pyrsistent-0.20.0-cp39-cp39-win_amd64.whl", hash = "sha256:58b8f6366e152092194ae68fefe18b9f0b4f89227dfd86a07770c3d86097aebf"}, - {file = "pyrsistent-0.20.0-py3-none-any.whl", hash = "sha256:c55acc4733aad6560a7f5f818466631f07efc001fd023f34a6c203f8b6df0f0b"}, - {file = "pyrsistent-0.20.0.tar.gz", hash = "sha256:4c48f78f62ab596c679086084d0dd13254ae4f3d6c72a83ffdf5ebdef8f265a4"}, -] - [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -446,40 +297,18 @@ files = [ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, ] -[[package]] -name = "typing-extensions" -version = "4.12.2" -description = "Backported and Experimental Type Hints for Python 3.8+" -optional = false -python-versions = ">=3.8" -files = [ - {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, - {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, -] - [[package]] name = "tzdata" -version = "2024.1" +version = "2024.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, -] - -[[package]] -name = "wcwidth" -version = "0.2.13" -description = "Measures the displayed width of unicode strings in a terminal" -optional = false -python-versions = "*" -files = [ - {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, - {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, ] [metadata] lock-version = "2.0" python-versions = "3.11.2" -content-hash = "4f0398a0377db46c9077df7a552d3114548b0cef411710cc6a65b4a35eec1718" +content-hash = "6d20d1febb5124269bbafee8745b54b5a080f02d65654e3c7f42b526cc32289e" diff --git a/pyproject.toml b/pyproject.toml index a512543..25dd156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,8 @@ [tool.poetry] package-mode = false + [tool.poetry.dependencies] python = "3.11.2" scikit-learn = "1.5.2" -basilisp = "0.2.3" pandas = "2.2.3" + diff --git a/python.edn b/python.edn index c7cd260..e823478 100644 --- a/python.edn +++ b/python.edn @@ -1 +1 @@ -{:python-executable ".venv/bin/python3"} \ No newline at end of file +{:python-executable ".venv/bin/python3"} From fb9c65ae3017ee5617ae31725c795adec0b13e22 Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Fri, 27 Sep 2024 07:00:18 +0000 Subject: [PATCH 2/5] fixed pythoin installation --- .devcontainer/devcontainer.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 0aafe8c..2ce1221 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -28,5 +28,7 @@ }, "remoteUser": "vscode", - "postStartCommand": "poetry install --sync" + "postStartCommand": {"install python packages": "poetry install --sync" , + "add link to python executable": "ln -s /usr/bin/python3 /usr/local/bin/python" + } } From 8ff60c99c9a9a830aad05bc3df96c9d8803f02ab Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Fri, 27 Sep 2024 07:58:08 +0000 Subject: [PATCH 3/5] use latest releases --- .gitignore | 1 + build.clj | 4 ++-- deps.edn | 12 +++--------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 129ba5c..9784716 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ book test/noj_book/ .venv/ +docs/ \ No newline at end of file diff --git a/build.clj b/build.clj index a4a25f5..0de640a 100644 --- a/build.clj +++ b/build.clj @@ -52,7 +52,7 @@ :src-dirs ["src"] :pom-data (pom-template version)))) -(defn generate-tests [opts] +(defn generate-tests [_] (let [basis (b/create-basis {:aliases [:gen-tests :model-integration-tests]}) cmds (b/java-command @@ -63,7 +63,7 @@ (when-not (zero? exit) (throw (ex-info "Tests generation failed" {}))))) (def opts {}) (defn ci "Run the CI pipeline of tests (and build the JAR)." [opts] - (generate-tests (assoc opts :aliases [:dev])) + (generate-tests nil) (test (assoc opts :aliases [:dev :test])) (b/delete {:path "target"}) (let [opts (jar-opts opts)] diff --git a/deps.edn b/deps.edn index 5177ef5..fda4d5e 100644 --- a/deps.edn +++ b/deps.edn @@ -6,13 +6,8 @@ generateme/fastmath {:mvn/version "3.0.0-alpha1"} aerial.hanami/aerial.hanami {:mvn/version "0.20.0"} org.scicloj/hanamicloth {:mvn/version "1-alpha8"} - - scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml" - :git/sha "50f47dad934a2786b3cf025bef509f1f3d1a7e1d"} - - ;;scicloj/metamorph.ml - ;;{:mvn/version "0.8.2-branch-noj-2-alpha4-SNAPSHOT"} - org.scicloj/scicloj.ml.tribuo {:mvn/version "0.1.1-branch-noj-2-alpha4-SNAPSHOT"} + org.scicloj/metamorph.ml {:mvn/version "0.9.0"} + org.scicloj/scicloj.ml.tribuo {:mvn/version "0.1.2"} org.tribuo/tribuo-regression-sgd {:mvn/version "4.2.0"} org.tribuo/tribuo-regression-tree {:mvn/version "4.2.0"} org.tribuo/tribuo-classification-sgd {:mvn/version "4.2.0"} @@ -45,8 +40,7 @@ org.bytedeco/arpack-ng-platform {:mvn/version "3.7.0-1.5.4"} org.bytedeco/openblas {:mvn/version "0.3.10-1.5.4"} org.bytedeco/javacpp {:mvn/version "1.5.4"} - scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml" - :git/sha "60ed8aa3aa51b3653794754dbd484168876549d4"} + org.scicloj/metamorph.ml {:mvn/version "0.9.0"} io.github.cognitect-labs/test-runner {:git/tag "v0.5.1" :git/sha "dfb30dd"} }} From b9ab3b1e0d156bef536fdddcdaa399aa1b68cc83 Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Fri, 27 Sep 2024 08:29:59 +0000 Subject: [PATCH 4/5] fixed python setup --- .devcontainer/devcontainer.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 2ce1221..140244a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -29,6 +29,6 @@ "remoteUser": "vscode", "postStartCommand": {"install python packages": "poetry install --sync" , - "add link to python executable": "ln -s /usr/bin/python3 /usr/local/bin/python" + "add link to python executable": "sudo ln -s /usr/bin/python3 /usr/local/bin/python" } } From 1c4c6a5f677a140f316d281bf245e6a098aef6da Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Fri, 27 Sep 2024 10:27:50 +0000 Subject: [PATCH 5/5] use aproximate comparision --- build.clj | 2 +- deps.edn | 4 ++-- notebooks/noj_book/ml_basic.clj | 31 +++++++++++++++++++++++++------ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/build.clj b/build.clj index 0de640a..d6184d8 100644 --- a/build.clj +++ b/build.clj @@ -53,7 +53,7 @@ :pom-data (pom-template version)))) (defn generate-tests [_] - (let [basis (b/create-basis {:aliases [:gen-tests :model-integration-tests]}) + (let [basis (b/create-basis {:aliases [:gen-tests :model-integration-tests :test]}) cmds (b/java-command {:basis basis diff --git a/deps.edn b/deps.edn index fda4d5e..9710598 100644 --- a/deps.edn +++ b/deps.edn @@ -16,8 +16,7 @@ org.scicloj/kind-pyplot {:mvn/version "1-beta1"} scicloj/clojisr {:mvn/version "1.0.0"}} :aliases - {:gen-tests {:extra-paths ["build" "notebooks"] - :extra-deps { org.scicloj/clay {:mvn/version "2-beta16"}}} + {:gen-tests {:extra-paths ["build"]} :build {:deps {io.github.clojure/tools.build {:mvn/version "0.9.6"} slipset/deps-deploy {:mvn/version "0.2.1"}} @@ -27,6 +26,7 @@ :test {:extra-paths ["test" "notebooks"] :extra-deps {org.clojure/test.check {:mvn/version "1.1.1"} io.github.cognitect-labs/test-runner {:git/tag "v0.5.1" :git/sha "dfb30dd"} + same/ish {:mvn/version "0.1.6"} org.scicloj/clay {:mvn/version "2-beta16"}}} :model-integration-tests diff --git a/notebooks/noj_book/ml_basic.clj b/notebooks/noj_book/ml_basic.clj index d4a2587..ff3fd06 100644 --- a/notebooks/noj_book/ml_basic.clj +++ b/notebooks/noj_book/ml_basic.clj @@ -11,12 +11,29 @@ (:require [tablecloth.api :as tc] [scicloj.metamorph.ml.toydata :as data] [tech.v3.dataset :as ds] + [same.core :as same] + [same.compare :as compare] [camel-snake-kebab.core :as csk] [scicloj.kindly.v4.kind :as kind] [scicloj.kindly.v4.api :as kindly])) +(defn round + [n scale rm] + (.setScale ^java.math.BigDecimal (bigdec n) + (int scale) + ^RoundingMode (if (instance? java.math.RoundingMode rm) + rm + (java.math.RoundingMode/valueOf + (str (if (ident? rm) (symbol rm) rm)))))) +(defn set-sameish-comparator! [scale] + (same/set-comparator! (fn [a b] + (let [a-rounded (round a scale :HALF_UP) + b-rounded (round b scale :HALF_UP)] + (= a-rounded + b-rounded))))) + ;; ## Inspect data ;; @@ -208,19 +225,21 @@ split (def rf-prediction (ml/predict (:test split) rf-model)) +(set-sameish-comparator! 1) ;; First five prediction including the probability distributions ;; are (-> rf-prediction (tc/head) (tc/rows)) -(kindly/check = - [["no" 0.6470588235294118 0.35294117647058826] - ["no" 0.5714285714285714 0.42857142857142855] - ["no" 0.8529411764705882 0.14705882352941177] - ["no" 0.8879310344827587 0.11206896551724138] - ["no" 0.8879310344827587 0.11206896551724138]]) +(kindly/check same/ish? + [["no" 0.64 0.35] + ["no" 0.57 0.42] + ["no" 0.85 0.14] + ["no" 0.88 0.11] + ["no" 0.88 0.11]]) + (loss/classification-accuracy (:survived (ds-cat/reverse-map-categorical-xforms (:test split)))