-
Notifications
You must be signed in to change notification settings - Fork 618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add NVIDIA NIM inference adapter #355
Conversation
a5760c0
to
2a25ace
Compare
# the root directory of this source tree. | ||
|
||
from ._config import NVIDIAConfig | ||
from ._nvidia import NVIDIAInferenceAdapter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a dynamic import within get_adapter_impl()
-- we want configs to be manipulated without needing implementation dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR. So good!
Re: testing, I'd like to have a reproducible e2e test (ala what we have in providers/tests/inference/test_text_inference.py
and providers/tests/inference/test_vision_inference.py
) -- just having an nvidia specific fixture there which could then be invoked as
pytest -s -v --providers inference=nvidia test_text_inference.py --env ...
would be great.
) | ||
|
||
@property | ||
def is_hosted(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be is_nvidia_hosted
perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's really an internal thing. i've removed it from the NVIDIAConfig api entirely.
@@ -0,0 +1,182 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a nitpick, you can ignore it you feel strongly. we don't usually do underscores in files in the repo - at least not yet. we don't even strongly enforce what symbols get exported out a module (that part is a bit sad admittedly.) could you make the files not have starting underscores?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my inclination is to be cautious about the exported symbols, but it's important to be cohesive w/ the project. i'll change these. ptal.
|
||
from llama_models.datatypes import SamplingParams | ||
from llama_models.llama3.api.datatypes import ( | ||
InterleavedTextMedia, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explicit imports. we will be code-modding all our other code to do this sane thing soon :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i spent so much time trying to figure out which classes were coming from which packages 😆
CoreModelId.llama3_2_90b_vision_instruct.value, | ||
), | ||
# TODO(mf): how do we handle Nemotron models? | ||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a "base" llama model this model would correspond most closely with? we like to know it because we try to format tools, etc. in a way which the model will work best with. this isn't strictly necessary if the provider / API works very robustly with tool calling, etc. but so far given our experience with various "openai" wrapper APIs, it has been spotty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvidia/llama-3.1-nemotron-51b-instruct (typo in my comment) is https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct/modelcard
there's now a 70b variant at https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard
stream: Optional[bool] = False, | ||
logprobs: Optional[LogProbConfig] = None, | ||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance this could be done? it's OK if not, but we have gone back and filled up many of the missing completion() methods also now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me come back and add it in another PR, same for embedding
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] | ||
]: | ||
if tool_prompt_format: | ||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
i've updated the PR description to note that this does not cover structured output, vision models, embedding or completion apis. if it's ok, i'll follow up with PRs to add those features.
|
@ashwinb i find
it's an accuracy test because it checks the value of first/last name, birth year, and num seasons. i find that -
suggestions (not mutually exclusive) -
|
@mattf I agree with your comments on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Merging!
# What does this PR do? this PR adds a basic inference adapter to NVIDIA NIMs what it does - - chat completion api - tool calls - streaming - structured output - logprobs - support hosted NIM on integrate.api.nvidia.com - support downloaded NIM containers what it does not do - - completion api - embedding api - vision models - builtin tools - have certainty that sampling strategies are correct ## Feature/Issue validation/testing/test plan `pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...` all tests should pass. there are pydantic v1 warnings. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? - [x] Did you write any new necessary tests? Thanks for contributing 🎉!
What does this PR do?
this PR adds a basic inference adapter to NVIDIA NIMs
what it does -
what it does not do -
Feature/Issue validation/testing/test plan
pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...
all tests should pass. there are pydantic v1 warnings.
Before submitting
Pull Request section?
to it if that's the case.
Thanks for contributing 🎉!