diff --git a/pyproject.toml b/pyproject.toml index 2c27af5a..49664812 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,9 @@ complexity = [ "radon==6.0.*", "xenon==0.9.*", ] +openai = [ + "openai~=1.0.0", +] all = [ "codemodder[test]", "codemodder[complexity]", diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 1bbf313e..ded78ef4 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -2,6 +2,7 @@ import itertools import logging +import os from pathlib import Path from textwrap import indent from typing import TYPE_CHECKING, Iterator, List @@ -20,7 +21,15 @@ from codemodder.registry import CodemodRegistry from codemodder.utils.timer import Timer +try: + from openai import Client +except ImportError: + Client = None + + if TYPE_CHECKING: + from openai import Client + from codemodder.codemods.base_codemod import BaseCodemod @@ -39,6 +48,7 @@ class CodemodExecutionContext: path_exclude: list[str] max_workers: int = 1 tool_result_files_map: dict[str, list[str]] + llm_client: Client | None = None def __init__( self, @@ -65,6 +75,18 @@ def __init__( self.path_exclude = path_exclude self.max_workers = max_workers self.tool_result_files_map = tool_result_files_map or {} + self.llm_client = self._setup_llm_client() + + def _setup_llm_client(self) -> Client | None: + if not Client: + logger.debug("OpenAI API client not available") + return None + + if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")): + logger.debug("OpenAI API key not found") + return None + + return Client(api_key=api_key) def add_results(self, codemod_name: str, change_sets: List[ChangeSet]): self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets)