-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add types for contrib.stochastic_support and improve the HSGP module #1907
Merged
fehiepsi
merged 8 commits into
pyro-ppl:master
from
juanitorduz:types-stochastic-support
Nov 15, 2024
Merged
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f609a75
init
juanitorduz e2267ba
better return hints
juanitorduz 520f524
type improvements
juanitorduz a75f50c
better hints for jax arrays
juanitorduz 22496b1
initializaze ell_
juanitorduz de05465
undo change
juanitorduz d26ec57
fix condition
juanitorduz 495610c
OMG stupid typo #facepalm!
juanitorduz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what the implications of using
ArrayImpl
vsjax.numpy.ndarray
are?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great question! Thanks for the feedback! So we started trying with type-hints with
ArrayImpl
in #1819 (comment)On the other hand in https://jax.readthedocs.io/en/latest/jax.typing.html its suggested to use
jax.typing.ArrayLike
. However when I use this forx: ArrayLike
and have inside the functionx.ndim
,mypy
complains withHence, this is why I am continuing with
ArrayImpl
and waiting for feedback :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to avoid importing stuff from jaxlib. I guess it is better to use
jax.Array
(or jnp.ndarray)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re ndim: you can use jnp.ndim(x), which will accept ArrayLike instances, including python scalars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great input! I'll try that then 🙂!