From 80946d76c86b17b6ed6f48d5155ad03bca059694 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 11:07:07 +0200 Subject: [PATCH] done? --- CHANGELOG.md | 6 ++++++ Project.toml | 3 ++- src/JointEnergyModels.jl | 4 ++-- src/samplers.jl | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e56cdc3..24980b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.1.4]. +## Version [0.1.6] - 2024-09-09 + +### Changed + +- Now depending on new `EnergySamplers` package for energy-based sampling. [#27] + ## Version [0.1.5] - 2024-06-07 ### Changed diff --git a/Project.toml b/Project.toml index ab07394..dfc01f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "JointEnergyModels" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" authors = ["Patrick Altmeyer"] -version = "0.1.5" +version = "0.1.6" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +EnergySamplers = "f446124b-5d5e-4171-a6dd-a1d99768d3ce" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/JointEnergyModels.jl b/src/JointEnergyModels.jl index 127df19..fa80d52 100644 --- a/src/JointEnergyModels.jl +++ b/src/JointEnergyModels.jl @@ -1,11 +1,11 @@ module JointEnergyModels +using EnergySamplers using Flux using TaijaBase -using TaijaBase.Samplers using Reexport -@reexport import TaijaBase.Samplers: ConditionalSampler, UnconditionalSampler, JointSampler +@reexport import EnergySamplers: ConditionalSampler, UnconditionalSampler, JointSampler include("utils.jl") export _energy diff --git a/src/samplers.jl b/src/samplers.jl index e48a06b..cba3f2f 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -10,7 +10,7 @@ using Distributions Outer constructor for `ConditionalSampler`. """ -function TaijaBase.Samplers.ConditionalSampler( +function EnergySamplers.ConditionalSampler( X::Union{Tables.MatrixTable,AbstractMatrix}, y::Union{CategoricalArray,AbstractMatrix}; batch_size::Int = 1,