Skip to content

Commit

Permalink
Enable use of Azure OpenAI client (#592)
Browse files Browse the repository at this point in the history
* Add support for Azure OpenAI client

* Add support for model/deployment parameters

* Handle Azure API version via environment

* Add openai as test dependency
  • Loading branch information
drdavella authored May 23, 2024
1 parent 5811e8e commit a66333c
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 21 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test = [
"Jinja2~=3.1.2",
"jsonschema~=4.22.0",
"lxml>=4.9.3,<5.3.0",
"openai~=1.23.0",
"mock==5.1.*",
"pre-commit<4",
"Pyjwt~=2.8.0",
Expand Down
29 changes: 17 additions & 12 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.semgrep import SemgrepRuleDetector
from codemodder.codetf import CodeTF
from codemodder.context import CodemodExecutionContext
from codemodder.context import CodemodExecutionContext, MisconfiguredAIClient
from codemodder.dependency import Dependency
from codemodder.logging import configure_logger, log_list, log_section, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
Expand Down Expand Up @@ -166,17 +166,22 @@ def run(original_args) -> int:
tool_result_files_map["defectdojo"] = argv.defectdojo_findings_json or []

repo_manager = PythonRepoManager(Path(argv.directory))
context = CodemodExecutionContext(
Path(argv.directory),
argv.dry_run,
argv.verbose,
codemod_registry,
repo_manager,
argv.path_include,
argv.path_exclude,
tool_result_files_map,
argv.max_workers,
)

try:
context = CodemodExecutionContext(
Path(argv.directory),
argv.dry_run,
argv.verbose,
codemod_registry,
repo_manager,
argv.path_include,
argv.path_exclude,
tool_result_files_map,
argv.max_workers,
)
except MisconfiguredAIClient as e:
logger.error(e)
return 3 # Codemodder instructions conflicted (according to spec)

repo_manager.parse_project()

Expand Down
55 changes: 46 additions & 9 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@
from codemodder.utils.timer import Timer

try:
from openai import Client
from openai import AzureOpenAI, OpenAI
except ImportError:
Client = None
OpenAI = None
AzureOpenAI = None


if TYPE_CHECKING:
from openai import Client
from openai import OpenAI

from codemodder.codemods.base_codemod import BaseCodemod


class MisconfiguredAIClient(ValueError):
pass


MODELS = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"


class CodemodExecutionContext:
_failures_by_codemod: dict[str, list[Path]] = {}
_dependency_update_by_codemod: dict[str, PackageStore | None] = {}
Expand All @@ -49,7 +58,7 @@ class CodemodExecutionContext:
path_exclude: list[str]
max_workers: int = 1
tool_result_files_map: dict[str, list[str]]
llm_client: Client | None = None
llm_client: OpenAI | None = None

def __init__(
self,
Expand Down Expand Up @@ -80,16 +89,39 @@ def __init__(
self.tool_result_files_map = tool_result_files_map or {}
self.llm_client = self._setup_llm_client()

def _setup_llm_client(self) -> Client | None:
if not Client:
logger.debug("OpenAI API client not available")
def _setup_llm_client(self) -> OpenAI | None:
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None

azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
raise MisconfiguredAIClient(
"Azure OpenAI API key and endpoint must both be set or unset"
)

if azure_openapi_key and azure_openapi_endpoint:
logger.info("Using Azure OpenAI API client")
return AzureOpenAI(
api_key=azure_openapi_key,
api_version=os.getenv(
"CODEMODDER_AZURE_OPENAI_API_VERSION",
DEFAULT_AZURE_OPENAI_API_VERSION,
),
azure_endpoint=azure_openapi_endpoint,
)

if not OpenAI:
logger.info("OpenAI API client not available")
return None

if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
logger.debug("OpenAI API key not found")
logger.info("OpenAI API key not found")
return None

return Client(api_key=api_key)
logger.info("Using OpenAI API client")
return OpenAI(api_key=api_key)

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down Expand Up @@ -212,3 +244,8 @@ def log_changes(self, codemod_id: str):
for change in changes:
logger.info(" - %s", change.path)
logger.debug(" diff:\n%s", indent(change.diff, " " * 6))

def __getattribute__(self, attr: str):
if (name := attr.replace("_", "-")) in MODELS:
return os.getenv(f"CODEMODDER_AZURE_OPENAI_{name.upper()}_DEPLOYMENT", name)
return super().__getattribute__(attr)
122 changes: 122 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

import pytest
from openai import AzureOpenAI, OpenAI

from codemodder.context import DEFAULT_AZURE_OPENAI_API_VERSION
from codemodder.context import CodemodExecutionContext as Context
from codemodder.context import MisconfiguredAIClient
from codemodder.dependency import Security
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
from codemodder.registry import load_registered_codemods
Expand Down Expand Up @@ -77,3 +82,120 @@ def test_failed_dependency_description(self, mocker):
```"""
in description
)

def test_setup_llm_client_no_env_vars(self, mocker):
mocker.patch.dict(os.environ, clear=True)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert context.llm_client is None

def test_setup_openai_llm_client(self, mocker):
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.llm_client, OpenAI)

def test_setup_azure_llm_client(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION

@pytest.mark.parametrize(
"env_var",
["CODEMODDER_AZURE_OPENAI_API_KEY", "CODEMODDER_AZURE_OPENAI_ENDPOINT"],
)
def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
mocker.patch.dict(os.environ, {env_var: "test"})
with pytest.raises(MisconfiguredAIClient):
Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)

def test_get_model_name(self, mocker):
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert context.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"

@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
def test_model_get_name_from_env(self, mocker, model):
name = "my-awesome-deployment"
mocker.patch.dict(
os.environ,
{
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert getattr(context, model.replace("-", "_")) == name

def test_get_api_version_from_env(self, mocker):
version = "fake-version"
mocker.patch.dict(
os.environ,
{
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
"CODEMODDER_AZURE_OPENAI_API_VERSION": version,
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == version

0 comments on commit a66333c

Please sign in to comment.