Skip to content

Commit

Permalink
fix up jax tutorial and bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Dec 8, 2024
1 parent d96eb8b commit 725630c
Show file tree
Hide file tree
Showing 22 changed files with 609 additions and 1,237 deletions.
587 changes: 587 additions & 0 deletions docs/basic_tutorial/basic_tutorial_jax.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,epoch,val_loss
0,0.0,0.29503539204597473
1,1.0,0.21996137499809265
2,2.0,0.1790909618139267
3,3.0,0.10524507611989975
4,4.0,0.07911499589681625
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,epoch,val_loss
0,0.0,0.3224255442619324
1,1.0,0.24499981105327606
2,2.0,0.24103233218193054
3,3.0,0.1361701339483261
4,4.0,0.11248210817575455
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/overrides/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Navigate the site here!
</span>
<span class="right-margin">
v0.4.7 is out!
v0.4.8 is out!
</span>
<span>
<span class="twemoji">
Expand Down
2 changes: 1 addition & 1 deletion lan_pipeline_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ dependencies:
- jupyterlab-server==2.25.0
- jupyterlab-widgets==3.0.9
- kiwisolver==1.4.5
- lanfactory==0.4.7
- lanfactory==0.4.8
- markdown-it-py==3.0.0
- markupsafe==2.1.2
- matplotlib==3.8.0
Expand Down
2 changes: 1 addition & 1 deletion lanfactory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.7"
__version__ = "0.4.8"

from . import config
from . import trainers
Expand Down
3 changes: 0 additions & 3 deletions lanfactory/trainers/jax_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,10 @@ def __call__(self, inputs):
x = self.activation_funs[i](x)

if (not self.train) and (self.train_output_type == "logprob"):
print("passing through identity")
x = x # just for pedagogy
elif (not self.train) and (self.train_output_type == "logits"):
print("passing through transform")
x = -jnp.log((1 + jnp.exp(-x)))
elif not self.train:
print("passing through identity 2")
x = x # just for pedagogy

return x
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ nav:
- Overview: index.md
- Basic Tutorial:
- Installation: basic_tutorial/basic_tutorial.ipynb
- Basic Tutorial / JAX:
- Installation: basic_tutorial_jax/basic_tutorial_jax.ipynb
- API:
- lanfactory: api/lanfactory.md
- config: api/config.md
Expand Down
Loading

0 comments on commit 725630c

Please sign in to comment.