diff --git a/.semversioner/next-release/patch-20241025031711368197.json b/.semversioner/next-release/patch-20241025031711368197.json new file mode 100644 index 0000000000..d0b083f7f1 --- /dev/null +++ b/.semversioner/next-release/patch-20241025031711368197.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "move import statements out of init files" +} diff --git a/.semversioner/next-release/patch-20241031180003172666.json b/.semversioner/next-release/patch-20241031180003172666.json new file mode 100644 index 0000000000..890da217a6 --- /dev/null +++ b/.semversioner/next-release/patch-20241031180003172666.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "fix autocompletion of existing files/directory paths." +} diff --git a/docs/prompt_tuning/auto_prompt_tuning.md b/docs/prompt_tuning/auto_prompt_tuning.md index 6279fc5e39..c371f8fe1f 100644 --- a/docs/prompt_tuning/auto_prompt_tuning.md +++ b/docs/prompt_tuning/auto_prompt_tuning.md @@ -20,9 +20,9 @@ Before running auto tuning, ensure you have already initialized your workspace w You can run the main script from the command line with various options: ```bash -graphrag prompt-tune [--root ROOT] [--domain DOMAIN] [--method METHOD] [--limit LIMIT] [--language LANGUAGE] \ +graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selection-method METHOD] [--limit LIMIT] [--language LANGUAGE] \ [--max-tokens MAX_TOKENS] [--chunk-size CHUNK_SIZE] [--n-subset-max N_SUBSET_MAX] [--k K] \ -[--min-examples-required MIN_EXAMPLES_REQUIRED] [--no-entity-types] [--output OUTPUT] +[--min-examples-required MIN_EXAMPLES_REQUIRED] [--discover-entity-types] [--output OUTPUT] ``` ## Command-Line Options @@ -49,7 +49,7 @@ graphrag prompt-tune [--root ROOT] [--domain DOMAIN] [--method METHOD] [--limit - `--min-examples-required` (optional): The minimum number of examples required for entity extraction prompts. Default is 2. -- `--no-entity-types` (optional): Use untyped entity extraction generation. We recommend using this when your data covers a lot of topics or it is highly randomized. +- `--discover-entity-types` (optional): Allow the LLM to discover and extract entities automatically. We recommend using this when your data covers a lot of topics or it is highly randomized. - `--output` (optional): The folder to save the generated prompts. Default is "prompts". diff --git a/examples/custom_input/run.py b/examples/custom_input/run.py index ba39033e12..debb022379 100644 --- a/examples/custom_input/run.py +++ b/examples/custom_input/run.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.index import run_pipeline_with_config +from graphrag.index.run import run_pipeline_with_config pipeline_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), "./pipeline.yml" diff --git a/examples/single_verb/run.py b/examples/single_verb/run.py index bc56158543..99f8137a98 100644 --- a/examples/single_verb/run.py +++ b/examples/single_verb/run.py @@ -5,8 +5,8 @@ import pandas as pd -from graphrag.index import run_pipeline, run_pipeline_with_config -from graphrag.index.config import PipelineWorkflowReference +from graphrag.index.config.workflow import PipelineWorkflowReference +from graphrag.index.run import run_pipeline, run_pipeline_with_config # our fake dataset dataset = pd.DataFrame([{"col1": 2, "col2": 4}, {"col1": 5, "col2": 10}]) diff --git a/examples/use_built_in_workflows/run.py b/examples/use_built_in_workflows/run.py index def3a0a67e..7212126d0b 100644 --- a/examples/use_built_in_workflows/run.py +++ b/examples/use_built_in_workflows/run.py @@ -3,9 +3,10 @@ import asyncio import os -from graphrag.index import run_pipeline, run_pipeline_with_config -from graphrag.index.config import PipelineCSVInputConfig, PipelineWorkflowReference -from graphrag.index.input import load_input +from graphrag.index.config.input import PipelineCSVInputConfig +from graphrag.index.config.workflow import PipelineWorkflowReference +from graphrag.index.input.load_input import load_input +from graphrag.index.run import run_pipeline, run_pipeline_with_config sample_data_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../_sample_data/" diff --git a/graphrag/__main__.py b/graphrag/__main__.py index ae2421478c..faafaee5b2 100644 --- a/graphrag/__main__.py +++ b/graphrag/__main__.py @@ -3,6 +3,6 @@ """The GraphRAG package.""" -from .cli.main import app +from graphrag.cli.main import app app(prog_name="graphrag") diff --git a/graphrag/api/__init__.py b/graphrag/api/__init__.py index 49059f3e2c..6165122e5c 100644 --- a/graphrag/api/__init__.py +++ b/graphrag/api/__init__.py @@ -8,7 +8,7 @@ """ from graphrag.api.index import build_index -from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts +from graphrag.api.prompt_tune import generate_indexing_prompts from graphrag.api.query import ( drift_search, global_search, @@ -16,6 +16,7 @@ local_search, local_search_streaming, ) +from graphrag.prompt_tune.types import DocSelectionType __all__ = [ # noqa: RUF022 # index API diff --git a/graphrag/api/index.py b/graphrag/api/index.py index 77ba1e5a8e..90fb306942 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -10,13 +10,14 @@ from pathlib import Path -from graphrag.config import CacheType, GraphRagConfig +from graphrag.config.enums import CacheType +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.index.emit.types import TableEmitterType from graphrag.index.run import run_pipeline_with_config from graphrag.index.typing import PipelineRunResult -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter from graphrag.vector_stores.factory import VectorStoreType diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index ff0cfdcfb5..917727214c 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -15,25 +15,32 @@ from pydantic import PositiveInt, validate_call from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.llm import load_llm -from graphrag.logging import PrintProgressReporter -from graphrag.prompt_tune.generator import ( - MAX_TOKEN_COUNT, - create_community_summarization_prompt, - create_entity_extraction_prompt, - create_entity_summarization_prompt, - detect_language, +from graphrag.index.llm.load_llm import load_llm +from graphrag.logging.print_progress import PrintProgressReporter +from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT +from graphrag.prompt_tune.generator.community_report_rating import ( generate_community_report_rating, +) +from graphrag.prompt_tune.generator.community_report_summarization import ( + create_community_summarization_prompt, +) +from graphrag.prompt_tune.generator.community_reporter_role import ( generate_community_reporter_role, - generate_domain, +) +from graphrag.prompt_tune.generator.domain import generate_domain +from graphrag.prompt_tune.generator.entity_extraction_prompt import ( + create_entity_extraction_prompt, +) +from graphrag.prompt_tune.generator.entity_relationship import ( generate_entity_relationship_examples, - generate_entity_types, - generate_persona, ) -from graphrag.prompt_tune.loader import ( - MIN_CHUNK_SIZE, - load_docs_in_chunks, +from graphrag.prompt_tune.generator.entity_summarization_prompt import ( + create_entity_summarization_prompt, ) +from graphrag.prompt_tune.generator.entity_types import generate_entity_types +from graphrag.prompt_tune.generator.language import detect_language +from graphrag.prompt_tune.generator.persona import generate_persona +from graphrag.prompt_tune.loader.input import MIN_CHUNK_SIZE, load_docs_in_chunks from graphrag.prompt_tune.types import DocSelectionType diff --git a/graphrag/api/query.py b/graphrag/api/query.py index 21648a12a8..7149211c83 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -24,12 +24,12 @@ import pandas as pd from pydantic import validate_call -from graphrag.config import GraphRagConfig +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.config.embeddings import ( community_full_content_embedding, entity_description_embedding, ) -from graphrag.logging import PrintProgressReporter +from graphrag.logging.print_progress import PrintProgressReporter from graphrag.query.factories import ( get_drift_search_engine, get_global_search_engine, @@ -47,8 +47,8 @@ from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 from graphrag.utils.cli import redact from graphrag.utils.embeddings import create_collection_name -from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.factory import VectorStoreFactory, VectorStoreType reporter = PrintProgressReporter("") diff --git a/graphrag/callbacks/factories.py b/graphrag/callbacks/factories.py index 3f3b64788f..257a0d0a6c 100644 --- a/graphrag/callbacks/factories.py +++ b/graphrag/callbacks/factories.py @@ -8,17 +8,16 @@ from datashaper import WorkflowCallbacks -from graphrag.config import ReportingType -from graphrag.index.config import ( +from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks +from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks +from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks +from graphrag.config.enums import ReportingType +from graphrag.index.config.reporting import ( PipelineBlobReportingConfig, PipelineFileReportingConfig, PipelineReportingConfig, ) -from .blob_workflow_callbacks import BlobWorkflowCallbacks -from .console_workflow_callbacks import ConsoleWorkflowCallbacks -from .file_workflow_callbacks import FileWorkflowCallbacks - def create_pipeline_reporter( config: PipelineReportingConfig | None, root_dir: str | None diff --git a/graphrag/callbacks/global_search_callbacks.py b/graphrag/callbacks/global_search_callbacks.py index 32c6fc8668..c8f1395bd9 100644 --- a/graphrag/callbacks/global_search_callbacks.py +++ b/graphrag/callbacks/global_search_callbacks.py @@ -3,10 +3,9 @@ """GlobalSearch LLM Callbacks.""" +from graphrag.callbacks.llm_callbacks import BaseLLMCallback from graphrag.query.structured_search.base import SearchResult -from .llm_callbacks import BaseLLMCallback - class GlobalSearchLLMCallback(BaseLLMCallback): """GlobalSearch LLM Callbacks.""" diff --git a/graphrag/callbacks/progress_workflow_callbacks.py b/graphrag/callbacks/progress_workflow_callbacks.py index 31c29543a5..d4c9407a58 100644 --- a/graphrag/callbacks/progress_workflow_callbacks.py +++ b/graphrag/callbacks/progress_workflow_callbacks.py @@ -7,7 +7,7 @@ from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): diff --git a/graphrag/cli/index.py b/graphrag/cli/index.py index c9ec2bdc05..90a720fab4 100644 --- a/graphrag/cli/index.py +++ b/graphrag/cli/index.py @@ -11,15 +11,15 @@ from pathlib import Path import graphrag.api as api -from graphrag.config import ( - CacheType, - enable_logging_with_config, - load_config, - resolve_paths, -) +from graphrag.config.enums import CacheType +from graphrag.config.load_config import load_config +from graphrag.config.logging import enable_logging_with_config +from graphrag.config.resolve_path import resolve_paths from graphrag.index.emit.types import TableEmitterType from graphrag.index.validate_config import validate_config_names -from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter +from graphrag.logging.base import ProgressReporter +from graphrag.logging.factories import create_progress_reporter +from graphrag.logging.types import ReporterType from graphrag.utils.cli import redact # Ignore warnings from numba diff --git a/graphrag/cli/initialize.py b/graphrag/cli/initialize.py index 46bf6167df..992d38f71e 100644 --- a/graphrag/cli/initialize.py +++ b/graphrag/cli/initialize.py @@ -6,7 +6,8 @@ from pathlib import Path from graphrag.config.init_content import INIT_DOTENV, INIT_YAML -from graphrag.logging import ReporterType, create_progress_reporter +from graphrag.logging.factories import create_progress_reporter +from graphrag.logging.types import ReporterType from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT from graphrag.prompts.index.community_report import ( COMMUNITY_REPORT_PROMPT, diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 27bb5be34b..919015ae31 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -3,23 +3,24 @@ """CLI entrypoint.""" -import asyncio +import os +import re +from collections.abc import Callable from enum import Enum from pathlib import Path from typing import Annotated import typer -from graphrag.api import DocSelectionType from graphrag.index.emit.types import TableEmitterType -from graphrag.logging import ReporterType -from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT -from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE - -from .index import index_cli, update_cli -from .initialize import initialize_project_at -from .prompt_tune import prompt_tune -from .query import run_drift_search, run_global_search, run_local_search +from graphrag.logging.types import ReporterType +from graphrag.prompt_tune.defaults import ( + MAX_TOKEN_COUNT, + MIN_CHUNK_SIZE, + N_SUBSET_MAX, + K, +) +from graphrag.prompt_tune.types import DocSelectionType INVALID_METHOD_ERROR = "Invalid method" @@ -29,6 +30,48 @@ ) +# A workaround for typer's lack of support for proper autocompletion of file/directory paths +# For more detail, watch +# https://github.com/fastapi/typer/discussions/682 +# https://github.com/fastapi/typer/issues/951 +def path_autocomplete( + file_okay: bool = True, + dir_okay: bool = True, + readable: bool = True, + writable: bool = False, + match_wildcard: str | None = None, +) -> Callable[[str], list[str]]: + """Autocomplete file and directory paths.""" + + def wildcard_match(string: str, pattern: str) -> bool: + regex = re.escape(pattern).replace(r"\?", ".").replace(r"\*", ".*") + return re.fullmatch(regex, string) is not None + + def completer(incomplete: str) -> list[str]: + items = os.listdir() + completions = [] + for item in items: + if not file_okay and Path(item).is_file(): + continue + if not dir_okay and Path(item).is_dir(): + continue + if readable and not os.access(item, os.R_OK): + continue + if writable and not os.access(item, os.W_OK): + continue + completions.append(item) + if match_wildcard: + completions = filter( + lambda i: wildcard_match(i, match_wildcard) + if match_wildcard + else False, + completions, + ) + return [i for i in completions if i.startswith(incomplete)] + + return completer + + class SearchType(Enum): """The type of search to run.""" @@ -50,10 +93,15 @@ def _initialize_cli( dir_okay=True, writable=True, resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" + ), ), ], ): """Generate a default configuration file.""" + from graphrag.cli.initialize import initialize_project_at + initialize_project_at(path=root) @@ -73,6 +121,9 @@ def _index_cli( dir_okay=True, writable=True, resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" + ), ), ] = Path(), # set default to current directory verbose: Annotated[ @@ -114,6 +165,8 @@ def _index_cli( ] = None, ): """Build a knowledge graph index.""" + from graphrag.cli.index import index_cli + index_cli( root_dir=root, verbose=verbose, @@ -181,6 +234,8 @@ def _update_cli( Applies a default storage configuration (if not provided by config), saving the new index to the local file system in the `update_output` folder. """ + from graphrag.cli.index import update_cli + update_cli( root_dir=root, verbose=verbose, @@ -204,12 +259,21 @@ def _prompt_tune_cli( dir_okay=True, writable=True, resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" + ), ), ] = Path(), # set default to current directory config: Annotated[ Path | None, typer.Option( - help="The configuration to use.", exists=True, file_okay=True, readable=True + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + autocompletion=path_autocomplete( + file_okay=True, dir_okay=False, match_wildcard="*" + ), ), ] = None, domain: Annotated[ @@ -226,13 +290,13 @@ def _prompt_tune_cli( typer.Option( help="The number of text chunks to embed when --selection-method=auto." ), - ] = 300, + ] = N_SUBSET_MAX, k: Annotated[ int, typer.Option( help="The maximum number of documents to select from each centroid when --selection-method=auto." ), - ] = 15, + ] = K, limit: Annotated[ int, typer.Option( @@ -271,6 +335,10 @@ def _prompt_tune_cli( ] = Path("prompts"), ): """Generate custom graphrag prompts with your own data (i.e. auto templating).""" + import asyncio + + from graphrag.cli.prompt_tune import prompt_tune + loop = asyncio.get_event_loop() loop.run_until_complete( prompt_tune( @@ -298,7 +366,13 @@ def _query_cli( config: Annotated[ Path | None, typer.Option( - help="The configuration to use.", exists=True, file_okay=True, readable=True + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + autocompletion=path_autocomplete( + file_okay=True, dir_okay=False, match_wildcard="*" + ), ), ] = None, data: Annotated[ @@ -309,6 +383,9 @@ def _query_cli( dir_okay=True, readable=True, resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, match_wildcard="*" + ), ), ] = None, root: Annotated[ @@ -319,6 +396,9 @@ def _query_cli( dir_okay=True, writable=True, resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, match_wildcard="*" + ), ), ] = Path(), # set default to current directory community_level: Annotated[ @@ -342,6 +422,8 @@ def _query_cli( ] = False, ): """Query a knowledge graph index.""" + from graphrag.cli.query import run_drift_search, run_global_search, run_local_search + match method: case SearchType.LOCAL: run_local_search( diff --git a/graphrag/cli/prompt_tune.py b/graphrag/cli/prompt_tune.py index cbb36e00ba..feaa08c32a 100644 --- a/graphrag/cli/prompt_tune.py +++ b/graphrag/cli/prompt_tune.py @@ -6,8 +6,8 @@ from pathlib import Path import graphrag.api as api -from graphrag.config import load_config -from graphrag.logging import PrintProgressReporter +from graphrag.config.load_config import load_config +from graphrag.logging.print_progress import PrintProgressReporter from graphrag.prompt_tune.generator.community_report_summarization import ( COMMUNITY_SUMMARIZATION_FILENAME, ) diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 815313a2d5..ea9116c695 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -10,9 +10,11 @@ import pandas as pd import graphrag.api as api -from graphrag.config import GraphRagConfig, load_config, resolve_paths +from graphrag.config.load_config import load_config +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.resolve_path import resolve_paths from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.logging import PrintProgressReporter +from graphrag.logging.print_progress import PrintProgressReporter from graphrag.utils.storage import _create_storage, _load_table_from_storage reporter = PrintProgressReporter("") diff --git a/graphrag/config/__init__.py b/graphrag/config/__init__.py index 3354f2ccd8..90e6e010af 100644 --- a/graphrag/config/__init__.py +++ b/graphrag/config/__init__.py @@ -2,134 +2,3 @@ # Licensed under the MIT License """The Indexing Engine default config package root.""" - -from .config_file_loader import load_config_from_file, search_for_config_in_root_dir -from .create_graphrag_config import ( - create_graphrag_config, -) -from .enums import ( - CacheType, - InputFileType, - InputType, - LLMType, - ReportingType, - StorageType, - TextEmbeddingTarget, -) -from .errors import ( - ApiKeyMissingError, - AzureApiBaseMissingError, - AzureDeploymentNameMissingError, -) -from .input_models import ( - CacheConfigInput, - ChunkingConfigInput, - ClaimExtractionConfigInput, - ClusterGraphConfigInput, - CommunityReportsConfigInput, - EmbedGraphConfigInput, - EntityExtractionConfigInput, - GlobalSearchConfigInput, - GraphRagConfigInput, - InputConfigInput, - LLMConfigInput, - LLMParametersInput, - LocalSearchConfigInput, - ParallelizationParametersInput, - ReportingConfigInput, - SnapshotsConfigInput, - StorageConfigInput, - SummarizeDescriptionsConfigInput, - TextEmbeddingConfigInput, - UmapConfigInput, -) -from .load_config import load_config -from .logging import enable_logging_with_config -from .models import ( - CacheConfig, - ChunkingConfig, - ClaimExtractionConfig, - ClusterGraphConfig, - CommunityReportsConfig, - DRIFTSearchConfig, - EmbedGraphConfig, - EntityExtractionConfig, - GlobalSearchConfig, - GraphRagConfig, - InputConfig, - LLMConfig, - LLMParameters, - LocalSearchConfig, - ParallelizationParameters, - ReportingConfig, - SnapshotsConfig, - StorageConfig, - SummarizeDescriptionsConfig, - TextEmbeddingConfig, - UmapConfig, -) -from .read_dotenv import read_dotenv -from .resolve_path import resolve_path, resolve_paths - -__all__ = [ - "ApiKeyMissingError", - "AzureApiBaseMissingError", - "AzureDeploymentNameMissingError", - "CacheConfig", - "CacheConfigInput", - "CacheType", - "ChunkingConfig", - "ChunkingConfigInput", - "ClaimExtractionConfig", - "ClaimExtractionConfigInput", - "ClusterGraphConfig", - "ClusterGraphConfigInput", - "CommunityReportsConfig", - "CommunityReportsConfigInput", - "DRIFTSearchConfig", - "EmbedGraphConfig", - "EmbedGraphConfigInput", - "EntityExtractionConfig", - "EntityExtractionConfigInput", - "GlobalSearchConfig", - "GlobalSearchConfigInput", - "GraphRagConfig", - "GraphRagConfigInput", - "InputConfig", - "InputConfigInput", - "InputFileType", - "InputType", - "LLMConfig", - "LLMConfigInput", - "LLMParameters", - "LLMParametersInput", - "LLMType", - "LocalSearchConfig", - "LocalSearchConfigInput", - "ParallelizationParameters", - "ParallelizationParametersInput", - "ReportingConfig", - "ReportingConfigInput", - "ReportingType", - "SnapshotsConfig", - "SnapshotsConfigInput", - "StorageConfig", - "StorageConfigInput", - "StorageType", - "StorageType", - "SummarizeDescriptionsConfig", - "SummarizeDescriptionsConfigInput", - "TextEmbeddingConfig", - "TextEmbeddingConfigInput", - "TextEmbeddingTarget", - "UmapConfig", - "UmapConfigInput", - "create_graphrag_config", - "enable_logging_with_config", - "load_config", - "load_config_from_file", - "read_dotenv", - "resolve_path", - "resolve_paths", - "search_for_config_in_root_dir", -] diff --git a/graphrag/config/config_file_loader.py b/graphrag/config/config_file_loader.py index 667fbe8807..4ab930a374 100644 --- a/graphrag/config/config_file_loader.py +++ b/graphrag/config/config_file_loader.py @@ -9,8 +9,8 @@ import yaml -from .create_graphrag_config import create_graphrag_config -from .models.graph_rag_config import GraphRagConfig +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig _default_config_files = ["settings.yaml", "settings.yml", "settings.json"] diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 42b250088e..2e0f1005dd 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -13,8 +13,7 @@ from pydantic import TypeAdapter import graphrag.config.defaults as defs - -from .enums import ( +from graphrag.config.enums import ( CacheType, InputFileType, InputType, @@ -23,39 +22,37 @@ StorageType, TextEmbeddingTarget, ) -from .environment_reader import EnvironmentReader -from .errors import ( +from graphrag.config.environment_reader import EnvironmentReader +from graphrag.config.errors import ( ApiKeyMissingError, AzureApiBaseMissingError, AzureDeploymentNameMissingError, ) -from .input_models import ( - GraphRagConfigInput, - LLMConfigInput, -) -from .models import ( - CacheConfig, - ChunkingConfig, - ClaimExtractionConfig, - ClusterGraphConfig, - CommunityReportsConfig, - DRIFTSearchConfig, - EmbedGraphConfig, - EntityExtractionConfig, - GlobalSearchConfig, - GraphRagConfig, - InputConfig, - LLMParameters, - LocalSearchConfig, - ParallelizationParameters, - ReportingConfig, - SnapshotsConfig, - StorageConfig, +from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput +from graphrag.config.models.cache_config import CacheConfig +from graphrag.config.models.chunking_config import ChunkingConfig +from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig +from graphrag.config.models.cluster_graph_config import ClusterGraphConfig +from graphrag.config.models.community_reports_config import CommunityReportsConfig +from graphrag.config.models.drift_search_config import DRIFTSearchConfig +from graphrag.config.models.embed_graph_config import EmbedGraphConfig +from graphrag.config.models.entity_extraction_config import EntityExtractionConfig +from graphrag.config.models.global_search_config import GlobalSearchConfig +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.llm_parameters import LLMParameters +from graphrag.config.models.local_search_config import LocalSearchConfig +from graphrag.config.models.parallelization_parameters import ParallelizationParameters +from graphrag.config.models.reporting_config import ReportingConfig +from graphrag.config.models.snapshots_config import SnapshotsConfig +from graphrag.config.models.storage_config import StorageConfig +from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, - TextEmbeddingConfig, - UmapConfig, ) -from .read_dotenv import read_dotenv +from graphrag.config.models.text_embedding_config import TextEmbeddingConfig +from graphrag.config.models.umap_config import UmapConfig +from graphrag.config.read_dotenv import read_dotenv InputModelValidator = TypeAdapter(GraphRagConfigInput) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 41ec8fc892..ecfe632eaa 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -7,9 +7,7 @@ from datashaper import AsyncType -from graphrag.vector_stores import VectorStoreType - -from .enums import ( +from graphrag.config.enums import ( CacheType, InputFileType, InputType, @@ -18,6 +16,7 @@ StorageType, TextEmbeddingTarget, ) +from graphrag.vector_stores.factory import VectorStoreType ASYNC_MODE = AsyncType.Threaded ENCODING_MODEL = "cl100k_base" diff --git a/graphrag/config/input_models/__init__.py b/graphrag/config/input_models/__init__.py index f905ae38b2..6c5862a947 100644 --- a/graphrag/config/input_models/__init__.py +++ b/graphrag/config/input_models/__init__.py @@ -2,49 +2,3 @@ # Licensed under the MIT License """Interfaces for Default Config parameterization.""" - -from .cache_config_input import CacheConfigInput -from .chunking_config_input import ChunkingConfigInput -from .claim_extraction_config_input import ClaimExtractionConfigInput -from .cluster_graph_config_input import ClusterGraphConfigInput -from .community_reports_config_input import CommunityReportsConfigInput -from .embed_graph_config_input import EmbedGraphConfigInput -from .entity_extraction_config_input import EntityExtractionConfigInput -from .global_search_config_input import GlobalSearchConfigInput -from .graphrag_config_input import GraphRagConfigInput -from .input_config_input import InputConfigInput -from .llm_config_input import LLMConfigInput -from .llm_parameters_input import LLMParametersInput -from .local_search_config_input import LocalSearchConfigInput -from .parallelization_parameters_input import ParallelizationParametersInput -from .reporting_config_input import ReportingConfigInput -from .snapshots_config_input import SnapshotsConfigInput -from .storage_config_input import StorageConfigInput -from .summarize_descriptions_config_input import ( - SummarizeDescriptionsConfigInput, -) -from .text_embedding_config_input import TextEmbeddingConfigInput -from .umap_config_input import UmapConfigInput - -__all__ = [ - "CacheConfigInput", - "ChunkingConfigInput", - "ClaimExtractionConfigInput", - "ClusterGraphConfigInput", - "CommunityReportsConfigInput", - "EmbedGraphConfigInput", - "EntityExtractionConfigInput", - "GlobalSearchConfigInput", - "GraphRagConfigInput", - "InputConfigInput", - "LLMConfigInput", - "LLMParametersInput", - "LocalSearchConfigInput", - "ParallelizationParametersInput", - "ReportingConfigInput", - "SnapshotsConfigInput", - "StorageConfigInput", - "SummarizeDescriptionsConfigInput", - "TextEmbeddingConfigInput", - "UmapConfigInput", -] diff --git a/graphrag/config/input_models/claim_extraction_config_input.py b/graphrag/config/input_models/claim_extraction_config_input.py index f23e31d0a7..42ff60ea14 100644 --- a/graphrag/config/input_models/claim_extraction_config_input.py +++ b/graphrag/config/input_models/claim_extraction_config_input.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired -from .llm_config_input import LLMConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput class ClaimExtractionConfigInput(LLMConfigInput): diff --git a/graphrag/config/input_models/community_reports_config_input.py b/graphrag/config/input_models/community_reports_config_input.py index 79ae3152e7..4f8297ae33 100644 --- a/graphrag/config/input_models/community_reports_config_input.py +++ b/graphrag/config/input_models/community_reports_config_input.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired -from .llm_config_input import LLMConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput class CommunityReportsConfigInput(LLMConfigInput): diff --git a/graphrag/config/input_models/entity_extraction_config_input.py b/graphrag/config/input_models/entity_extraction_config_input.py index f1d3587e99..dcc2770c21 100644 --- a/graphrag/config/input_models/entity_extraction_config_input.py +++ b/graphrag/config/input_models/entity_extraction_config_input.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired -from .llm_config_input import LLMConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput class EntityExtractionConfigInput(LLMConfigInput): diff --git a/graphrag/config/input_models/graphrag_config_input.py b/graphrag/config/input_models/graphrag_config_input.py index 7c04dea2e3..9d3094edd7 100644 --- a/graphrag/config/input_models/graphrag_config_input.py +++ b/graphrag/config/input_models/graphrag_config_input.py @@ -5,25 +5,39 @@ from typing_extensions import NotRequired -from .cache_config_input import CacheConfigInput -from .chunking_config_input import ChunkingConfigInput -from .claim_extraction_config_input import ClaimExtractionConfigInput -from .cluster_graph_config_input import ClusterGraphConfigInput -from .community_reports_config_input import CommunityReportsConfigInput -from .embed_graph_config_input import EmbedGraphConfigInput -from .entity_extraction_config_input import EntityExtractionConfigInput -from .global_search_config_input import GlobalSearchConfigInput -from .input_config_input import InputConfigInput -from .llm_config_input import LLMConfigInput -from .local_search_config_input import LocalSearchConfigInput -from .reporting_config_input import ReportingConfigInput -from .snapshots_config_input import SnapshotsConfigInput -from .storage_config_input import StorageConfigInput -from .summarize_descriptions_config_input import ( +from graphrag.config.input_models.cache_config_input import CacheConfigInput +from graphrag.config.input_models.chunking_config_input import ChunkingConfigInput +from graphrag.config.input_models.claim_extraction_config_input import ( + ClaimExtractionConfigInput, +) +from graphrag.config.input_models.cluster_graph_config_input import ( + ClusterGraphConfigInput, +) +from graphrag.config.input_models.community_reports_config_input import ( + CommunityReportsConfigInput, +) +from graphrag.config.input_models.embed_graph_config_input import EmbedGraphConfigInput +from graphrag.config.input_models.entity_extraction_config_input import ( + EntityExtractionConfigInput, +) +from graphrag.config.input_models.global_search_config_input import ( + GlobalSearchConfigInput, +) +from graphrag.config.input_models.input_config_input import InputConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput +from graphrag.config.input_models.local_search_config_input import ( + LocalSearchConfigInput, +) +from graphrag.config.input_models.reporting_config_input import ReportingConfigInput +from graphrag.config.input_models.snapshots_config_input import SnapshotsConfigInput +from graphrag.config.input_models.storage_config_input import StorageConfigInput +from graphrag.config.input_models.summarize_descriptions_config_input import ( SummarizeDescriptionsConfigInput, ) -from .text_embedding_config_input import TextEmbeddingConfigInput -from .umap_config_input import UmapConfigInput +from graphrag.config.input_models.text_embedding_config_input import ( + TextEmbeddingConfigInput, +) +from graphrag.config.input_models.umap_config_input import UmapConfigInput class GraphRagConfigInput(LLMConfigInput): diff --git a/graphrag/config/input_models/llm_config_input.py b/graphrag/config/input_models/llm_config_input.py index 67231371b8..35b3b342b4 100644 --- a/graphrag/config/input_models/llm_config_input.py +++ b/graphrag/config/input_models/llm_config_input.py @@ -6,8 +6,10 @@ from datashaper import AsyncType from typing_extensions import NotRequired, TypedDict -from .llm_parameters_input import LLMParametersInput -from .parallelization_parameters_input import ParallelizationParametersInput +from graphrag.config.input_models.llm_parameters_input import LLMParametersInput +from graphrag.config.input_models.parallelization_parameters_input import ( + ParallelizationParametersInput, +) class LLMConfigInput(TypedDict): diff --git a/graphrag/config/input_models/summarize_descriptions_config_input.py b/graphrag/config/input_models/summarize_descriptions_config_input.py index 6ce756e558..b71a465aef 100644 --- a/graphrag/config/input_models/summarize_descriptions_config_input.py +++ b/graphrag/config/input_models/summarize_descriptions_config_input.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired -from .llm_config_input import LLMConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput class SummarizeDescriptionsConfigInput(LLMConfigInput): diff --git a/graphrag/config/input_models/text_embedding_config_input.py b/graphrag/config/input_models/text_embedding_config_input.py index a7e176c658..de72612e34 100644 --- a/graphrag/config/input_models/text_embedding_config_input.py +++ b/graphrag/config/input_models/text_embedding_config_input.py @@ -8,8 +8,7 @@ from graphrag.config.enums import ( TextEmbeddingTarget, ) - -from .llm_config_input import LLMConfigInput +from graphrag.config.input_models.llm_config_input import LLMConfigInput class TextEmbeddingConfigInput(LLMConfigInput): diff --git a/graphrag/config/load_config.py b/graphrag/config/load_config.py index c4133a7196..63c2c6967b 100644 --- a/graphrag/config/load_config.py +++ b/graphrag/config/load_config.py @@ -5,9 +5,12 @@ from pathlib import Path -from .config_file_loader import load_config_from_file, search_for_config_in_root_dir -from .create_graphrag_config import create_graphrag_config -from .models.graph_rag_config import GraphRagConfig +from graphrag.config.config_file_loader import ( + load_config_from_file, + search_for_config_in_root_dir, +) +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig def load_config( diff --git a/graphrag/config/logging.py b/graphrag/config/logging.py index 99ee459a27..3c626b625c 100644 --- a/graphrag/config/logging.py +++ b/graphrag/config/logging.py @@ -6,8 +6,8 @@ import logging from pathlib import Path -from .enums import ReportingType -from .models.graph_rag_config import GraphRagConfig +from graphrag.config.enums import ReportingType +from graphrag.config.models.graph_rag_config import GraphRagConfig def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None: diff --git a/graphrag/config/models/__init__.py b/graphrag/config/models/__init__.py index 887d4ad653..6c5862a947 100644 --- a/graphrag/config/models/__init__.py +++ b/graphrag/config/models/__init__.py @@ -2,49 +2,3 @@ # Licensed under the MIT License """Interfaces for Default Config parameterization.""" - -from .cache_config import CacheConfig -from .chunking_config import ChunkingConfig -from .claim_extraction_config import ClaimExtractionConfig -from .cluster_graph_config import ClusterGraphConfig -from .community_reports_config import CommunityReportsConfig -from .drift_search_config import DRIFTSearchConfig -from .embed_graph_config import EmbedGraphConfig -from .entity_extraction_config import EntityExtractionConfig -from .global_search_config import GlobalSearchConfig -from .graph_rag_config import GraphRagConfig -from .input_config import InputConfig -from .llm_config import LLMConfig -from .llm_parameters import LLMParameters -from .local_search_config import LocalSearchConfig -from .parallelization_parameters import ParallelizationParameters -from .reporting_config import ReportingConfig -from .snapshots_config import SnapshotsConfig -from .storage_config import StorageConfig -from .summarize_descriptions_config import SummarizeDescriptionsConfig -from .text_embedding_config import TextEmbeddingConfig -from .umap_config import UmapConfig - -__all__ = [ - "CacheConfig", - "ChunkingConfig", - "ClaimExtractionConfig", - "ClusterGraphConfig", - "CommunityReportsConfig", - "DRIFTSearchConfig", - "EmbedGraphConfig", - "EntityExtractionConfig", - "GlobalSearchConfig", - "GraphRagConfig", - "InputConfig", - "LLMConfig", - "LLMParameters", - "LocalSearchConfig", - "ParallelizationParameters", - "ReportingConfig", - "SnapshotsConfig", - "StorageConfig", - "SummarizeDescriptionsConfig", - "TextEmbeddingConfig", - "UmapConfig", -] diff --git a/graphrag/config/models/claim_extraction_config.py b/graphrag/config/models/claim_extraction_config.py index 6a4de8e3c6..716f64480e 100644 --- a/graphrag/config/models/claim_extraction_config.py +++ b/graphrag/config/models/claim_extraction_config.py @@ -8,8 +8,7 @@ from pydantic import Field import graphrag.config.defaults as defs - -from .llm_config import LLMConfig +from graphrag.config.models.llm_config import LLMConfig class ClaimExtractionConfig(LLMConfig): diff --git a/graphrag/config/models/community_reports_config.py b/graphrag/config/models/community_reports_config.py index 0eafa81c29..104c77eca6 100644 --- a/graphrag/config/models/community_reports_config.py +++ b/graphrag/config/models/community_reports_config.py @@ -8,8 +8,7 @@ from pydantic import Field import graphrag.config.defaults as defs - -from .llm_config import LLMConfig +from graphrag.config.models.llm_config import LLMConfig class CommunityReportsConfig(LLMConfig): diff --git a/graphrag/config/models/entity_extraction_config.py b/graphrag/config/models/entity_extraction_config.py index 08055d510b..40f155d0e4 100644 --- a/graphrag/config/models/entity_extraction_config.py +++ b/graphrag/config/models/entity_extraction_config.py @@ -8,8 +8,7 @@ from pydantic import Field import graphrag.config.defaults as defs - -from .llm_config import LLMConfig +from graphrag.config.models.llm_config import LLMConfig class EntityExtractionConfig(LLMConfig): diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index adcf64452d..3dc5f71d82 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -7,27 +7,26 @@ from pydantic import Field import graphrag.config.defaults as defs - -from .cache_config import CacheConfig -from .chunking_config import ChunkingConfig -from .claim_extraction_config import ClaimExtractionConfig -from .cluster_graph_config import ClusterGraphConfig -from .community_reports_config import CommunityReportsConfig -from .drift_search_config import DRIFTSearchConfig -from .embed_graph_config import EmbedGraphConfig -from .entity_extraction_config import EntityExtractionConfig -from .global_search_config import GlobalSearchConfig -from .input_config import InputConfig -from .llm_config import LLMConfig -from .local_search_config import LocalSearchConfig -from .reporting_config import ReportingConfig -from .snapshots_config import SnapshotsConfig -from .storage_config import StorageConfig -from .summarize_descriptions_config import ( +from graphrag.config.models.cache_config import CacheConfig +from graphrag.config.models.chunking_config import ChunkingConfig +from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig +from graphrag.config.models.cluster_graph_config import ClusterGraphConfig +from graphrag.config.models.community_reports_config import CommunityReportsConfig +from graphrag.config.models.drift_search_config import DRIFTSearchConfig +from graphrag.config.models.embed_graph_config import EmbedGraphConfig +from graphrag.config.models.entity_extraction_config import EntityExtractionConfig +from graphrag.config.models.global_search_config import GlobalSearchConfig +from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.llm_config import LLMConfig +from graphrag.config.models.local_search_config import LocalSearchConfig +from graphrag.config.models.reporting_config import ReportingConfig +from graphrag.config.models.snapshots_config import SnapshotsConfig +from graphrag.config.models.storage_config import StorageConfig +from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) -from .text_embedding_config import TextEmbeddingConfig -from .umap_config import UmapConfig +from graphrag.config.models.text_embedding_config import TextEmbeddingConfig +from graphrag.config.models.umap_config import UmapConfig class GraphRagConfig(LLMConfig): diff --git a/graphrag/config/models/llm_config.py b/graphrag/config/models/llm_config.py index 62c193b0c5..3759bd949e 100644 --- a/graphrag/config/models/llm_config.py +++ b/graphrag/config/models/llm_config.py @@ -7,9 +7,8 @@ from pydantic import BaseModel, Field import graphrag.config.defaults as defs - -from .llm_parameters import LLMParameters -from .parallelization_parameters import ParallelizationParameters +from graphrag.config.models.llm_parameters import LLMParameters +from graphrag.config.models.parallelization_parameters import ParallelizationParameters class LLMConfig(BaseModel): diff --git a/graphrag/config/models/summarize_descriptions_config.py b/graphrag/config/models/summarize_descriptions_config.py index 9104a60ac2..c1acb9b381 100644 --- a/graphrag/config/models/summarize_descriptions_config.py +++ b/graphrag/config/models/summarize_descriptions_config.py @@ -8,8 +8,7 @@ from pydantic import Field import graphrag.config.defaults as defs - -from .llm_config import LLMConfig +from graphrag.config.models.llm_config import LLMConfig class SummarizeDescriptionsConfig(LLMConfig): diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/text_embedding_config.py index 815263bbcf..1fde9a41dd 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/text_embedding_config.py @@ -7,8 +7,7 @@ import graphrag.config.defaults as defs from graphrag.config.enums import TextEmbeddingTarget - -from .llm_config import LLMConfig +from graphrag.config.models.llm_config import LLMConfig class TextEmbeddingConfig(LLMConfig): diff --git a/graphrag/config/resolve_path.py b/graphrag/config/resolve_path.py index 7ff60c4562..237c7f7edb 100644 --- a/graphrag/config/resolve_path.py +++ b/graphrag/config/resolve_path.py @@ -7,8 +7,8 @@ from pathlib import Path from string import Template -from .enums import ReportingType, StorageType -from .models.graph_rag_config import GraphRagConfig +from graphrag.config.enums import ReportingType, StorageType +from graphrag.config.models.graph_rag_config import GraphRagConfig def _resolve_timestamp_path_with_value(path: str | Path, timestamp_value: str) -> Path: diff --git a/graphrag/index/__init__.py b/graphrag/index/__init__.py index c97c290a94..c5acc43b65 100644 --- a/graphrag/index/__init__.py +++ b/graphrag/index/__init__.py @@ -1,78 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""The Indexing Engine package root.""" - -from .cache import PipelineCache -from .config import ( - PipelineBlobCacheConfig, - PipelineBlobReportingConfig, - PipelineBlobStorageConfig, - PipelineCacheConfig, - PipelineCacheConfigTypes, - PipelineConfig, - PipelineConsoleReportingConfig, - PipelineCSVInputConfig, - PipelineFileCacheConfig, - PipelineFileReportingConfig, - PipelineFileStorageConfig, - PipelineInputConfig, - PipelineInputConfigTypes, - PipelineMemoryCacheConfig, - PipelineMemoryStorageConfig, - PipelineNoneCacheConfig, - PipelineReportingConfig, - PipelineReportingConfigTypes, - PipelineStorageConfig, - PipelineStorageConfigTypes, - PipelineTextInputConfig, - PipelineWorkflowConfig, - PipelineWorkflowReference, - PipelineWorkflowStep, -) -from .create_pipeline_config import create_pipeline_config -from .errors import ( - NoWorkflowsDefinedError, - UndefinedWorkflowError, - UnknownWorkflowError, -) -from .load_pipeline_config import load_pipeline_config -from .run import run_pipeline, run_pipeline_with_config -from .storage import PipelineStorage - -__all__ = [ - "NoWorkflowsDefinedError", - "PipelineBlobCacheConfig", - "PipelineBlobCacheConfig", - "PipelineBlobReportingConfig", - "PipelineBlobStorageConfig", - "PipelineCSVInputConfig", - "PipelineCache", - "PipelineCacheConfig", - "PipelineCacheConfigTypes", - "PipelineConfig", - "PipelineConsoleReportingConfig", - "PipelineFileCacheConfig", - "PipelineFileReportingConfig", - "PipelineFileStorageConfig", - "PipelineInputConfig", - "PipelineInputConfigTypes", - "PipelineMemoryCacheConfig", - "PipelineMemoryStorageConfig", - "PipelineNoneCacheConfig", - "PipelineReportingConfig", - "PipelineReportingConfigTypes", - "PipelineStorage", - "PipelineStorageConfig", - "PipelineStorageConfigTypes", - "PipelineTextInputConfig", - "PipelineWorkflowConfig", - "PipelineWorkflowReference", - "PipelineWorkflowStep", - "UndefinedWorkflowError", - "UnknownWorkflowError", - "create_pipeline_config", - "load_pipeline_config", - "run_pipeline", - "run_pipeline_with_config", -] +"""The indexing engine package root.""" diff --git a/graphrag/index/cache/__init__.py b/graphrag/index/cache/__init__.py index 42ebb22994..ece87659fd 100644 --- a/graphrag/index/cache/__init__.py +++ b/graphrag/index/cache/__init__.py @@ -2,17 +2,3 @@ # Licensed under the MIT License """The Indexing Engine cache package root.""" - -from .json_pipeline_cache import JsonPipelineCache -from .load_cache import load_cache -from .memory_pipeline_cache import InMemoryCache -from .noop_pipeline_cache import NoopPipelineCache -from .pipeline_cache import PipelineCache - -__all__ = [ - "InMemoryCache", - "JsonPipelineCache", - "NoopPipelineCache", - "PipelineCache", - "load_cache", -] diff --git a/graphrag/index/cache/json_pipeline_cache.py b/graphrag/index/cache/json_pipeline_cache.py index b9e85889ad..13d5212394 100644 --- a/graphrag/index/cache/json_pipeline_cache.py +++ b/graphrag/index/cache/json_pipeline_cache.py @@ -6,9 +6,8 @@ import json from typing import Any -from graphrag.index.storage import PipelineStorage - -from .pipeline_cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.storage.pipeline_storage import PipelineStorage class JsonPipelineCache(PipelineCache): diff --git a/graphrag/index/cache/load_cache.py b/graphrag/index/cache/load_cache.py index 4e0e6324fb..91f633367c 100644 --- a/graphrag/index/cache/load_cache.py +++ b/graphrag/index/cache/load_cache.py @@ -12,16 +12,17 @@ PipelineBlobCacheConfig, PipelineFileCacheConfig, ) -from graphrag.index.storage import BlobPipelineStorage, FilePipelineStorage +from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.index.storage.file_pipeline_storage import FilePipelineStorage if TYPE_CHECKING: - from graphrag.index.config import ( + from graphrag.index.config.cache import ( PipelineCacheConfig, ) -from .json_pipeline_cache import JsonPipelineCache -from .memory_pipeline_cache import create_memory_cache -from .noop_pipeline_cache import NoopPipelineCache +from graphrag.index.cache.json_pipeline_cache import JsonPipelineCache +from graphrag.index.cache.memory_pipeline_cache import create_memory_cache +from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache def load_cache(config: PipelineCacheConfig | None, root_dir: str | None): diff --git a/graphrag/index/cache/memory_pipeline_cache.py b/graphrag/index/cache/memory_pipeline_cache.py index fa42f3f921..2a9e19c9c0 100644 --- a/graphrag/index/cache/memory_pipeline_cache.py +++ b/graphrag/index/cache/memory_pipeline_cache.py @@ -5,7 +5,7 @@ from typing import Any -from .pipeline_cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache class InMemoryCache(PipelineCache): diff --git a/graphrag/index/cache/noop_pipeline_cache.py b/graphrag/index/cache/noop_pipeline_cache.py index b7c3e60fdd..738787ad35 100644 --- a/graphrag/index/cache/noop_pipeline_cache.py +++ b/graphrag/index/cache/noop_pipeline_cache.py @@ -5,7 +5,7 @@ from typing import Any -from .pipeline_cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache class NoopPipelineCache(PipelineCache): diff --git a/graphrag/index/config/__init__.py b/graphrag/index/config/__init__.py index 847659cd24..7a8b16c91d 100644 --- a/graphrag/index/config/__init__.py +++ b/graphrag/index/config/__init__.py @@ -2,90 +2,3 @@ # Licensed under the MIT License """The Indexing Engine config typing package root.""" - -from .cache import ( - PipelineBlobCacheConfig, - PipelineCacheConfig, - PipelineCacheConfigTypes, - PipelineFileCacheConfig, - PipelineMemoryCacheConfig, - PipelineNoneCacheConfig, -) -from .embeddings import ( - all_embeddings, - community_full_content_embedding, - community_summary_embedding, - community_title_embedding, - document_text_embedding, - entity_description_embedding, - entity_title_embedding, - relationship_description_embedding, - required_embeddings, - text_unit_text_embedding, -) -from .input import ( - PipelineCSVInputConfig, - PipelineInputConfig, - PipelineInputConfigTypes, - PipelineTextInputConfig, -) -from .pipeline import PipelineConfig -from .reporting import ( - PipelineBlobReportingConfig, - PipelineConsoleReportingConfig, - PipelineFileReportingConfig, - PipelineReportingConfig, - PipelineReportingConfigTypes, -) -from .storage import ( - PipelineBlobStorageConfig, - PipelineFileStorageConfig, - PipelineMemoryStorageConfig, - PipelineStorageConfig, - PipelineStorageConfigTypes, -) -from .workflow import ( - PipelineWorkflowConfig, - PipelineWorkflowReference, - PipelineWorkflowStep, -) - -__all__ = [ - "PipelineBlobCacheConfig", - "PipelineBlobReportingConfig", - "PipelineBlobStorageConfig", - "PipelineCSVInputConfig", - "PipelineCacheConfig", - "PipelineCacheConfigTypes", - "PipelineCacheConfigTypes", - "PipelineCacheConfigTypes", - "PipelineConfig", - "PipelineConsoleReportingConfig", - "PipelineFileCacheConfig", - "PipelineFileReportingConfig", - "PipelineFileStorageConfig", - "PipelineInputConfig", - "PipelineInputConfigTypes", - "PipelineMemoryCacheConfig", - "PipelineMemoryCacheConfig", - "PipelineMemoryStorageConfig", - "PipelineNoneCacheConfig", - "PipelineReportingConfig", - "PipelineReportingConfigTypes", - "PipelineStorageConfig", - "PipelineStorageConfigTypes", - "PipelineTextInputConfig", - "PipelineWorkflowConfig", - "PipelineWorkflowReference", - "PipelineWorkflowStep", - "all_embeddings", - "community_full_content_embedding", - "community_summary_embedding", - "community_title_embedding", - "document_text_embedding", - "entity_description_embedding", - "entity_title_embedding", - "relationship_description_embedding", - "required_embeddings", - "text_unit_text_embedding", -] diff --git a/graphrag/index/config/input.py b/graphrag/index/config/input.py index 35db357599..b3e4e89e8b 100644 --- a/graphrag/index/config/input.py +++ b/graphrag/index/config/input.py @@ -11,8 +11,7 @@ from pydantic import Field as pydantic_Field from graphrag.config.enums import InputFileType, InputType - -from .workflow import PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowStep T = TypeVar("T") diff --git a/graphrag/index/config/pipeline.py b/graphrag/index/config/pipeline.py index 7fa68c7aae..8ee420cc28 100644 --- a/graphrag/index/config/pipeline.py +++ b/graphrag/index/config/pipeline.py @@ -9,11 +9,11 @@ from pydantic import BaseModel from pydantic import Field as pydantic_Field -from .cache import PipelineCacheConfigTypes -from .input import PipelineInputConfigTypes -from .reporting import PipelineReportingConfigTypes -from .storage import PipelineStorageConfigTypes -from .workflow import PipelineWorkflowReference +from graphrag.index.config.cache import PipelineCacheConfigTypes +from graphrag.index.config.input import PipelineInputConfigTypes +from graphrag.index.config.reporting import PipelineReportingConfigTypes +from graphrag.index.config.storage import PipelineStorageConfigTypes +from graphrag.index.config.workflow import PipelineWorkflowReference class PipelineConfig(BaseModel): diff --git a/graphrag/index/context.py b/graphrag/index/context.py index 94934e4909..fa4c9b2728 100644 --- a/graphrag/index/context.py +++ b/graphrag/index/context.py @@ -7,8 +7,8 @@ from dataclasses import dataclass as dc_dataclass from dataclasses import field -from .cache import PipelineCache -from .storage.pipeline_storage import PipelineStorage +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.storage.pipeline_storage import PipelineStorage @dc_dataclass diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index e0ab305fea..76a7e9b118 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -14,7 +14,9 @@ StorageType, TextEmbeddingTarget, ) -from graphrag.config.models import GraphRagConfig, StorageConfig, TextEmbeddingConfig +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.storage_config import StorageConfig +from graphrag.config.models.text_embedding_config import TextEmbeddingConfig from graphrag.index.config.cache import ( PipelineBlobCacheConfig, PipelineCacheConfigTypes, diff --git a/graphrag/index/emit/__init__.py b/graphrag/index/emit/__init__.py index 354989e338..7ae6eea9f1 100644 --- a/graphrag/index/emit/__init__.py +++ b/graphrag/index/emit/__init__.py @@ -2,20 +2,3 @@ # Licensed under the MIT License """Definitions for emitting pipeline artifacts to storage.""" - -from .csv_table_emitter import CSVTableEmitter -from .factories import create_table_emitter, create_table_emitters -from .json_table_emitter import JsonTableEmitter -from .parquet_table_emitter import ParquetTableEmitter -from .table_emitter import TableEmitter -from .types import TableEmitterType - -__all__ = [ - "CSVTableEmitter", - "JsonTableEmitter", - "ParquetTableEmitter", - "TableEmitter", - "TableEmitterType", - "create_table_emitter", - "create_table_emitters", -] diff --git a/graphrag/index/emit/csv_table_emitter.py b/graphrag/index/emit/csv_table_emitter.py index c0305c254b..3ba976b8df 100644 --- a/graphrag/index/emit/csv_table_emitter.py +++ b/graphrag/index/emit/csv_table_emitter.py @@ -7,9 +7,8 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage - -from .table_emitter import TableEmitter +from graphrag.index.emit.table_emitter import TableEmitter +from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index 84afa68443..9a83e7185b 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -3,15 +3,14 @@ """Table Emitter Factories.""" -from graphrag.index.storage import PipelineStorage +from graphrag.index.emit.csv_table_emitter import CSVTableEmitter +from graphrag.index.emit.json_table_emitter import JsonTableEmitter +from graphrag.index.emit.parquet_table_emitter import ParquetTableEmitter +from graphrag.index.emit.table_emitter import TableEmitter +from graphrag.index.emit.types import TableEmitterType +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.typing import ErrorHandlerFn -from .csv_table_emitter import CSVTableEmitter -from .json_table_emitter import JsonTableEmitter -from .parquet_table_emitter import ParquetTableEmitter -from .table_emitter import TableEmitter -from .types import TableEmitterType - def create_table_emitter( emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn diff --git a/graphrag/index/emit/json_table_emitter.py b/graphrag/index/emit/json_table_emitter.py index 0b18c717a6..ceadc414d9 100644 --- a/graphrag/index/emit/json_table_emitter.py +++ b/graphrag/index/emit/json_table_emitter.py @@ -7,9 +7,8 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage - -from .table_emitter import TableEmitter +from graphrag.index.emit.table_emitter import TableEmitter +from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/emit/parquet_table_emitter.py b/graphrag/index/emit/parquet_table_emitter.py index 753915a79a..e649f283c3 100644 --- a/graphrag/index/emit/parquet_table_emitter.py +++ b/graphrag/index/emit/parquet_table_emitter.py @@ -9,11 +9,10 @@ import pandas as pd from pyarrow.lib import ArrowInvalid, ArrowTypeError -from graphrag.index.storage import PipelineStorage +from graphrag.index.emit.table_emitter import TableEmitter +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.typing import ErrorHandlerFn -from .table_emitter import TableEmitter - log = logging.getLogger(__name__) diff --git a/graphrag/index/flows/create_base_entity_graph.py b/graphrag/index/flows/create_base_entity_graph.py index bded3ca6fb..fe429d336b 100644 --- a/graphrag/index/flows/create_base_entity_graph.py +++ b/graphrag/index/flows/create_base_entity_graph.py @@ -11,7 +11,7 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.embed_graph import embed_graph from graphrag.index.operations.extract_entities import extract_entities @@ -22,7 +22,7 @@ from graphrag.index.operations.summarize_descriptions import ( summarize_descriptions, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage async def create_base_entity_graph( diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index 51a9dba643..cca55d3a17 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -16,8 +16,8 @@ from graphrag.index.operations.chunk_text import chunk_text from graphrag.index.operations.snapshot import snapshot -from graphrag.index.storage import PipelineStorage -from graphrag.index.utils import gen_md5_hash +from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.index.utils.hashing import gen_md5_hash async def create_base_text_units( diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index 001844b5b0..754d66d6e5 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -11,7 +11,7 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.community_reports.schemas import ( CLAIM_DESCRIPTION, CLAIM_DETAILS, diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py index e04e7fe926..ad25445bf7 100644 --- a/graphrag/index/flows/create_final_covariates.py +++ b/graphrag/index/flows/create_final_covariates.py @@ -12,7 +12,7 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.operations.extract_covariates import ( extract_covariates, ) diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index d2adcb34c8..c966e673ed 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -13,7 +13,7 @@ from graphrag.index.operations.layout_graph import layout_graph from graphrag.index.operations.snapshot import snapshot from graphrag.index.operations.unpack_graph import unpack_graph -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage async def create_final_nodes( diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index 258d01c768..23fee842df 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -10,7 +10,7 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.config.embeddings import ( community_full_content_embedding, community_summary_embedding, @@ -23,7 +23,7 @@ ) from graphrag.index.operations.embed_text import embed_text from graphrag.index.operations.snapshot import snapshot -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/graph/embedding/__init__.py b/graphrag/index/graph/embedding/__init__.py index 0ea2d085f1..ff075875a5 100644 --- a/graphrag/index/graph/embedding/__init__.py +++ b/graphrag/index/graph/embedding/__init__.py @@ -3,6 +3,6 @@ """The Indexing Engine graph embedding package root.""" -from .embedding import NodeEmbeddings, embed_nod2vec +from graphrag.index.graph.embedding.embedding import NodeEmbeddings, embed_nod2vec __all__ = ["NodeEmbeddings", "embed_nod2vec"] diff --git a/graphrag/index/graph/embedding/embedding.py b/graphrag/index/graph/embedding/embedding.py index 267a190f91..ff6f86e72a 100644 --- a/graphrag/index/graph/embedding/embedding.py +++ b/graphrag/index/graph/embedding/embedding.py @@ -5,7 +5,6 @@ from dataclasses import dataclass -import graspologic as gc import networkx as nx import numpy as np @@ -28,6 +27,9 @@ def embed_nod2vec( random_seed: int = 86, ) -> NodeEmbeddings: """Generate node embeddings using Node2Vec.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import graspologic as gc + # generate embedding lcc_tensors = gc.embed.node2vec_embed( # type: ignore graph=graph, diff --git a/graphrag/index/graph/extractors/__init__.py b/graphrag/index/graph/extractors/__init__.py index 511695aea6..42ad16b89c 100644 --- a/graphrag/index/graph/extractors/__init__.py +++ b/graphrag/index/graph/extractors/__init__.py @@ -3,11 +3,11 @@ """The Indexing Engine graph extractors package root.""" -from .claims import ClaimExtractor -from .community_reports import ( +from graphrag.index.graph.extractors.claims import ClaimExtractor +from graphrag.index.graph.extractors.community_reports import ( CommunityReportsExtractor, ) -from .graph import GraphExtractionResult, GraphExtractor +from graphrag.index.graph.extractors.graph import GraphExtractionResult, GraphExtractor __all__ = [ "ClaimExtractor", diff --git a/graphrag/index/graph/extractors/claims/__init__.py b/graphrag/index/graph/extractors/claims/__init__.py index 3a5a22fdb1..897cdd1125 100644 --- a/graphrag/index/graph/extractors/claims/__init__.py +++ b/graphrag/index/graph/extractors/claims/__init__.py @@ -3,6 +3,6 @@ """The Indexing Engine graph extractors claims package root.""" -from .claim_extractor import ClaimExtractor +from graphrag.index.graph.extractors.claims.claim_extractor import ClaimExtractor __all__ = ["ClaimExtractor"] diff --git a/graphrag/index/graph/extractors/community_reports/__init__.py b/graphrag/index/graph/extractors/community_reports/__init__.py index da3bf8396a..bac91674c2 100644 --- a/graphrag/index/graph/extractors/community_reports/__init__.py +++ b/graphrag/index/graph/extractors/community_reports/__init__.py @@ -4,12 +4,17 @@ """The Indexing Engine community reports package root.""" import graphrag.index.graph.extractors.community_reports.schemas as schemas - -from .build_mixed_context import build_mixed_context -from .community_reports_extractor import CommunityReportsExtractor -from .prep_community_report_context import prep_community_report_context -from .sort_context import sort_context -from .utils import ( +from graphrag.index.graph.extractors.community_reports.build_mixed_context import ( + build_mixed_context, +) +from graphrag.index.graph.extractors.community_reports.community_reports_extractor import ( + CommunityReportsExtractor, +) +from graphrag.index.graph.extractors.community_reports.prep_community_report_context import ( + prep_community_report_context, +) +from graphrag.index.graph.extractors.community_reports.sort_context import sort_context +from graphrag.index.graph.extractors.community_reports.utils import ( filter_claims_to_nodes, filter_edges_to_nodes, filter_nodes_to_level, diff --git a/graphrag/index/graph/extractors/community_reports/build_mixed_context.py b/graphrag/index/graph/extractors/community_reports/build_mixed_context.py index ad9e2a8447..ca10ca948d 100644 --- a/graphrag/index/graph/extractors/community_reports/build_mixed_context.py +++ b/graphrag/index/graph/extractors/community_reports/build_mixed_context.py @@ -5,10 +5,9 @@ import pandas as pd import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.graph.extractors.community_reports.sort_context import sort_context from graphrag.query.llm.text_utils import num_tokens -from .sort_context import sort_context - def build_mixed_context(context: list[dict], max_tokens: int) -> str: """ diff --git a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py index 291e61af69..a78064bd9b 100644 --- a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py +++ b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py @@ -9,7 +9,7 @@ from typing import Any from graphrag.index.typing import ErrorHandlerFn -from graphrag.index.utils import dict_has_keys_with_types +from graphrag.index.utils.dicts import dict_has_keys_with_types from graphrag.llm import CompletionLLM from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT diff --git a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py b/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py index 72cdac9b4e..a4df7d533a 100644 --- a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py +++ b/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py @@ -9,6 +9,11 @@ import pandas as pd import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.graph.extractors.community_reports.build_mixed_context import ( + build_mixed_context, +) +from graphrag.index.graph.extractors.community_reports.sort_context import sort_context +from graphrag.index.graph.extractors.community_reports.utils import set_context_size from graphrag.index.utils.dataframes import ( antijoin, drop_columns, @@ -19,10 +24,6 @@ where_column_equals, ) -from .build_mixed_context import build_mixed_context -from .sort_context import sort_context -from .utils import set_context_size - log = logging.getLogger(__name__) diff --git a/graphrag/index/graph/extractors/graph/__init__.py b/graphrag/index/graph/extractors/graph/__init__.py index 7f8d19c9ca..c3f14bfa2f 100644 --- a/graphrag/index/graph/extractors/graph/__init__.py +++ b/graphrag/index/graph/extractors/graph/__init__.py @@ -3,7 +3,7 @@ """The Indexing Engine unipartite graph package root.""" -from .graph_extractor import ( +from graphrag.index.graph.extractors.graph.graph_extractor import ( DEFAULT_ENTITY_TYPES, GraphExtractionResult, GraphExtractor, diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index b669cfa004..7374e77c24 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -15,7 +15,7 @@ import graphrag.config.defaults as defs from graphrag.index.typing import ErrorHandlerFn -from graphrag.index.utils import clean_str +from graphrag.index.utils.string import clean_str from graphrag.llm import CompletionLLM from graphrag.prompts.index.entity_extraction import ( CONTINUE_PROMPT, diff --git a/graphrag/index/graph/extractors/summarize/__init__.py b/graphrag/index/graph/extractors/summarize/__init__.py index 17fe5095aa..54661d0f1c 100644 --- a/graphrag/index/graph/extractors/summarize/__init__.py +++ b/graphrag/index/graph/extractors/summarize/__init__.py @@ -3,7 +3,7 @@ """The Indexing Engine unipartite graph package root.""" -from .description_summary_extractor import ( +from graphrag.index.graph.extractors.summarize.description_summary_extractor import ( SummarizationResult, SummarizeExtractor, ) diff --git a/graphrag/index/graph/utils/__init__.py b/graphrag/index/graph/utils/__init__.py index 6d4479283a..2f6971186d 100644 --- a/graphrag/index/graph/utils/__init__.py +++ b/graphrag/index/graph/utils/__init__.py @@ -3,7 +3,7 @@ """The Indexing Engine graph utils package root.""" -from .normalize_node_names import normalize_node_names -from .stable_lcc import stable_largest_connected_component +from graphrag.index.graph.utils.normalize_node_names import normalize_node_names +from graphrag.index.graph.utils.stable_lcc import stable_largest_connected_component __all__ = ["normalize_node_names", "stable_largest_connected_component"] diff --git a/graphrag/index/graph/utils/stable_lcc.py b/graphrag/index/graph/utils/stable_lcc.py index 7d602a6ba7..cbd5243513 100644 --- a/graphrag/index/graph/utils/stable_lcc.py +++ b/graphrag/index/graph/utils/stable_lcc.py @@ -6,13 +6,15 @@ from typing import Any, cast import networkx as nx -from graspologic.utils import largest_connected_component -from .normalize_node_names import normalize_node_names +from graphrag.index.graph.utils.normalize_node_names import normalize_node_names def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" + # NOTE: The import is done here to reduce the initial import time of the module + from graspologic.utils import largest_connected_component + graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) graph = normalize_node_names(graph) diff --git a/graphrag/index/graph/visualization/__init__.py b/graphrag/index/graph/visualization/__init__.py index f7780e4e9c..090acdec32 100644 --- a/graphrag/index/graph/visualization/__init__.py +++ b/graphrag/index/graph/visualization/__init__.py @@ -3,8 +3,11 @@ """The Indexing Engine graph visualization package root.""" -from .compute_umap_positions import compute_umap_positions, get_zero_positions -from .typing import GraphLayout, NodePosition +from graphrag.index.graph.visualization.compute_umap_positions import ( + compute_umap_positions, + get_zero_positions, +) +from graphrag.index.graph.visualization.typing import GraphLayout, NodePosition __all__ = [ "GraphLayout", diff --git a/graphrag/index/graph/visualization/compute_umap_positions.py b/graphrag/index/graph/visualization/compute_umap_positions.py index 569b7b309d..36ac354b72 100644 --- a/graphrag/index/graph/visualization/compute_umap_positions.py +++ b/graphrag/index/graph/visualization/compute_umap_positions.py @@ -3,13 +3,11 @@ """A module containing compute_umap_positions and visualize_embedding method definition.""" -import graspologic as gc import matplotlib.pyplot as plt import networkx as nx import numpy as np -import umap -from .typing import NodePosition +from graphrag.index.graph.visualization.typing import NodePosition def get_zero_positions( @@ -61,6 +59,9 @@ def compute_umap_positions( random_state: int = 86, ) -> list[NodePosition]: """Project embedding vectors down to 2D/3D using UMAP.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import umap + embedding_positions = umap.UMAP( min_dist=min_dist, n_neighbors=n_neighbors, @@ -105,6 +106,9 @@ def visualize_embedding( umap_positions: list[dict], ): """Project embedding down to 2D using UMAP and visualize.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import graspologic as gc + # rendering plt.clf() figure = plt.gcf() diff --git a/graphrag/index/input/__init__.py b/graphrag/index/input/__init__.py index 91421867de..15177c91db 100644 --- a/graphrag/index/input/__init__.py +++ b/graphrag/index/input/__init__.py @@ -2,7 +2,3 @@ # Licensed under the MIT License """The Indexing Engine input package root.""" - -from .load_input import load_input - -__all__ = ["load_input"] diff --git a/graphrag/index/input/csv.py b/graphrag/index/input/csv.py index 9c93fca8f4..ceec80b506 100644 --- a/graphrag/index/input/csv.py +++ b/graphrag/index/input/csv.py @@ -10,10 +10,10 @@ import pandas as pd -from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig -from graphrag.index.storage import PipelineStorage -from graphrag.index.utils import gen_md5_hash -from graphrag.logging import ProgressReporter +from graphrag.index.config.input import PipelineCSVInputConfig, PipelineInputConfig +from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.index.utils.hashing import gen_md5_hash +from graphrag.logging.base import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/input/load_input.py b/graphrag/index/input/load_input.py index 100caf982a..4bfc82cbb8 100644 --- a/graphrag/index/input/load_input.py +++ b/graphrag/index/input/load_input.py @@ -10,18 +10,17 @@ import pandas as pd -from graphrag.config import InputConfig, InputType -from graphrag.index.config import PipelineInputConfig -from graphrag.index.storage import ( - BlobPipelineStorage, - FilePipelineStorage, -) -from graphrag.logging import NullProgressReporter, ProgressReporter - -from .csv import input_type as csv -from .csv import load as load_csv -from .text import input_type as text -from .text import load as load_text +from graphrag.config.enums import InputType +from graphrag.config.models.input_config import InputConfig +from graphrag.index.config.input import PipelineInputConfig +from graphrag.index.input.csv import input_type as csv +from graphrag.index.input.csv import load as load_csv +from graphrag.index.input.text import input_type as text +from graphrag.index.input.text import load as load_text +from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.index.storage.file_pipeline_storage import FilePipelineStorage +from graphrag.logging.base import ProgressReporter +from graphrag.logging.null_progress import NullProgressReporter log = logging.getLogger(__name__) loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index 7e76bfe1e7..45814ee3ff 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -10,10 +10,10 @@ import pandas as pd -from graphrag.index.config import PipelineInputConfig -from graphrag.index.storage import PipelineStorage -from graphrag.index.utils import gen_md5_hash -from graphrag.logging import ProgressReporter +from graphrag.index.config.input import PipelineInputConfig +from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.index.utils.hashing import gen_md5_hash +from graphrag.logging.base import ProgressReporter DEFAULT_FILE_PATTERN = re.compile( r".*[\\/](?P[^\\/]+)[\\/](?P\d{4})-(?P\d{2})-(?P\d{2})_(?P[^_]+)_\d+\.txt" diff --git a/graphrag/index/llm/__init__.py b/graphrag/index/llm/__init__.py index 008ef07ccd..6644fd912f 100644 --- a/graphrag/index/llm/__init__.py +++ b/graphrag/index/llm/__init__.py @@ -2,13 +2,3 @@ # Licensed under the MIT License """The Indexing Engine LLM package root.""" - -from .load_llm import load_llm, load_llm_embeddings -from .types import TextListSplitter, TextSplitter - -__all__ = [ - "TextListSplitter", - "TextSplitter", - "load_llm", - "load_llm_embeddings", -] diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index a7eda31a4e..f06c89ad9d 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from datashaper import VerbCallbacks - from graphrag.index.cache import PipelineCache + from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.typing import ErrorHandlerFn log = logging.getLogger(__name__) diff --git a/graphrag/index/load_pipeline_config.py b/graphrag/index/load_pipeline_config.py index dfcf321b3b..77893b9535 100644 --- a/graphrag/index/load_pipeline_config.py +++ b/graphrag/index/load_pipeline_config.py @@ -9,10 +9,9 @@ import yaml from pyaml_env import parse_config as parse_config_with_env -from graphrag.config import create_graphrag_config, read_dotenv -from graphrag.index.config import PipelineConfig - -from .create_pipeline_config import create_pipeline_config +from graphrag.config.create_graphrag_config import create_graphrag_config, read_dotenv +from graphrag.index.config.pipeline import PipelineConfig +from graphrag.index.create_pipeline_config import create_pipeline_config def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig: diff --git a/graphrag/index/operations/chunk_text/__init__.py b/graphrag/index/operations/chunk_text/__init__.py index 273ff0abaf..d84b4c0c38 100644 --- a/graphrag/index/operations/chunk_text/__init__.py +++ b/graphrag/index/operations/chunk_text/__init__.py @@ -3,6 +3,10 @@ """The Indexing Engine text chunk package root.""" -from .chunk_text import ChunkStrategy, ChunkStrategyType, chunk_text +from graphrag.index.operations.chunk_text.chunk_text import ( + ChunkStrategy, + ChunkStrategyType, + chunk_text, +) __all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk_text"] diff --git a/graphrag/index/operations/chunk_text/chunk_text.py b/graphrag/index/operations/chunk_text/chunk_text.py index bbcc750c59..60211ef88b 100644 --- a/graphrag/index/operations/chunk_text/chunk_text.py +++ b/graphrag/index/operations/chunk_text/chunk_text.py @@ -12,7 +12,11 @@ progress_ticker, ) -from .typing import ChunkInput, ChunkStrategy, ChunkStrategyType +from graphrag.index.operations.chunk_text.typing import ( + ChunkInput, + ChunkStrategy, + ChunkStrategyType, +) def chunk_text( @@ -117,14 +121,13 @@ def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy: """Load strategy method definition.""" match strategy: case ChunkStrategyType.tokens: - from .strategies import run_tokens + from graphrag.index.operations.chunk_text.strategies import run_tokens return run_tokens case ChunkStrategyType.sentence: # NLTK from graphrag.index.bootstrap import bootstrap - - from .strategies import run_sentences + from graphrag.index.operations.chunk_text.strategies import run_sentences bootstrap() return run_sentences diff --git a/graphrag/index/operations/chunk_text/strategies.py b/graphrag/index/operations/chunk_text/strategies.py index 7507784be3..35c32585c0 100644 --- a/graphrag/index/operations/chunk_text/strategies.py +++ b/graphrag/index/operations/chunk_text/strategies.py @@ -11,9 +11,8 @@ from datashaper import ProgressTicker import graphrag.config.defaults as defs -from graphrag.index.text_splitting import Tokenizer - -from .typing import TextChunk +from graphrag.index.operations.chunk_text.typing import TextChunk +from graphrag.index.text_splitting.text_splitting import Tokenizer def run_tokens( diff --git a/graphrag/index/operations/cluster_graph.py b/graphrag/index/operations/cluster_graph.py index b993789dcd..295c78e6fb 100644 --- a/graphrag/index/operations/cluster_graph.py +++ b/graphrag/index/operations/cluster_graph.py @@ -11,10 +11,9 @@ import networkx as nx import pandas as pd from datashaper import VerbCallbacks, progress_iterable -from graspologic.partition import hierarchical_leiden from graphrag.index.graph.utils import stable_largest_connected_component -from graphrag.index.utils import gen_uuid +from graphrag.index.utils.uuid import gen_uuid Communities = list[tuple[int, str, list[str]]] @@ -187,6 +186,9 @@ def _compute_leiden_communities( seed=0xDEADBEEF, ) -> dict[int, dict[str, int]]: """Return Leiden root communities.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + from graspologic.partition import hierarchical_leiden + if use_lcc: graph = stable_largest_connected_component(graph) diff --git a/graphrag/index/operations/embed_graph/__init__.py b/graphrag/index/operations/embed_graph/__init__.py index a47441b425..07c91c3be8 100644 --- a/graphrag/index/operations/embed_graph/__init__.py +++ b/graphrag/index/operations/embed_graph/__init__.py @@ -3,7 +3,10 @@ """The Indexing Engine graph embed package root.""" -from .embed_graph import EmbedGraphStrategyType, embed_graph -from .typing import NodeEmbeddings +from graphrag.index.operations.embed_graph.embed_graph import ( + EmbedGraphStrategyType, + embed_graph, +) +from graphrag.index.operations.embed_graph.typing import NodeEmbeddings __all__ = ["EmbedGraphStrategyType", "NodeEmbeddings", "embed_graph"] diff --git a/graphrag/index/operations/embed_graph/embed_graph.py b/graphrag/index/operations/embed_graph/embed_graph.py index ab125a9315..d6d345a124 100644 --- a/graphrag/index/operations/embed_graph/embed_graph.py +++ b/graphrag/index/operations/embed_graph/embed_graph.py @@ -12,9 +12,8 @@ from graphrag.index.graph.embedding import embed_nod2vec from graphrag.index.graph.utils import stable_largest_connected_component -from graphrag.index.utils import load_graph - -from .typing import NodeEmbeddings +from graphrag.index.operations.embed_graph.typing import NodeEmbeddings +from graphrag.index.utils.load_graph import load_graph class EmbedGraphStrategyType(str, Enum): diff --git a/graphrag/index/operations/embed_text/__init__.py b/graphrag/index/operations/embed_text/__init__.py index 30819b9954..214f064c96 100644 --- a/graphrag/index/operations/embed_text/__init__.py +++ b/graphrag/index/operations/embed_text/__init__.py @@ -3,6 +3,9 @@ """The Indexing Engine text embed package root.""" -from .embed_text import TextEmbedStrategyType, embed_text +from graphrag.index.operations.embed_text.embed_text import ( + TextEmbedStrategyType, + embed_text, +) __all__ = ["TextEmbedStrategyType", "embed_text"] diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index d6f0b3b9ae..627d1728b2 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -11,15 +11,11 @@ import pandas as pd from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.utils.embeddings import create_collection_name -from graphrag.vector_stores import ( - BaseVectorStore, - VectorStoreDocument, - VectorStoreFactory, -) - -from .strategies.typing import TextEmbeddingStrategy +from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument +from graphrag.vector_stores.factory import VectorStoreFactory log = logging.getLogger(__name__) @@ -242,11 +238,15 @@ def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: """Load strategy method definition.""" match strategy: case TextEmbedStrategyType.openai: - from .strategies.openai import run as run_openai + from graphrag.index.operations.embed_text.strategies.openai import ( + run as run_openai, + ) return run_openai case TextEmbedStrategyType.mock: - from .strategies.mock import run as run_mock + from graphrag.index.operations.embed_text.strategies.mock import ( + run as run_mock, + ) return run_mock case _: diff --git a/graphrag/index/operations/embed_text/strategies/mock.py b/graphrag/index/operations/embed_text/strategies/mock.py index 1be4ab0f9f..a32eceb386 100644 --- a/graphrag/index/operations/embed_text/strategies/mock.py +++ b/graphrag/index/operations/embed_text/strategies/mock.py @@ -9,9 +9,8 @@ from datashaper import ProgressTicker, VerbCallbacks, progress_ticker -from graphrag.index.cache import PipelineCache - -from .typing import TextEmbeddingResult +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult async def run( # noqa RUF029 async is required for interface diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/strategies/openai.py index fb443ec83e..f0445d8480 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/strategies/openai.py @@ -11,14 +11,13 @@ from datashaper import ProgressTicker, VerbCallbacks, progress_ticker import graphrag.config.defaults as defs -from graphrag.index.cache import PipelineCache -from graphrag.index.llm import load_llm_embeddings -from graphrag.index.text_splitting import TokenTextSplitter -from graphrag.index.utils import is_null +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.llm.load_llm import load_llm_embeddings +from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult +from graphrag.index.text_splitting.text_splitting import TokenTextSplitter +from graphrag.index.utils.is_null import is_null from graphrag.llm import EmbeddingLLM, OpenAIConfiguration -from .typing import TextEmbeddingResult - log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/embed_text/strategies/typing.py b/graphrag/index/operations/embed_text/strategies/typing.py index 1b25256497..79998f72eb 100644 --- a/graphrag/index/operations/embed_text/strategies/typing.py +++ b/graphrag/index/operations/embed_text/strategies/typing.py @@ -8,7 +8,7 @@ from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache @dataclass diff --git a/graphrag/index/operations/extract_covariates/__init__.py b/graphrag/index/operations/extract_covariates/__init__.py index 53d357bb46..315f503c3e 100644 --- a/graphrag/index/operations/extract_covariates/__init__.py +++ b/graphrag/index/operations/extract_covariates/__init__.py @@ -3,6 +3,9 @@ """The Indexing Engine text extract claims package root.""" -from .extract_covariates import ExtractClaimsStrategyType, extract_covariates +from graphrag.index.operations.extract_covariates.extract_covariates import ( + ExtractClaimsStrategyType, + extract_covariates, +) __all__ = ["ExtractClaimsStrategyType", "extract_covariates"] diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index 1ee5f51cc6..48b3b04c31 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -14,9 +14,12 @@ derive_from_rows, ) -from graphrag.index.cache import PipelineCache - -from .typing import Covariate, CovariateExtractStrategy, ExtractClaimsStrategyType +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.extract_covariates.typing import ( + Covariate, + CovariateExtractStrategy, + ExtractClaimsStrategyType, +) log = logging.getLogger(__name__) @@ -72,7 +75,9 @@ def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractS """Load strategy method definition.""" match strategy_type: case ExtractClaimsStrategyType.graph_intelligence: - from .strategies import run_graph_intelligence + from graphrag.index.operations.extract_covariates.strategies import ( + run_graph_intelligence, + ) return run_graph_intelligence case _: diff --git a/graphrag/index/operations/extract_covariates/strategies.py b/graphrag/index/operations/extract_covariates/strategies.py index 2ef83e513a..4d7729961c 100644 --- a/graphrag/index/operations/extract_covariates/strategies.py +++ b/graphrag/index/operations/extract_covariates/strategies.py @@ -9,15 +9,14 @@ from datashaper import VerbCallbacks import graphrag.config.defaults as defs -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.claims import ClaimExtractor -from graphrag.index.llm import load_llm -from graphrag.llm import CompletionLLM - -from .typing import ( +from graphrag.index.llm.load_llm import load_llm +from graphrag.index.operations.extract_covariates.typing import ( Covariate, CovariateExtractionResult, ) +from graphrag.llm import CompletionLLM async def run_graph_intelligence( diff --git a/graphrag/index/operations/extract_covariates/typing.py b/graphrag/index/operations/extract_covariates/typing.py index c0cb96633f..a208ea3e9f 100644 --- a/graphrag/index/operations/extract_covariates/typing.py +++ b/graphrag/index/operations/extract_covariates/typing.py @@ -10,7 +10,7 @@ from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache @dataclass diff --git a/graphrag/index/operations/extract_entities/__init__.py b/graphrag/index/operations/extract_entities/__init__.py index 579b57dfd4..5092df4a52 100644 --- a/graphrag/index/operations/extract_entities/__init__.py +++ b/graphrag/index/operations/extract_entities/__init__.py @@ -3,6 +3,9 @@ """The Indexing Engine entities extraction package root.""" -from .extract_entities import ExtractEntityStrategyType, extract_entities +from graphrag.index.operations.extract_entities.extract_entities import ( + ExtractEntityStrategyType, + extract_entities, +) __all__ = ["ExtractEntityStrategyType", "extract_entities"] diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index 96bec73b25..522d4d98e5 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -16,9 +16,11 @@ ) from graphrag.index.bootstrap import bootstrap -from graphrag.index.cache import PipelineCache - -from .strategies.typing import Document, EntityExtractStrategy +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.extract_entities.strategies.typing import ( + Document, + EntityExtractStrategy, +) log = logging.getLogger(__name__) @@ -162,14 +164,18 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr """Load strategy method definition.""" match strategy_type: case ExtractEntityStrategyType.graph_intelligence: - from .strategies.graph_intelligence import run_graph_intelligence + from graphrag.index.operations.extract_entities.strategies.graph_intelligence import ( + run_graph_intelligence, + ) return run_graph_intelligence case ExtractEntityStrategyType.nltk: bootstrap() # dynamically import nltk strategy to avoid dependency if not used - from .strategies.nltk import run as run_nltk + from graphrag.index.operations.extract_entities.strategies.nltk import ( + run as run_nltk, + ) return run_nltk case _: diff --git a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py index 072d5bed0a..18fcc97444 100644 --- a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py +++ b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py @@ -6,22 +6,21 @@ from datashaper import VerbCallbacks import graphrag.config.defaults as defs -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors import GraphExtractor -from graphrag.index.llm import load_llm -from graphrag.index.text_splitting import ( - NoopTextSplitter, - TextSplitter, - TokenTextSplitter, -) -from graphrag.llm import CompletionLLM - -from .typing import ( +from graphrag.index.llm.load_llm import load_llm +from graphrag.index.operations.extract_entities.strategies.typing import ( Document, EntityExtractionResult, EntityTypes, StrategyConfig, ) +from graphrag.index.text_splitting.text_splitting import ( + NoopTextSplitter, + TextSplitter, + TokenTextSplitter, +) +from graphrag.llm import CompletionLLM async def run_graph_intelligence( diff --git a/graphrag/index/operations/extract_entities/strategies/nltk.py b/graphrag/index/operations/extract_entities/strategies/nltk.py index 8f9aefa0ee..08f447004a 100644 --- a/graphrag/index/operations/extract_entities/strategies/nltk.py +++ b/graphrag/index/operations/extract_entities/strategies/nltk.py @@ -8,9 +8,13 @@ from datashaper import VerbCallbacks from nltk.corpus import words -from graphrag.index.cache import PipelineCache - -from .typing import Document, EntityExtractionResult, EntityTypes, StrategyConfig +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.extract_entities.strategies.typing import ( + Document, + EntityExtractionResult, + EntityTypes, + StrategyConfig, +) # Need to do this cause we're potentially multithreading, and nltk doesn't like that words.ensure_loaded() diff --git a/graphrag/index/operations/extract_entities/strategies/typing.py b/graphrag/index/operations/extract_entities/strategies/typing.py index e1c548b0c4..57df220d9b 100644 --- a/graphrag/index/operations/extract_entities/strategies/typing.py +++ b/graphrag/index/operations/extract_entities/strategies/typing.py @@ -10,7 +10,7 @@ import networkx as nx from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache ExtractedEntity = dict[str, Any] StrategyConfig = dict[str, Any] diff --git a/graphrag/index/operations/layout_graph/__init__.py b/graphrag/index/operations/layout_graph/__init__.py index 74584f83ed..7638c5a3f6 100644 --- a/graphrag/index/operations/layout_graph/__init__.py +++ b/graphrag/index/operations/layout_graph/__init__.py @@ -3,6 +3,6 @@ """The Indexing Engine graph layout package root.""" -from .layout_graph import layout_graph +from graphrag.index.operations.layout_graph.layout_graph import layout_graph __all__ = ["layout_graph"] diff --git a/graphrag/index/operations/layout_graph/layout_graph.py b/graphrag/index/operations/layout_graph/layout_graph.py index d2b232660d..356511f4fd 100644 --- a/graphrag/index/operations/layout_graph/layout_graph.py +++ b/graphrag/index/operations/layout_graph/layout_graph.py @@ -12,7 +12,7 @@ from graphrag.index.graph.visualization import GraphLayout from graphrag.index.operations.embed_graph import NodeEmbeddings -from graphrag.index.utils import load_graph +from graphrag.index.utils.load_graph import load_graph class LayoutGraphStrategyType(str, Enum): @@ -102,7 +102,9 @@ def _run_layout( graph = load_graph(graphml_or_graph) match strategy: case LayoutGraphStrategyType.umap: - from .methods.umap import run as run_umap + from graphrag.index.operations.layout_graph.methods.umap import ( + run as run_umap, + ) return run_umap( graph, @@ -111,7 +113,9 @@ def _run_layout( lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d), ) case LayoutGraphStrategyType.zero: - from .methods.zero import run as run_zero + from graphrag.index.operations.layout_graph.methods.zero import ( + run as run_zero, + ) return run_zero( graph, diff --git a/graphrag/index/operations/merge_graphs/__init__.py b/graphrag/index/operations/merge_graphs/__init__.py index f3b957dd9d..f5f463520d 100644 --- a/graphrag/index/operations/merge_graphs/__init__.py +++ b/graphrag/index/operations/merge_graphs/__init__.py @@ -3,7 +3,7 @@ """merge_graphs operation.""" -from .merge_graphs import merge_graphs +from graphrag.index.operations.merge_graphs.merge_graphs import merge_graphs __all__ = [ "merge_graphs", diff --git a/graphrag/index/operations/merge_graphs/merge_graphs.py b/graphrag/index/operations/merge_graphs/merge_graphs.py index 80ab20ef41..9aee37cf3e 100644 --- a/graphrag/index/operations/merge_graphs/merge_graphs.py +++ b/graphrag/index/operations/merge_graphs/merge_graphs.py @@ -8,7 +8,7 @@ import networkx as nx from datashaper import VerbCallbacks, progress_iterable -from .typing import ( +from graphrag.index.operations.merge_graphs.typing import ( BasicMergeOperation, DetailedAttributeMergeOperation, NumericOperation, diff --git a/graphrag/index/operations/snapshot.py b/graphrag/index/operations/snapshot.py index c889595649..7ae7f0ca09 100644 --- a/graphrag/index/operations/snapshot.py +++ b/graphrag/index/operations/snapshot.py @@ -5,7 +5,7 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage async def snapshot( diff --git a/graphrag/index/operations/snapshot_graphml.py b/graphrag/index/operations/snapshot_graphml.py index 07a174fad6..feda376f97 100644 --- a/graphrag/index/operations/snapshot_graphml.py +++ b/graphrag/index/operations/snapshot_graphml.py @@ -5,7 +5,7 @@ import networkx as nx -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage async def snapshot_graphml( diff --git a/graphrag/index/operations/snapshot_rows.py b/graphrag/index/operations/snapshot_rows.py index 5abd771bc3..0050f98061 100644 --- a/graphrag/index/operations/snapshot_rows.py +++ b/graphrag/index/operations/snapshot_rows.py @@ -9,7 +9,7 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @dataclass diff --git a/graphrag/index/operations/summarize_communities/__init__.py b/graphrag/index/operations/summarize_communities/__init__.py index d3065198b6..6b74b9f3a9 100644 --- a/graphrag/index/operations/summarize_communities/__init__.py +++ b/graphrag/index/operations/summarize_communities/__init__.py @@ -3,10 +3,18 @@ """Community summarization modules.""" -from .prepare_community_reports import prepare_community_reports -from .restore_community_hierarchy import restore_community_hierarchy -from .summarize_communities import summarize_communities -from .typing import CreateCommunityReportsStrategyType +from graphrag.index.operations.summarize_communities.prepare_community_reports import ( + prepare_community_reports, +) +from graphrag.index.operations.summarize_communities.restore_community_hierarchy import ( + restore_community_hierarchy, +) +from graphrag.index.operations.summarize_communities.summarize_communities import ( + summarize_communities, +) +from graphrag.index.operations.summarize_communities.typing import ( + CreateCommunityReportsStrategyType, +) __all__ = [ "CreateCommunityReportsStrategyType", diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 2653e41f47..33900a9852 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -9,18 +9,17 @@ from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.community_reports import ( CommunityReportsExtractor, ) -from graphrag.index.llm import load_llm -from graphrag.index.utils.rate_limiter import RateLimiter -from graphrag.llm import CompletionLLM - -from .typing import ( +from graphrag.index.llm.load_llm import load_llm +from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, StrategyConfig, ) +from graphrag.index.utils.rate_limiter import RateLimiter +from graphrag.llm import CompletionLLM DEFAULT_CHUNK_SIZE = 3000 diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index ad8c44077e..de35981d0a 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -16,13 +16,12 @@ import graphrag.config.defaults as defaults import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.community_reports import ( get_levels, prep_community_report_context, ) - -from .typing import ( +from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, CommunityReportsStrategy, CreateCommunityReportsStrategyType, @@ -104,7 +103,9 @@ def load_strategy( """Load strategy method definition.""" match strategy: case CreateCommunityReportsStrategyType.graph_intelligence: - from .strategies import run_graph_intelligence + from graphrag.index.operations.summarize_communities.strategies import ( + run_graph_intelligence, + ) return run_graph_intelligence case _: diff --git a/graphrag/index/operations/summarize_communities/typing.py b/graphrag/index/operations/summarize_communities/typing.py index a5a01cbbc2..46a1a09573 100644 --- a/graphrag/index/operations/summarize_communities/typing.py +++ b/graphrag/index/operations/summarize_communities/typing.py @@ -10,7 +10,7 @@ from datashaper import VerbCallbacks from typing_extensions import TypedDict -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache ExtractedEntity = dict[str, Any] StrategyConfig = dict[str, Any] diff --git a/graphrag/index/operations/summarize_descriptions/__init__.py b/graphrag/index/operations/summarize_descriptions/__init__.py index 55f818d11c..1b67f6dc72 100644 --- a/graphrag/index/operations/summarize_descriptions/__init__.py +++ b/graphrag/index/operations/summarize_descriptions/__init__.py @@ -3,8 +3,13 @@ """Root package for description summarization.""" -from .summarize_descriptions import summarize_descriptions -from .typing import SummarizationStrategy, SummarizeStrategyType +from graphrag.index.operations.summarize_descriptions.summarize_descriptions import ( + summarize_descriptions, +) +from graphrag.index.operations.summarize_descriptions.typing import ( + SummarizationStrategy, + SummarizeStrategyType, +) __all__ = [ "SummarizationStrategy", diff --git a/graphrag/index/operations/summarize_descriptions/strategies.py b/graphrag/index/operations/summarize_descriptions/strategies.py index 91ff0d3141..fd6ea5a849 100644 --- a/graphrag/index/operations/summarize_descriptions/strategies.py +++ b/graphrag/index/operations/summarize_descriptions/strategies.py @@ -5,15 +5,14 @@ from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.summarize import SummarizeExtractor -from graphrag.index.llm import load_llm -from graphrag.llm import CompletionLLM - -from .typing import ( +from graphrag.index.llm.load_llm import load_llm +from graphrag.index.operations.summarize_descriptions.typing import ( StrategyConfig, SummarizedDescriptionResult, ) +from graphrag.llm import CompletionLLM async def run_graph_intelligence( diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index 0775447369..612f2c6795 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -14,9 +14,8 @@ progress_ticker, ) -from graphrag.index.cache import PipelineCache - -from .typing import ( +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.operations.summarize_descriptions.typing import ( SummarizationStrategy, SummarizeStrategyType, ) @@ -140,7 +139,9 @@ def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy """Load strategy method definition.""" match strategy_type: case SummarizeStrategyType.graph_intelligence: - from .strategies import run_graph_intelligence + from graphrag.index.operations.summarize_descriptions.strategies import ( + run_graph_intelligence, + ) return run_graph_intelligence case _: diff --git a/graphrag/index/operations/summarize_descriptions/typing.py b/graphrag/index/operations/summarize_descriptions/typing.py index 4e957cf49f..c7ba9ceb74 100644 --- a/graphrag/index/operations/summarize_descriptions/typing.py +++ b/graphrag/index/operations/summarize_descriptions/typing.py @@ -10,7 +10,7 @@ from datashaper import VerbCallbacks -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache StrategyConfig = dict[str, Any] diff --git a/graphrag/index/operations/unpack_graph.py b/graphrag/index/operations/unpack_graph.py index ad9f738125..f64ce43e0b 100644 --- a/graphrag/index/operations/unpack_graph.py +++ b/graphrag/index/operations/unpack_graph.py @@ -9,7 +9,7 @@ import pandas as pd from datashaper import VerbCallbacks, progress_iterable -from graphrag.index.utils import load_graph +from graphrag.index.utils.load_graph import load_graph default_copy = ["level"] diff --git a/graphrag/index/run/cache.py b/graphrag/index/run/cache.py index 0b77576ef4..d1c57e4f80 100644 --- a/graphrag/index/run/cache.py +++ b/graphrag/index/run/cache.py @@ -3,7 +3,7 @@ """Cache functions for the GraphRAG update module.""" -from graphrag.index.cache import load_cache +from graphrag.index.cache.load_cache import load_cache from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.config.cache import ( PipelineCacheConfigTypes, diff --git a/graphrag/index/run/run.py b/graphrag/index/run/run.py index 8db214bf93..d2dfa19df5 100644 --- a/graphrag/index/run/run.py +++ b/graphrag/index/run/run.py @@ -15,13 +15,14 @@ from datashaper import NoopVerbCallbacks, WorkflowCallbacks from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks -from graphrag.index.cache import PipelineCache -from graphrag.index.config import ( +from graphrag.index.cache.pipeline_cache import PipelineCache +from graphrag.index.config.pipeline import ( PipelineConfig, PipelineWorkflowReference, - PipelineWorkflowStep, ) -from graphrag.index.emit import TableEmitterType, create_table_emitters +from graphrag.index.config.workflow import PipelineWorkflowStep +from graphrag.index.emit.factories import create_table_emitters +from graphrag.index.emit.types import TableEmitterType from graphrag.index.load_pipeline_config import load_pipeline_config from graphrag.index.run.cache import _create_cache from graphrag.index.run.postprocess import ( @@ -40,7 +41,7 @@ _create_callback_chain, _process_workflow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.typing import PipelineRunResult from graphrag.index.update.incremental_index import ( get_delta_docs, @@ -51,10 +52,8 @@ WorkflowDefinitions, load_workflows, ) -from graphrag.logging import ( - NullProgressReporter, - ProgressReporter, -) +from graphrag.logging.base import ProgressReporter +from graphrag.logging.null_progress import NullProgressReporter from graphrag.utils.storage import _create_storage log = logging.getLogger(__name__) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index e672763513..ab5b4989f3 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -31,10 +31,10 @@ PipelineFileStorageConfig, ) from graphrag.index.context import PipelineRunContext, PipelineRunStats -from graphrag.index.input import load_input +from graphrag.index.input.load_input import load_input from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.index.storage.pipeline_storage import PipelineStorage -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/run/workflow.py b/graphrag/index/run/workflow.py index a59244a20d..4497910262 100644 --- a/graphrag/index/run/workflow.py +++ b/graphrag/index/run/workflow.py @@ -22,7 +22,7 @@ from graphrag.index.run.profiling import _write_workflow_stats from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.typing import PipelineRunResult -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter from graphrag.utils.storage import _load_table_from_storage log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/__init__.py b/graphrag/index/storage/__init__.py index d1025d24d8..51f34cbd3d 100644 --- a/graphrag/index/storage/__init__.py +++ b/graphrag/index/storage/__init__.py @@ -2,18 +2,3 @@ # Licensed under the MIT License """The Indexing Engine storage package root.""" - -from .blob_pipeline_storage import BlobPipelineStorage, create_blob_storage -from .file_pipeline_storage import FilePipelineStorage -from .load_storage import load_storage -from .memory_pipeline_storage import MemoryPipelineStorage -from .pipeline_storage import PipelineStorage - -__all__ = [ - "BlobPipelineStorage", - "FilePipelineStorage", - "MemoryPipelineStorage", - "PipelineStorage", - "create_blob_storage", - "load_storage", -] diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index bdf25c9970..da75a10082 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -13,9 +13,8 @@ from azure.storage.blob import BlobServiceClient from datashaper import Progress -from graphrag.logging import ProgressReporter - -from .pipeline_storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.logging.base import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/file_pipeline_storage.py b/graphrag/index/storage/file_pipeline_storage.py index a3a18cf436..dbbca4e18e 100644 --- a/graphrag/index/storage/file_pipeline_storage.py +++ b/graphrag/index/storage/file_pipeline_storage.py @@ -16,9 +16,8 @@ from aiofiles.ospath import exists from datashaper import Progress -from graphrag.logging import ProgressReporter - -from .pipeline_storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.logging.base import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/load_storage.py b/graphrag/index/storage/load_storage.py index 33d61ee97f..fc847e06f8 100644 --- a/graphrag/index/storage/load_storage.py +++ b/graphrag/index/storage/load_storage.py @@ -7,16 +7,15 @@ from typing import cast -from graphrag.config import StorageType +from graphrag.config.enums import StorageType from graphrag.index.config.storage import ( PipelineBlobStorageConfig, PipelineFileStorageConfig, PipelineStorageConfig, ) - -from .blob_pipeline_storage import create_blob_storage -from .file_pipeline_storage import create_file_storage -from .memory_pipeline_storage import create_memory_storage +from graphrag.index.storage.blob_pipeline_storage import create_blob_storage +from graphrag.index.storage.file_pipeline_storage import create_file_storage +from graphrag.index.storage.memory_pipeline_storage import create_memory_storage def load_storage(config: PipelineStorageConfig): diff --git a/graphrag/index/storage/memory_pipeline_storage.py b/graphrag/index/storage/memory_pipeline_storage.py index 3f9f9c9be9..80245c387e 100644 --- a/graphrag/index/storage/memory_pipeline_storage.py +++ b/graphrag/index/storage/memory_pipeline_storage.py @@ -5,8 +5,8 @@ from typing import Any -from .file_pipeline_storage import FilePipelineStorage -from .pipeline_storage import PipelineStorage +from graphrag.index.storage.file_pipeline_storage import FilePipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage class MemoryPipelineStorage(FilePipelineStorage): diff --git a/graphrag/index/storage/pipeline_storage.py b/graphrag/index/storage/pipeline_storage.py index 554c2ffd74..ec67b43234 100644 --- a/graphrag/index/storage/pipeline_storage.py +++ b/graphrag/index/storage/pipeline_storage.py @@ -8,7 +8,7 @@ from collections.abc import Iterator from typing import Any -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter class PipelineStorage(metaclass=ABCMeta): diff --git a/graphrag/index/text_splitting/__init__.py b/graphrag/index/text_splitting/__init__.py index 4653adb22b..e6f3b31b4c 100644 --- a/graphrag/index/text_splitting/__init__.py +++ b/graphrag/index/text_splitting/__init__.py @@ -2,33 +2,3 @@ # Licensed under the MIT License """The Indexing Engine Text Splitting package root.""" - -from .check_token_limit import check_token_limit -from .text_splitting import ( - DecodeFn, - EncodedText, - EncodeFn, - LengthFn, - NoopTextSplitter, - TextListSplitter, - TextListSplitterType, - TextSplitter, - Tokenizer, - TokenTextSplitter, - split_text_on_tokens, -) - -__all__ = [ - "DecodeFn", - "EncodeFn", - "EncodedText", - "LengthFn", - "NoopTextSplitter", - "TextListSplitter", - "TextListSplitterType", - "TextSplitter", - "TokenTextSplitter", - "Tokenizer", - "check_token_limit", - "split_text_on_tokens", -] diff --git a/graphrag/index/text_splitting/check_token_limit.py b/graphrag/index/text_splitting/check_token_limit.py index 1a5f862254..7b6a139e02 100644 --- a/graphrag/index/text_splitting/check_token_limit.py +++ b/graphrag/index/text_splitting/check_token_limit.py @@ -3,7 +3,7 @@ """Token limit method definition.""" -from .text_splitting import TokenTextSplitter +from graphrag.index.text_splitting.text_splitting import TokenTextSplitter def check_token_limit(text, max_token): diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index c65515da5b..7ef5dd874c 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -14,7 +14,7 @@ import pandas as pd import tiktoken -from graphrag.index.utils import num_tokens_from_string +from graphrag.index.utils.tokens import num_tokens_from_string EncodedText = list[int] DecodeFn = Callable[[EncodedText], str] diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 89acb285e0..abebfd7e0d 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -24,7 +24,7 @@ _run_entity_summarization, ) from graphrag.index.update.relationships import _update_and_merge_relationships -from graphrag.logging.types import ProgressReporter +from graphrag.logging.print_progress import ProgressReporter from graphrag.utils.storage import _load_table_from_storage diff --git a/graphrag/index/utils/__init__.py b/graphrag/index/utils/__init__.py index 7cbbb53d75..d1737fc9be 100644 --- a/graphrag/index/utils/__init__.py +++ b/graphrag/index/utils/__init__.py @@ -2,24 +2,3 @@ # Licensed under the MIT License """Utils methods definition.""" - -from .dicts import dict_has_keys_with_types -from .hashing import gen_md5_hash -from .is_null import is_null -from .load_graph import load_graph -from .string import clean_str -from .tokens import num_tokens_from_string, string_from_tokens -from .topological_sort import topological_sort -from .uuid import gen_uuid - -__all__ = [ - "clean_str", - "dict_has_keys_with_types", - "gen_md5_hash", - "gen_uuid", - "is_null", - "load_graph", - "num_tokens_from_string", - "string_from_tokens", - "topological_sort", -] diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index bc3b8a0ed6..11d7fd8390 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -8,9 +8,9 @@ from datashaper import NoopVerbCallbacks -from graphrag.config.models import GraphRagConfig -from graphrag.index.llm import load_llm, load_llm_embeddings -from graphrag.logging import ProgressReporter +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings +from graphrag.logging.print_progress import ProgressReporter def validate_config_names( diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index ed580309a8..db1cb74c7b 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -3,8 +3,8 @@ """The Indexing Engine workflows package root.""" -from .load import create_workflow, load_workflows -from .typing import ( +from graphrag.index.workflows.load import create_workflow, load_workflows +from graphrag.index.workflows.typing import ( StepDefinition, VerbDefinitions, VerbTiming, diff --git a/graphrag/index/workflows/default_workflows.py b/graphrag/index/workflows/default_workflows.py index 5a3d176b56..536423c4e3 100644 --- a/graphrag/index/workflows/default_workflows.py +++ b/graphrag/index/workflows/default_workflows.py @@ -4,74 +4,74 @@ """A package containing default workflows definitions.""" # load and register all subflows -from .v1.subflows import * # noqa +from graphrag.index.workflows.v1.subflows import * # noqa -from .typing import WorkflowDefinitions -from .v1.create_base_entity_graph import ( +from graphrag.index.workflows.typing import WorkflowDefinitions +from graphrag.index.workflows.v1.create_base_entity_graph import ( build_steps as build_create_base_entity_graph_steps, ) -from .v1.create_base_entity_graph import ( +from graphrag.index.workflows.v1.create_base_entity_graph import ( workflow_name as create_base_entity_graph, ) -from .v1.create_base_text_units import ( +from graphrag.index.workflows.v1.create_base_text_units import ( build_steps as build_create_base_text_units_steps, ) -from .v1.create_base_text_units import ( +from graphrag.index.workflows.v1.create_base_text_units import ( workflow_name as create_base_text_units, ) -from .v1.create_final_communities import ( +from graphrag.index.workflows.v1.create_final_communities import ( build_steps as build_create_final_communities_steps, ) -from .v1.create_final_communities import ( +from graphrag.index.workflows.v1.create_final_communities import ( workflow_name as create_final_communities, ) -from .v1.create_final_community_reports import ( +from graphrag.index.workflows.v1.create_final_community_reports import ( build_steps as build_create_final_community_reports_steps, ) -from .v1.create_final_community_reports import ( +from graphrag.index.workflows.v1.create_final_community_reports import ( workflow_name as create_final_community_reports, ) -from .v1.create_final_covariates import ( +from graphrag.index.workflows.v1.create_final_covariates import ( build_steps as build_create_final_covariates_steps, ) -from .v1.create_final_covariates import ( +from graphrag.index.workflows.v1.create_final_covariates import ( workflow_name as create_final_covariates, ) -from .v1.create_final_documents import ( +from graphrag.index.workflows.v1.create_final_documents import ( build_steps as build_create_final_documents_steps, ) -from .v1.create_final_documents import ( +from graphrag.index.workflows.v1.create_final_documents import ( workflow_name as create_final_documents, ) -from .v1.create_final_entities import ( +from graphrag.index.workflows.v1.create_final_entities import ( build_steps as build_create_final_entities_steps, ) -from .v1.create_final_entities import ( +from graphrag.index.workflows.v1.create_final_entities import ( workflow_name as create_final_entities, ) -from .v1.create_final_nodes import ( +from graphrag.index.workflows.v1.create_final_nodes import ( build_steps as build_create_final_nodes_steps, ) -from .v1.create_final_nodes import ( +from graphrag.index.workflows.v1.create_final_nodes import ( workflow_name as create_final_nodes, ) -from .v1.create_final_relationships import ( +from graphrag.index.workflows.v1.create_final_relationships import ( build_steps as build_create_final_relationships_steps, ) -from .v1.create_final_relationships import ( +from graphrag.index.workflows.v1.create_final_relationships import ( workflow_name as create_final_relationships, ) -from .v1.create_final_text_units import ( +from graphrag.index.workflows.v1.create_final_text_units import ( build_steps as build_create_final_text_units, ) -from .v1.create_final_text_units import ( +from graphrag.index.workflows.v1.create_final_text_units import ( workflow_name as create_final_text_units, ) -from .v1.generate_text_embeddings import ( +from graphrag.index.workflows.v1.generate_text_embeddings import ( build_steps as build_generate_text_embeddings_steps, ) -from .v1.generate_text_embeddings import ( +from graphrag.index.workflows.v1.generate_text_embeddings import ( workflow_name as generate_text_embeddings, ) diff --git a/graphrag/index/workflows/load.py b/graphrag/index/workflows/load.py index a9f65b86d1..236642d165 100644 --- a/graphrag/index/workflows/load.py +++ b/graphrag/index/workflows/load.py @@ -16,13 +16,18 @@ UndefinedWorkflowError, UnknownWorkflowError, ) -from graphrag.index.utils import topological_sort - -from .default_workflows import default_workflows as _default_workflows -from .typing import VerbDefinitions, WorkflowDefinitions, WorkflowToRun +from graphrag.index.utils.topological_sort import topological_sort +from graphrag.index.workflows.default_workflows import ( + default_workflows as _default_workflows, +) +from graphrag.index.workflows.typing import ( + VerbDefinitions, + WorkflowDefinitions, + WorkflowToRun, +) if TYPE_CHECKING: - from graphrag.index.config import ( + from graphrag.index.config.workflow import ( PipelineWorkflowConfig, PipelineWorkflowReference, PipelineWorkflowStep, diff --git a/graphrag/index/workflows/v1/create_base_entity_graph.py b/graphrag/index/workflows/v1/create_base_entity_graph.py index bb0c41ac57..0e8a9a4fb7 100644 --- a/graphrag/index/workflows/v1/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -7,7 +7,7 @@ AsyncType, ) -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_base_entity_graph" diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index efd7f7eab1..40250b62d2 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -5,7 +5,7 @@ from datashaper import DEFAULT_INPUT_NAME -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_base_text_units" diff --git a/graphrag/index/workflows/v1/create_final_communities.py b/graphrag/index/workflows/v1/create_final_communities.py index 96ca5215eb..b5296b4bfc 100644 --- a/graphrag/index/workflows/v1/create_final_communities.py +++ b/graphrag/index/workflows/v1/create_final_communities.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_communities" diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index 8a56583d84..6b8d110fe1 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_community_reports" diff --git a/graphrag/index/workflows/v1/create_final_covariates.py b/graphrag/index/workflows/v1/create_final_covariates.py index 6bdf32e4bb..b730a1737d 100644 --- a/graphrag/index/workflows/v1/create_final_covariates.py +++ b/graphrag/index/workflows/v1/create_final_covariates.py @@ -7,7 +7,7 @@ AsyncType, ) -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_covariates" diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index 5160cc5165..ad0a1f036e 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -5,7 +5,7 @@ from datashaper import DEFAULT_INPUT_NAME -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_documents" diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index 50ee56d8e5..d36d5bb331 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -5,7 +5,7 @@ import logging -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_entities" log = logging.getLogger(__name__) diff --git a/graphrag/index/workflows/v1/create_final_nodes.py b/graphrag/index/workflows/v1/create_final_nodes.py index dedac2870e..ff22adbc0f 100644 --- a/graphrag/index/workflows/v1/create_final_nodes.py +++ b/graphrag/index/workflows/v1/create_final_nodes.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_nodes" diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index 3eaff05b0e..c3947ba860 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -5,7 +5,7 @@ import logging -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_relationships" diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py index 31015b0d01..a39e22d2e2 100644 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ b/graphrag/index/workflows/v1/create_final_text_units.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_text_units" diff --git a/graphrag/index/workflows/v1/generate_text_embeddings.py b/graphrag/index/workflows/v1/generate_text_embeddings.py index e919f8790d..58464b33a8 100644 --- a/graphrag/index/workflows/v1/generate_text_embeddings.py +++ b/graphrag/index/workflows/v1/generate_text_embeddings.py @@ -5,7 +5,7 @@ import logging -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep log = logging.getLogger(__name__) diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 8e080c5c71..1002a31af9 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -3,19 +3,37 @@ """The Indexing Engine workflows -> subflows package root.""" -from .create_base_entity_graph import create_base_entity_graph -from .create_base_text_units import create_base_text_units -from .create_final_communities import create_final_communities -from .create_final_community_reports import create_final_community_reports -from .create_final_covariates import create_final_covariates -from .create_final_documents import create_final_documents -from .create_final_entities import create_final_entities -from .create_final_nodes import create_final_nodes -from .create_final_relationships import ( +from graphrag.index.workflows.v1.subflows.create_base_entity_graph import ( + create_base_entity_graph, +) +from graphrag.index.workflows.v1.subflows.create_base_text_units import ( + create_base_text_units, +) +from graphrag.index.workflows.v1.subflows.create_final_communities import ( + create_final_communities, +) +from graphrag.index.workflows.v1.subflows.create_final_community_reports import ( + create_final_community_reports, +) +from graphrag.index.workflows.v1.subflows.create_final_covariates import ( + create_final_covariates, +) +from graphrag.index.workflows.v1.subflows.create_final_documents import ( + create_final_documents, +) +from graphrag.index.workflows.v1.subflows.create_final_entities import ( + create_final_entities, +) +from graphrag.index.workflows.v1.subflows.create_final_nodes import create_final_nodes +from graphrag.index.workflows.v1.subflows.create_final_relationships import ( create_final_relationships, ) -from .create_final_text_units import create_final_text_units -from .generate_text_embeddings import generate_text_embeddings +from graphrag.index.workflows.v1.subflows.create_final_text_units import ( + create_final_text_units, +) +from graphrag.index.workflows.v1.subflows.generate_text_embeddings import ( + generate_text_embeddings, +) __all__ = [ "create_base_entity_graph", diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py index 09c8e8067c..3f1199e219 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py @@ -14,11 +14,11 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.flows.create_base_entity_graph import ( create_base_entity_graph as create_base_entity_graph_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb( diff --git a/graphrag/index/workflows/v1/subflows/create_base_text_units.py b/graphrag/index/workflows/v1/subflows/create_base_text_units.py index 598c6cf10c..e9b3f43938 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_base_text_units.py @@ -17,7 +17,7 @@ from graphrag.index.flows.create_base_text_units import ( create_base_text_units as create_base_text_units_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb(name="create_base_text_units", treats_input_tables_as_immutable=True) diff --git a/graphrag/index/workflows/v1/subflows/create_final_communities.py b/graphrag/index/workflows/v1/subflows/create_final_communities.py index a70ecbf1f7..d66a327593 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_communities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_communities.py @@ -15,7 +15,7 @@ from graphrag.index.flows.create_final_communities import ( create_final_communities as create_final_communities_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb(name="create_final_communities", treats_input_tables_as_immutable=True) diff --git a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py index b8f5984e8c..ff6d9ef8a2 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py @@ -15,7 +15,7 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.flows.create_final_community_reports import ( create_final_community_reports as create_final_community_reports_flow, ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_covariates.py b/graphrag/index/workflows/v1/subflows/create_final_covariates.py index d6b83ed70f..0ab54b1d85 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_covariates.py +++ b/graphrag/index/workflows/v1/subflows/create_final_covariates.py @@ -13,11 +13,11 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.flows.create_final_covariates import ( create_final_covariates as create_final_covariates_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb(name="create_final_covariates", treats_input_tables_as_immutable=True) diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py index 9b7a4e7559..4ac4e24dcd 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -16,7 +16,7 @@ from graphrag.index.flows.create_final_documents import ( create_final_documents as create_final_documents_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb( diff --git a/graphrag/index/workflows/v1/subflows/create_final_entities.py b/graphrag/index/workflows/v1/subflows/create_final_entities.py index bd5e735a12..968fa0d24b 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_entities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_entities.py @@ -15,7 +15,7 @@ from graphrag.index.flows.create_final_entities import ( create_final_entities as create_final_entities_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb( diff --git a/graphrag/index/workflows/v1/subflows/create_final_nodes.py b/graphrag/index/workflows/v1/subflows/create_final_nodes.py index 2060e59eb3..e8266b76b5 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_nodes.py +++ b/graphrag/index/workflows/v1/subflows/create_final_nodes.py @@ -15,7 +15,7 @@ from graphrag.index.flows.create_final_nodes import ( create_final_nodes as create_final_nodes_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage @verb(name="create_final_nodes", treats_input_tables_as_immutable=True) diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py index 4769735487..995eeec834 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships.py +++ b/graphrag/index/workflows/v1/subflows/create_final_relationships.py @@ -17,7 +17,7 @@ from graphrag.index.flows.create_final_relationships import ( create_final_relationships as create_final_relationships_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.utils.ds_util import get_required_input_table diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py index a950dcb974..8cc69839a9 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units.py @@ -17,7 +17,7 @@ from graphrag.index.flows.create_final_text_units import ( create_final_text_units as create_final_text_units_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table diff --git a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py index b0da04badf..1ac256a90e 100644 --- a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py @@ -16,11 +16,11 @@ verb, ) -from graphrag.index.cache import PipelineCache +from graphrag.index.cache.pipeline_cache import PipelineCache from graphrag.index.flows.generate_text_embeddings import ( generate_text_embeddings as generate_text_embeddings_flow, ) -from graphrag.index.storage import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.utils.ds_util import get_required_input_table log = logging.getLogger(__name__) diff --git a/graphrag/logging/__init__.py b/graphrag/logging/__init__.py index 31afc94387..9f5fe0de3a 100644 --- a/graphrag/logging/__init__.py +++ b/graphrag/logging/__init__.py @@ -2,26 +2,3 @@ # Licensed under the MIT License """Logging utilities and implementations.""" - -from .console import ConsoleReporter -from .factories import create_progress_reporter -from .null_progress import NullProgressReporter -from .print_progress import PrintProgressReporter -from .rich_progress import RichProgressReporter -from .types import ( - ProgressReporter, - ReporterType, - StatusLogger, -) - -__all__ = [ - # Progress Reporters - "ConsoleReporter", - "NullProgressReporter", - "PrintProgressReporter", - "ProgressReporter", - "ReporterType", - "RichProgressReporter", - "StatusLogger", - "create_progress_reporter", -] diff --git a/graphrag/logging/base.py b/graphrag/logging/base.py new file mode 100644 index 0000000000..3afb69aaba --- /dev/null +++ b/graphrag/logging/base.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for logging and progress reporting.""" + +from abc import ABC, abstractmethod +from typing import Any + +from datashaper.progress.types import Progress + + +class StatusLogger(ABC): + """Provides a way to report status updates from the pipeline.""" + + @abstractmethod + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + + @abstractmethod + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + + @abstractmethod + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + + +class ProgressReporter(ABC): + """ + Abstract base class for progress reporters. + + This is used to report workflow processing progress via mechanisms like progress-bars. + """ + + @abstractmethod + def __call__(self, update: Progress): + """Update progress.""" + + @abstractmethod + def dispose(self): + """Dispose of the progress reporter.""" + + @abstractmethod + def child(self, prefix: str, transient=True) -> "ProgressReporter": + """Create a child progress bar.""" + + @abstractmethod + def force_refresh(self) -> None: + """Force a refresh.""" + + @abstractmethod + def stop(self) -> None: + """Stop the progress reporter.""" + + @abstractmethod + def error(self, message: str) -> None: + """Report an error.""" + + @abstractmethod + def warning(self, message: str) -> None: + """Report a warning.""" + + @abstractmethod + def info(self, message: str) -> None: + """Report information.""" + + @abstractmethod + def success(self, message: str) -> None: + """Report success.""" diff --git a/graphrag/logging/console.py b/graphrag/logging/console.py index b00a7e8d9c..e42269e645 100644 --- a/graphrag/logging/console.py +++ b/graphrag/logging/console.py @@ -5,7 +5,7 @@ from typing import Any -from .types import StatusLogger +from graphrag.logging.base import StatusLogger class ConsoleReporter(StatusLogger): diff --git a/graphrag/logging/factories.py b/graphrag/logging/factories.py index efd69b7550..9deefb3207 100644 --- a/graphrag/logging/factories.py +++ b/graphrag/logging/factories.py @@ -3,13 +3,11 @@ """Factory functions for creating loggers.""" -from .null_progress import NullProgressReporter -from .print_progress import PrintProgressReporter -from .rich_progress import RichProgressReporter -from .types import ( - ProgressReporter, - ReporterType, -) +from graphrag.logging.base import ProgressReporter +from graphrag.logging.null_progress import NullProgressReporter +from graphrag.logging.print_progress import PrintProgressReporter +from graphrag.logging.rich_progress import RichProgressReporter +from graphrag.logging.types import ReporterType def create_progress_reporter( diff --git a/graphrag/logging/null_progress.py b/graphrag/logging/null_progress.py index 0539c5c014..4d46400170 100644 --- a/graphrag/logging/null_progress.py +++ b/graphrag/logging/null_progress.py @@ -3,7 +3,7 @@ """Null Progress Reporter.""" -from .types import Progress, ProgressReporter +from graphrag.logging.base import Progress, ProgressReporter class NullProgressReporter(ProgressReporter): diff --git a/graphrag/logging/print_progress.py b/graphrag/logging/print_progress.py index d529e0dfd6..20c45dd38b 100644 --- a/graphrag/logging/print_progress.py +++ b/graphrag/logging/print_progress.py @@ -3,7 +3,7 @@ """Print Progress Reporter.""" -from .types import Progress, ProgressReporter +from graphrag.logging.base import Progress, ProgressReporter class PrintProgressReporter(ProgressReporter): diff --git a/graphrag/logging/rich_progress.py b/graphrag/logging/rich_progress.py index 362b64f0c8..f83261dbd6 100644 --- a/graphrag/logging/rich_progress.py +++ b/graphrag/logging/rich_progress.py @@ -13,7 +13,7 @@ from rich.spinner import Spinner from rich.tree import Tree -from .types import ProgressReporter +from graphrag.logging.base import ProgressReporter # https://stackoverflow.com/a/34325723 diff --git a/graphrag/logging/types.py b/graphrag/logging/types.py index 3ba50e5bd4..d852a47e8d 100644 --- a/graphrag/logging/types.py +++ b/graphrag/logging/types.py @@ -3,11 +3,7 @@ """Types for status reporting.""" -from abc import ABC, abstractmethod from enum import Enum -from typing import Any - -from datashaper import Progress class ReporterType(str, Enum): @@ -20,63 +16,3 @@ class ReporterType(str, Enum): def __str__(self): """Return the string representation of the enum value.""" return self.value - - -class StatusLogger(ABC): - """Provides a way to report status updates from the pipeline.""" - - @abstractmethod - def error(self, message: str, details: dict[str, Any] | None = None): - """Report an error.""" - - @abstractmethod - def warning(self, message: str, details: dict[str, Any] | None = None): - """Report a warning.""" - - @abstractmethod - def log(self, message: str, details: dict[str, Any] | None = None): - """Report a log.""" - - -class ProgressReporter(ABC): - """ - Abstract base class for progress reporters. - - This is used to report workflow processing progress via mechanisms like progress-bars. - """ - - @abstractmethod - def __call__(self, update: Progress): - """Update progress.""" - - @abstractmethod - def dispose(self): - """Dispose of the progress reporter.""" - - @abstractmethod - def child(self, prefix: str, transient=True) -> "ProgressReporter": - """Create a child progress bar.""" - - @abstractmethod - def force_refresh(self) -> None: - """Force a refresh.""" - - @abstractmethod - def stop(self) -> None: - """Stop the progress reporter.""" - - @abstractmethod - def error(self, message: str) -> None: - """Report an error.""" - - @abstractmethod - def warning(self, message: str) -> None: - """Report a warning.""" - - @abstractmethod - def info(self, message: str) -> None: - """Report information.""" - - @abstractmethod - def success(self, message: str) -> None: - """Report success.""" diff --git a/graphrag/model/__init__.py b/graphrag/model/__init__.py index 9dbec3d1dd..6523d9d916 100644 --- a/graphrag/model/__init__.py +++ b/graphrag/model/__init__.py @@ -1,31 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -""" -GraphRAG knowledge model package root. - -The GraphRAG knowledge model contains a set of classes that represent the target datamodels for our pipelines and analytics tools. -These models can be augmented and integrated into your own data infrastructure to suit your needs. -""" - -from .community import Community -from .community_report import CommunityReport -from .covariate import Covariate -from .document import Document -from .entity import Entity -from .identified import Identified -from .named import Named -from .relationship import Relationship -from .text_unit import TextUnit - -__all__ = [ - "Community", - "CommunityReport", - "Covariate", - "Document", - "Entity", - "Identified", - "Named", - "Relationship", - "TextUnit", -] +"""GraphRAG knowledge model package root.""" diff --git a/graphrag/model/community.py b/graphrag/model/community.py index 041aaa5e47..43d6c4033a 100644 --- a/graphrag/model/community.py +++ b/graphrag/model/community.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .named import Named +from graphrag.model.named import Named @dataclass diff --git a/graphrag/model/community_report.py b/graphrag/model/community_report.py index 53c35a5117..9216fb68d9 100644 --- a/graphrag/model/community_report.py +++ b/graphrag/model/community_report.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .named import Named +from graphrag.model.named import Named @dataclass diff --git a/graphrag/model/covariate.py b/graphrag/model/covariate.py index 484ea16fae..0ce188e497 100644 --- a/graphrag/model/covariate.py +++ b/graphrag/model/covariate.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .identified import Identified +from graphrag.model.identified import Identified @dataclass diff --git a/graphrag/model/document.py b/graphrag/model/document.py index 2980318376..ec2b2d4523 100644 --- a/graphrag/model/document.py +++ b/graphrag/model/document.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import Any -from .named import Named +from graphrag.model.named import Named @dataclass diff --git a/graphrag/model/entity.py b/graphrag/model/entity.py index a152abf2a5..4c45dfbc40 100644 --- a/graphrag/model/entity.py +++ b/graphrag/model/entity.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .named import Named +from graphrag.model.named import Named @dataclass diff --git a/graphrag/model/named.py b/graphrag/model/named.py index 5352c77c96..245fb5d333 100644 --- a/graphrag/model/named.py +++ b/graphrag/model/named.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from .identified import Identified +from graphrag.model.identified import Identified @dataclass diff --git a/graphrag/model/relationship.py b/graphrag/model/relationship.py index 54fb20c31c..ee9b24c1f6 100644 --- a/graphrag/model/relationship.py +++ b/graphrag/model/relationship.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .identified import Identified +from graphrag.model.identified import Identified @dataclass diff --git a/graphrag/model/text_unit.py b/graphrag/model/text_unit.py index b54ee9e5f8..4ad3b9e8d4 100644 --- a/graphrag/model/text_unit.py +++ b/graphrag/model/text_unit.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from .identified import Identified +from graphrag.model.identified import Identified @dataclass diff --git a/graphrag/prompt_tune/__init__.py b/graphrag/prompt_tune/__init__.py index 2384b5793c..6997787e61 100644 --- a/graphrag/prompt_tune/__init__.py +++ b/graphrag/prompt_tune/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Command line interface for the fine_tune module.""" +"""The prompt-tuning package root.""" diff --git a/graphrag/prompt_tune/defaults.py b/graphrag/prompt_tune/defaults.py new file mode 100644 index 0000000000..e72c82e1f6 --- /dev/null +++ b/graphrag/prompt_tune/defaults.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default values for the prompt-tuning module. + +Note: These values get accessed from the CLI to set default behavior. +To maintain fast responsiveness from the CLI, do not add long-running code in this file and be mindful of imports. +""" + +DEFAULT_TASK = """ +Identify the relations and structure of the community of interest, specifically within the {domain} domain. +""" + +K = 15 +MAX_TOKEN_COUNT = 2000 +MIN_CHUNK_SIZE = 200 +N_SUBSET_MAX = 300 +MIN_CHUNK_OVERLAP = 0 diff --git a/graphrag/prompt_tune/generator/__init__.py b/graphrag/prompt_tune/generator/__init__.py index df45b46033..8f144052ac 100644 --- a/graphrag/prompt_tune/generator/__init__.py +++ b/graphrag/prompt_tune/generator/__init__.py @@ -2,29 +2,3 @@ # Licensed under the MIT License """Prompt generation module.""" - -from .community_report_rating import generate_community_report_rating -from .community_report_summarization import create_community_summarization_prompt -from .community_reporter_role import generate_community_reporter_role -from .defaults import MAX_TOKEN_COUNT -from .domain import generate_domain -from .entity_extraction_prompt import create_entity_extraction_prompt -from .entity_relationship import generate_entity_relationship_examples -from .entity_summarization_prompt import create_entity_summarization_prompt -from .entity_types import generate_entity_types -from .language import detect_language -from .persona import generate_persona - -__all__ = [ - "MAX_TOKEN_COUNT", - "create_community_summarization_prompt", - "create_entity_extraction_prompt", - "create_entity_summarization_prompt", - "detect_language", - "generate_community_report_rating", - "generate_community_reporter_role", - "generate_domain", - "generate_entity_relationship_examples", - "generate_entity_types", - "generate_persona", -] diff --git a/graphrag/prompt_tune/generator/community_report_rating.py b/graphrag/prompt_tune/generator/community_report_rating.py index 59f94d5698..23d7cc6832 100644 --- a/graphrag/prompt_tune/generator/community_report_rating.py +++ b/graphrag/prompt_tune/generator/community_report_rating.py @@ -4,7 +4,7 @@ # Licensed under the MIT License from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.prompt import ( +from graphrag.prompt_tune.prompt.community_report_rating import ( GENERATE_REPORT_RATING_PROMPT, ) diff --git a/graphrag/prompt_tune/generator/community_report_summarization.py b/graphrag/prompt_tune/generator/community_report_summarization.py index b0c0b614d2..4d2d0da846 100644 --- a/graphrag/prompt_tune/generator/community_report_summarization.py +++ b/graphrag/prompt_tune/generator/community_report_summarization.py @@ -5,7 +5,9 @@ from pathlib import Path -from graphrag.prompt_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT +from graphrag.prompt_tune.template.community_report_summarization import ( + COMMUNITY_REPORT_SUMMARIZATION_PROMPT, +) COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt" diff --git a/graphrag/prompt_tune/generator/community_reporter_role.py b/graphrag/prompt_tune/generator/community_reporter_role.py index 9abd5ed83f..f16a6c3dd4 100644 --- a/graphrag/prompt_tune/generator/community_reporter_role.py +++ b/graphrag/prompt_tune/generator/community_reporter_role.py @@ -4,7 +4,7 @@ """Generate a community reporter role for community summarization.""" from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.prompt import ( +from graphrag.prompt_tune.prompt.community_reporter_role import ( GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT, ) diff --git a/graphrag/prompt_tune/generator/defaults.py b/graphrag/prompt_tune/generator/defaults.py deleted file mode 100644 index 5b42f81332..0000000000 --- a/graphrag/prompt_tune/generator/defaults.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Default values for the fine-tuning module.""" - -DEFAULT_TASK = """ -Identify the relations and structure of the community of interest, specifically within the {domain} domain. -""" - -MAX_TOKEN_COUNT = 2000 diff --git a/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/graphrag/prompt_tune/generator/entity_extraction_prompt.py index 806e310915..a170642995 100644 --- a/graphrag/prompt_tune/generator/entity_extraction_prompt.py +++ b/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -7,7 +7,7 @@ import graphrag.config.defaults as defs from graphrag.index.utils.tokens import num_tokens_from_string -from graphrag.prompt_tune.template import ( +from graphrag.prompt_tune.template.entity_extraction import ( EXAMPLE_EXTRACTION_TEMPLATE, GRAPH_EXTRACTION_JSON_PROMPT, GRAPH_EXTRACTION_PROMPT, diff --git a/graphrag/prompt_tune/generator/entity_relationship.py b/graphrag/prompt_tune/generator/entity_relationship.py index 2733e3bd74..f8862bd6ef 100644 --- a/graphrag/prompt_tune/generator/entity_relationship.py +++ b/graphrag/prompt_tune/generator/entity_relationship.py @@ -7,7 +7,7 @@ import json from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.prompt import ( +from graphrag.prompt_tune.prompt.entity_relationship import ( ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, ENTITY_RELATIONSHIPS_GENERATION_PROMPT, UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, diff --git a/graphrag/prompt_tune/generator/entity_summarization_prompt.py b/graphrag/prompt_tune/generator/entity_summarization_prompt.py index 736df830d6..979e5b0a6a 100644 --- a/graphrag/prompt_tune/generator/entity_summarization_prompt.py +++ b/graphrag/prompt_tune/generator/entity_summarization_prompt.py @@ -5,7 +5,9 @@ from pathlib import Path -from graphrag.prompt_tune.template import ENTITY_SUMMARIZATION_PROMPT +from graphrag.prompt_tune.template.entity_summarization import ( + ENTITY_SUMMARIZATION_PROMPT, +) ENTITY_SUMMARIZATION_FILENAME = "summarize_descriptions.txt" diff --git a/graphrag/prompt_tune/generator/entity_types.py b/graphrag/prompt_tune/generator/entity_types.py index 42518acd8c..51ac0020e0 100644 --- a/graphrag/prompt_tune/generator/entity_types.py +++ b/graphrag/prompt_tune/generator/entity_types.py @@ -4,7 +4,7 @@ """Entity type generation module for fine-tuning.""" from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK +from graphrag.prompt_tune.defaults import DEFAULT_TASK from graphrag.prompt_tune.prompt.entity_types import ( ENTITY_TYPE_GENERATION_JSON_PROMPT, ENTITY_TYPE_GENERATION_PROMPT, diff --git a/graphrag/prompt_tune/generator/language.py b/graphrag/prompt_tune/generator/language.py index 38de531ca3..d803df9c54 100644 --- a/graphrag/prompt_tune/generator/language.py +++ b/graphrag/prompt_tune/generator/language.py @@ -4,7 +4,7 @@ """Language detection for GraphRAG prompts.""" from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.prompt import DETECT_LANGUAGE_PROMPT +from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT async def detect_language(llm: CompletionLLM, docs: str | list[str]) -> str: diff --git a/graphrag/prompt_tune/generator/persona.py b/graphrag/prompt_tune/generator/persona.py index cdd57a655d..c66cc4a717 100644 --- a/graphrag/prompt_tune/generator/persona.py +++ b/graphrag/prompt_tune/generator/persona.py @@ -4,8 +4,8 @@ """Persona generating module for fine-tuning GraphRAG prompts.""" from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK -from graphrag.prompt_tune.prompt import GENERATE_PERSONA_PROMPT +from graphrag.prompt_tune.defaults import DEFAULT_TASK +from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT async def generate_persona( diff --git a/graphrag/prompt_tune/loader/__init__.py b/graphrag/prompt_tune/loader/__init__.py index bc8026e92d..7c7a6f88ff 100644 --- a/graphrag/prompt_tune/loader/__init__.py +++ b/graphrag/prompt_tune/loader/__init__.py @@ -2,11 +2,3 @@ # Licensed under the MIT License """Fine-tuning config and data loader module.""" - -from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks - -__all__ = [ - "MIN_CHUNK_OVERLAP", - "MIN_CHUNK_SIZE", - "load_docs_in_chunks", -] diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 5fd5719666..eb8ad8f8f6 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -9,18 +9,19 @@ import graphrag.config.defaults as defs from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.input import load_input -from graphrag.index.llm import load_llm_embeddings +from graphrag.index.input.load_input import load_input +from graphrag.index.llm.load_llm import load_llm_embeddings from graphrag.index.operations.chunk_text import chunk_text from graphrag.llm.types.llm_types import EmbeddingLLM -from graphrag.logging import ProgressReporter +from graphrag.logging.base import ProgressReporter +from graphrag.prompt_tune.defaults import ( + MIN_CHUNK_OVERLAP, + MIN_CHUNK_SIZE, + N_SUBSET_MAX, + K, +) from graphrag.prompt_tune.types import DocSelectionType -MIN_CHUNK_OVERLAP = 0 -MIN_CHUNK_SIZE = 200 -N_SUBSET_MAX = 300 -K = 15 - async def _embed_chunks( text_chunks: pd.DataFrame, diff --git a/graphrag/prompt_tune/prompt/__init__.py b/graphrag/prompt_tune/prompt/__init__.py index 991d52856e..497c56dda5 100644 --- a/graphrag/prompt_tune/prompt/__init__.py +++ b/graphrag/prompt_tune/prompt/__init__.py @@ -1,32 +1,4 @@ -"""Persona, entity type, relationships and domain generation prompts module.""" - # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from .community_report_rating import GENERATE_REPORT_RATING_PROMPT -from .community_reporter_role import GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT -from .domain import GENERATE_DOMAIN_PROMPT -from .entity_relationship import ( - ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, - ENTITY_RELATIONSHIPS_GENERATION_PROMPT, - UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, -) -from .entity_types import ( - ENTITY_TYPE_GENERATION_JSON_PROMPT, - ENTITY_TYPE_GENERATION_PROMPT, -) -from .language import DETECT_LANGUAGE_PROMPT -from .persona import GENERATE_PERSONA_PROMPT - -__all__ = [ - "DETECT_LANGUAGE_PROMPT", - "ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT", - "ENTITY_RELATIONSHIPS_GENERATION_PROMPT", - "ENTITY_TYPE_GENERATION_JSON_PROMPT", - "ENTITY_TYPE_GENERATION_PROMPT", - "GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT", - "GENERATE_DOMAIN_PROMPT", - "GENERATE_PERSONA_PROMPT", - "GENERATE_REPORT_RATING_PROMPT", - "UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT", -] +"""Persona, entity type, relationships and domain generation prompts module.""" diff --git a/graphrag/prompt_tune/template/__init__.py b/graphrag/prompt_tune/template/__init__.py index e056762ff7..f830ce2a9e 100644 --- a/graphrag/prompt_tune/template/__init__.py +++ b/graphrag/prompt_tune/template/__init__.py @@ -2,23 +2,3 @@ # Licensed under the MIT License """Fine-tuning prompts for entity extraction, entity summarization, and community report summarization.""" - -from .community_report_summarization import COMMUNITY_REPORT_SUMMARIZATION_PROMPT -from .entity_extraction import ( - EXAMPLE_EXTRACTION_TEMPLATE, - GRAPH_EXTRACTION_JSON_PROMPT, - GRAPH_EXTRACTION_PROMPT, - UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE, - UNTYPED_GRAPH_EXTRACTION_PROMPT, -) -from .entity_summarization import ENTITY_SUMMARIZATION_PROMPT - -__all__ = [ - "COMMUNITY_REPORT_SUMMARIZATION_PROMPT", - "ENTITY_SUMMARIZATION_PROMPT", - "EXAMPLE_EXTRACTION_TEMPLATE", - "GRAPH_EXTRACTION_JSON_PROMPT", - "GRAPH_EXTRACTION_PROMPT", - "UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE", - "UNTYPED_GRAPH_EXTRACTION_PROMPT", -] diff --git a/graphrag/query/__init__.py b/graphrag/query/__init__.py index 58a557f8a2..effd81e123 100644 --- a/graphrag/query/__init__.py +++ b/graphrag/query/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""GraphRAG Orchestration Module.""" +"""The query engine package root.""" diff --git a/graphrag/query/context_builder/community_context.py b/graphrag/query/context_builder/community_context.py index b9357b9854..287edb48e2 100644 --- a/graphrag/query/context_builder/community_context.py +++ b/graphrag/query/context_builder/community_context.py @@ -10,7 +10,8 @@ import pandas as pd import tiktoken -from graphrag.model import CommunityReport, Entity +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity from graphrag.query.llm.text_utils import num_tokens log = logging.getLogger(__name__) diff --git a/graphrag/query/context_builder/dynamic_community_selection.py b/graphrag/query/context_builder/dynamic_community_selection.py index 12a92ef7f0..a17cfdc27a 100644 --- a/graphrag/query/context_builder/dynamic_community_selection.py +++ b/graphrag/query/context_builder/dynamic_community_selection.py @@ -12,7 +12,8 @@ import tiktoken -from graphrag.model import Community, CommunityReport +from graphrag.model.community import Community +from graphrag.model.community_report import CommunityReport from graphrag.query.context_builder.rate_prompt import RATE_QUERY from graphrag.query.context_builder.rate_relevancy import rate_relevancy from graphrag.query.llm.base import BaseLLM diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index f7e1fbfe18..dd2ec63c27 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -5,14 +5,15 @@ from enum import Enum -from graphrag.model import Entity, Relationship +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship from graphrag.query.input.retrieval.entities import ( get_entity_by_id, get_entity_by_key, get_entity_by_name, ) from graphrag.query.llm.base import BaseTextEmbedding -from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.base import BaseVectorStore class EntityVectorStoreKey(str, Enum): diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index 78522af5a9..7cb61e3e8c 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -9,7 +9,9 @@ import pandas as pd import tiktoken -from graphrag.model import Covariate, Entity, Relationship +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship from graphrag.query.input.retrieval.covariates import ( get_candidate_covariates, to_covariate_dataframe, diff --git a/graphrag/query/context_builder/source_context.py b/graphrag/query/context_builder/source_context.py index 4d63db58e9..b8ba86aee4 100644 --- a/graphrag/query/context_builder/source_context.py +++ b/graphrag/query/context_builder/source_context.py @@ -9,7 +9,8 @@ import pandas as pd import tiktoken -from graphrag.model import Relationship, TextUnit +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.query.llm.text_utils import num_tokens """ diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 2e8a3afd4d..b854569b92 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -7,15 +7,13 @@ import tiktoken -from graphrag.config import GraphRagConfig -from graphrag.model import ( - Community, - CommunityReport, - Covariate, - Entity, - Relationship, - TextUnit, -) +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.model.community import Community +from graphrag.model.community_report import CommunityReport +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.llm.get_client import get_llm, get_text_embedder from graphrag.query.structured_search.drift_search.drift_context import ( @@ -30,7 +28,7 @@ LocalSearchMixedContext, ) from graphrag.query.structured_search.local_search.search import LocalSearch -from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.base import BaseVectorStore def get_local_search_engine( diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index 9bd73d6531..478fe385d2 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -13,14 +13,12 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.summarize_communities import restore_community_hierarchy -from graphrag.model import ( - Community, - CommunityReport, - Covariate, - Entity, - Relationship, - TextUnit, -) +from graphrag.model.community import Community +from graphrag.model.community_report import CommunityReport +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.query.factories import get_text_embedder from graphrag.query.input.loaders.dfs import ( read_communities, diff --git a/graphrag/query/input/loaders/dfs.py b/graphrag/query/input/loaders/dfs.py index f144ad8a47..55af4c2d3f 100644 --- a/graphrag/query/input/loaders/dfs.py +++ b/graphrag/query/input/loaders/dfs.py @@ -5,14 +5,12 @@ import pandas as pd -from graphrag.model import ( - Community, - CommunityReport, - Covariate, - Entity, - Relationship, - TextUnit, -) +from graphrag.model.community import Community +from graphrag.model.community_report import CommunityReport +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.query.input.loaders.utils import ( to_optional_dict, to_optional_float, @@ -21,7 +19,7 @@ to_optional_str, to_str, ) -from graphrag.vector_stores import BaseVectorStore, VectorStoreDocument +from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument def read_entities( diff --git a/graphrag/query/input/retrieval/community_reports.py b/graphrag/query/input/retrieval/community_reports.py index bd4933f1f9..7fb38b4f96 100644 --- a/graphrag/query/input/retrieval/community_reports.py +++ b/graphrag/query/input/retrieval/community_reports.py @@ -7,7 +7,8 @@ import pandas as pd -from graphrag.model import CommunityReport, Entity +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity def get_candidate_communities( diff --git a/graphrag/query/input/retrieval/covariates.py b/graphrag/query/input/retrieval/covariates.py index 1c45203d01..4ca5ba13f1 100644 --- a/graphrag/query/input/retrieval/covariates.py +++ b/graphrag/query/input/retrieval/covariates.py @@ -7,7 +7,8 @@ import pandas as pd -from graphrag.model import Covariate, Entity +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity def get_candidate_covariates( diff --git a/graphrag/query/input/retrieval/entities.py b/graphrag/query/input/retrieval/entities.py index 41c92fab31..0384ceaf90 100644 --- a/graphrag/query/input/retrieval/entities.py +++ b/graphrag/query/input/retrieval/entities.py @@ -9,7 +9,7 @@ import pandas as pd -from graphrag.model import Entity +from graphrag.model.entity import Entity def get_entity_by_id(entities: dict[str, Entity], value: str) -> Entity | None: diff --git a/graphrag/query/input/retrieval/relationships.py b/graphrag/query/input/retrieval/relationships.py index 2dec596ff3..86ddd9efd1 100644 --- a/graphrag/query/input/retrieval/relationships.py +++ b/graphrag/query/input/retrieval/relationships.py @@ -7,7 +7,8 @@ import pandas as pd -from graphrag.model import Entity, Relationship +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship def get_in_network_relationships( diff --git a/graphrag/query/input/retrieval/text_units.py b/graphrag/query/input/retrieval/text_units.py index a00dc20a0a..1a0305bc3d 100644 --- a/graphrag/query/input/retrieval/text_units.py +++ b/graphrag/query/input/retrieval/text_units.py @@ -7,7 +7,8 @@ import pandas as pd -from graphrag.model import Entity, TextUnit +from graphrag.model.entity import Entity +from graphrag.model.text_unit import TextUnit def get_candidate_text_units( diff --git a/graphrag/query/llm/get_client.py b/graphrag/query/llm/get_client.py index 12baf5d1cf..5b9dbfbbc2 100644 --- a/graphrag/query/llm/get_client.py +++ b/graphrag/query/llm/get_client.py @@ -5,7 +5,8 @@ from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from graphrag.config import GraphRagConfig, LLMType +from graphrag.config.enums import LLMType +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType diff --git a/graphrag/query/llm/oai/__init__.py b/graphrag/query/llm/oai/__init__.py index cbb257905e..910766bdfb 100644 --- a/graphrag/query/llm/oai/__init__.py +++ b/graphrag/query/llm/oai/__init__.py @@ -2,20 +2,3 @@ # Licensed under the MIT License """GraphRAG Orchestration OpenAI Wrappers.""" - -from .base import BaseOpenAILLM, OpenAILLMImpl, OpenAITextEmbeddingImpl -from .chat_openai import ChatOpenAI -from .embedding import OpenAIEmbedding -from .openai import OpenAI -from .typing import OPENAI_RETRY_ERROR_TYPES, OpenaiApiType - -__all__ = [ - "OPENAI_RETRY_ERROR_TYPES", - "BaseOpenAILLM", - "ChatOpenAI", - "OpenAI", - "OpenAIEmbedding", - "OpenAILLMImpl", - "OpenAITextEmbeddingImpl", - "OpenaiApiType", -] diff --git a/graphrag/query/llm/oai/base.py b/graphrag/query/llm/oai/base.py index 08a90d98a4..0bdea9c1b3 100644 --- a/graphrag/query/llm/oai/base.py +++ b/graphrag/query/llm/oai/base.py @@ -8,7 +8,8 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI -from graphrag.logging import ConsoleReporter, StatusLogger +from graphrag.logging.base import StatusLogger +from graphrag.logging.console import ConsoleReporter from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType diff --git a/graphrag/query/llm/oai/chat_openai.py b/graphrag/query/llm/oai/chat_openai.py index 621ebecebe..8daadd750b 100644 --- a/graphrag/query/llm/oai/chat_openai.py +++ b/graphrag/query/llm/oai/chat_openai.py @@ -15,7 +15,7 @@ wait_exponential_jitter, ) -from graphrag.logging import StatusLogger +from graphrag.logging.base import StatusLogger from graphrag.query.llm.base import BaseLLM, BaseLLMCallback from graphrag.query.llm.oai.base import OpenAILLMImpl from graphrag.query.llm.oai.typing import ( diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index 6b39a0017f..006a9588b6 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -18,7 +18,7 @@ wait_exponential_jitter, ) -from graphrag.logging import StatusLogger +from graphrag.logging.base import StatusLogger from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.base import OpenAILLMImpl from graphrag.query.llm.oai.typing import ( diff --git a/graphrag/query/structured_search/drift_search/drift_context.py b/graphrag/query/structured_search/drift_search/drift_context.py index 9da939d3c4..2ab19fc63b 100644 --- a/graphrag/query/structured_search/drift_search/drift_context.py +++ b/graphrag/query/structured_search/drift_search/drift_context.py @@ -12,13 +12,11 @@ import tiktoken from graphrag.config.models.drift_search_config import DRIFTSearchConfig -from graphrag.model import ( - CommunityReport, - Covariate, - Entity, - Relationship, - TextUnit, -) +from graphrag.model.community_report import CommunityReport +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.prompts.query.drift_search_system_prompt import ( DRIFT_LOCAL_SYSTEM_PROMPT, ) @@ -30,7 +28,7 @@ from graphrag.query.structured_search.local_search.mixed_context import ( LocalSearchMixedContext, ) -from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.base import BaseVectorStore log = logging.getLogger(__name__) diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 5d74ff8f0d..8f3895a25c 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -14,7 +14,7 @@ from tqdm.asyncio import tqdm_asyncio from graphrag.config.models.drift_search_config import DRIFTSearchConfig -from graphrag.model import CommunityReport +from graphrag.model.community_report import CommunityReport from graphrag.prompts.query.drift_search_system_prompt import ( DRIFT_PRIMER_PROMPT, ) diff --git a/graphrag/query/structured_search/global_search/community_context.py b/graphrag/query/structured_search/global_search/community_context.py index 0e22ad60c8..1bc17655bf 100644 --- a/graphrag/query/structured_search/global_search/community_context.py +++ b/graphrag/query/structured_search/global_search/community_context.py @@ -7,7 +7,9 @@ import tiktoken -from graphrag.model import Community, CommunityReport, Entity +from graphrag.model.community import Community +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity from graphrag.query.context_builder.builders import ContextBuilderResult from graphrag.query.context_builder.community_context import ( build_community_context, diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index ebebfd34aa..1dbf38e66b 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -9,13 +9,11 @@ import pandas as pd import tiktoken -from graphrag.model import ( - CommunityReport, - Covariate, - Entity, - Relationship, - TextUnit, -) +from graphrag.model.community_report import CommunityReport +from graphrag.model.covariate import Covariate +from graphrag.model.entity import Entity +from graphrag.model.relationship import Relationship +from graphrag.model.text_unit import TextUnit from graphrag.query.context_builder.builders import ContextBuilderResult from graphrag.query.context_builder.community_context import ( build_community_context, @@ -44,7 +42,7 @@ from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import LocalContextBuilder -from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.base import BaseVectorStore log = logging.getLogger(__name__) diff --git a/graphrag/utils/storage.py b/graphrag/utils/storage.py index 60d08b6309..8b2e0ce2ed 100644 --- a/graphrag/utils/storage.py +++ b/graphrag/utils/storage.py @@ -13,7 +13,7 @@ PipelineFileStorageConfig, PipelineStorageConfigTypes, ) -from graphrag.index.storage import load_storage +from graphrag.index.storage.load_storage import load_storage from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/vector_stores/__init__.py b/graphrag/vector_stores/__init__.py index 560db06349..c1d54b741b 100644 --- a/graphrag/vector_stores/__init__.py +++ b/graphrag/vector_stores/__init__.py @@ -2,18 +2,3 @@ # Licensed under the MIT License """A module containing vector storage implementations.""" - -from graphrag.vector_stores.base import ( - BaseVectorStore, - VectorStoreDocument, - VectorStoreSearchResult, -) -from graphrag.vector_stores.factory import VectorStoreFactory, VectorStoreType - -__all__ = [ - "BaseVectorStore", - "VectorStoreDocument", - "VectorStoreFactory", - "VectorStoreSearchResult", - "VectorStoreType", -] diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index f35e41dfcb..eebf2fa05e 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -25,8 +25,7 @@ from azure.search.documents.models import VectorizedQuery from graphrag.model.types import TextEmbedder - -from .base import ( +from graphrag.vector_stores.base import ( DEFAULT_VECTOR_SIZE, BaseVectorStore, VectorStoreDocument, diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 564533bacb..174a5c98d0 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -6,8 +6,8 @@ from enum import Enum from typing import ClassVar -from .azure_ai_search import AzureAISearch -from .lancedb import LanceDBVectorStore +from graphrag.vector_stores.azure_ai_search import AzureAISearch +from graphrag.vector_stores.lancedb import LanceDBVectorStore class VectorStoreType(str, Enum): diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 3cc3ea20e5..b334f4753b 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -10,7 +10,7 @@ from graphrag.model.types import TextEmbedder -from .base import ( +from graphrag.vector_stores.base import ( BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult, diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index 6e57ce3bb8..3c636f5cd3 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -13,64 +13,82 @@ from pydantic import ValidationError import graphrag.config.defaults as defs -from graphrag.config import ( +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.enums import ( + CacheType, + InputFileType, + InputType, + ReportingType, + StorageType, +) +from graphrag.config.errors import ( ApiKeyMissingError, AzureApiBaseMissingError, AzureDeploymentNameMissingError, - CacheConfig, - CacheConfigInput, - CacheType, - ChunkingConfig, - ChunkingConfigInput, - ClaimExtractionConfig, +) +from graphrag.config.input_models.cache_config_input import CacheConfigInput +from graphrag.config.input_models.chunking_config_input import ChunkingConfigInput +from graphrag.config.input_models.claim_extraction_config_input import ( ClaimExtractionConfigInput, - ClusterGraphConfig, +) +from graphrag.config.input_models.cluster_graph_config_input import ( ClusterGraphConfigInput, - CommunityReportsConfig, +) +from graphrag.config.input_models.community_reports_config_input import ( CommunityReportsConfigInput, - DRIFTSearchConfig, - EmbedGraphConfig, - EmbedGraphConfigInput, - EntityExtractionConfig, +) +from graphrag.config.input_models.embed_graph_config_input import EmbedGraphConfigInput +from graphrag.config.input_models.entity_extraction_config_input import ( EntityExtractionConfigInput, - GlobalSearchConfig, - GraphRagConfig, - GraphRagConfigInput, - InputConfig, - InputConfigInput, - InputFileType, - InputType, - LLMParameters, - LLMParametersInput, - LocalSearchConfig, - ParallelizationParameters, - ReportingConfig, - ReportingConfigInput, - ReportingType, - SnapshotsConfig, - SnapshotsConfigInput, - StorageConfig, - StorageConfigInput, - StorageType, - SummarizeDescriptionsConfig, +) +from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput +from graphrag.config.input_models.input_config_input import InputConfigInput +from graphrag.config.input_models.llm_parameters_input import LLMParametersInput +from graphrag.config.input_models.reporting_config_input import ReportingConfigInput +from graphrag.config.input_models.snapshots_config_input import SnapshotsConfigInput +from graphrag.config.input_models.storage_config_input import StorageConfigInput +from graphrag.config.input_models.summarize_descriptions_config_input import ( SummarizeDescriptionsConfigInput, - TextEmbeddingConfig, +) +from graphrag.config.input_models.text_embedding_config_input import ( TextEmbeddingConfigInput, - UmapConfig, - UmapConfigInput, - create_graphrag_config, ) -from graphrag.index import ( - PipelineConfig, +from graphrag.config.input_models.umap_config_input import UmapConfigInput +from graphrag.config.models.cache_config import CacheConfig +from graphrag.config.models.chunking_config import ChunkingConfig +from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig +from graphrag.config.models.cluster_graph_config import ClusterGraphConfig +from graphrag.config.models.community_reports_config import CommunityReportsConfig +from graphrag.config.models.drift_search_config import DRIFTSearchConfig +from graphrag.config.models.embed_graph_config import EmbedGraphConfig +from graphrag.config.models.entity_extraction_config import EntityExtractionConfig +from graphrag.config.models.global_search_config import GlobalSearchConfig +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.llm_parameters import LLMParameters +from graphrag.config.models.local_search_config import LocalSearchConfig +from graphrag.config.models.parallelization_parameters import ParallelizationParameters +from graphrag.config.models.reporting_config import ReportingConfig +from graphrag.config.models.snapshots_config import SnapshotsConfig +from graphrag.config.models.storage_config import StorageConfig +from graphrag.config.models.summarize_descriptions_config import ( + SummarizeDescriptionsConfig, +) +from graphrag.config.models.text_embedding_config import TextEmbeddingConfig +from graphrag.config.models.umap_config import UmapConfig +from graphrag.index.config.cache import PipelineFileCacheConfig +from graphrag.index.config.input import ( PipelineCSVInputConfig, - PipelineFileCacheConfig, - PipelineFileReportingConfig, - PipelineFileStorageConfig, PipelineInputConfig, PipelineTextInputConfig, +) +from graphrag.index.config.pipeline import ( + PipelineConfig, PipelineWorkflowReference, - create_pipeline_config, ) +from graphrag.index.config.reporting import PipelineFileReportingConfig +from graphrag.index.config.storage import PipelineFileStorageConfig +from graphrag.index.create_pipeline_config import create_pipeline_config current_dir = os.path.dirname(__file__) diff --git a/tests/unit/indexing/cache/test_file_pipeline_cache.py b/tests/unit/indexing/cache/test_file_pipeline_cache.py index ada3239602..ff63056edf 100644 --- a/tests/unit/indexing/cache/test_file_pipeline_cache.py +++ b/tests/unit/indexing/cache/test_file_pipeline_cache.py @@ -4,9 +4,7 @@ import os import unittest -from graphrag.index.cache import ( - JsonPipelineCache, -) +from graphrag.index.cache.json_pipeline_cache import JsonPipelineCache from graphrag.index.storage.file_pipeline_storage import ( FilePipelineStorage, ) diff --git a/tests/unit/indexing/config/helpers.py b/tests/unit/indexing/config/helpers.py index 580d0a9fa7..f70b9af81e 100644 --- a/tests/unit/indexing/config/helpers.py +++ b/tests/unit/indexing/config/helpers.py @@ -4,8 +4,8 @@ import unittest from typing import Any -from graphrag.config import create_graphrag_config -from graphrag.index import PipelineConfig, create_pipeline_config +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.create_pipeline_config import PipelineConfig, create_pipeline_config def assert_contains_default_config( diff --git a/tests/unit/indexing/config/test_load.py b/tests/unit/indexing/config/test_load.py index 78f6a93a0a..636525b320 100644 --- a/tests/unit/indexing/config/test_load.py +++ b/tests/unit/indexing/config/test_load.py @@ -7,12 +7,10 @@ from typing import Any from unittest import mock -from graphrag.config import create_graphrag_config -from graphrag.index import ( - PipelineConfig, - create_pipeline_config, - load_pipeline_config, -) +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.config.pipeline import PipelineConfig +from graphrag.index.create_pipeline_config import create_pipeline_config +from graphrag.index.load_pipeline_config import load_pipeline_config current_dir = os.path.dirname(__file__) diff --git a/tests/unit/indexing/test_exports.py b/tests/unit/indexing/test_exports.py index 232dfbbdf3..ee2b23e622 100644 --- a/tests/unit/indexing/test_exports.py +++ b/tests/unit/indexing/test_exports.py @@ -1,10 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index import ( - create_pipeline_config, - run_pipeline, - run_pipeline_with_config, -) +from graphrag.index.create_pipeline_config import create_pipeline_config +from graphrag.index.run import run_pipeline, run_pipeline_with_config def test_exported_functions(): diff --git a/tests/unit/indexing/test_init_content.py b/tests/unit/indexing/test_init_content.py index 5001a2f9e5..8e6d8d3fdb 100644 --- a/tests/unit/indexing/test_init_content.py +++ b/tests/unit/indexing/test_init_content.py @@ -6,11 +6,9 @@ import yaml -from graphrag.config import ( - GraphRagConfig, - create_graphrag_config, -) +from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.init_content import INIT_YAML +from graphrag.config.models.graph_rag_config import GraphRagConfig def test_init_yaml(): diff --git a/tests/unit/indexing/workflows/test_emit.py b/tests/unit/indexing/workflows/test_emit.py index 5c16f66cbf..2d17bb199b 100644 --- a/tests/unit/indexing/workflows/test_emit.py +++ b/tests/unit/indexing/workflows/test_emit.py @@ -11,9 +11,10 @@ create_verb_result, ) -from graphrag.index.config import PipelineWorkflowReference +from graphrag.index.config.pipeline import PipelineWorkflowReference from graphrag.index.run import run_pipeline -from graphrag.index.storage import MemoryPipelineStorage, PipelineStorage +from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage async def mock_verb( diff --git a/tests/unit/indexing/workflows/test_load.py b/tests/unit/indexing/workflows/test_load.py index 6d037d51a5..60ae6647b4 100644 --- a/tests/unit/indexing/workflows/test_load.py +++ b/tests/unit/indexing/workflows/test_load.py @@ -4,7 +4,7 @@ import pytest -from graphrag.index.config import PipelineWorkflowReference +from graphrag.index.config.pipeline import PipelineWorkflowReference from graphrag.index.errors import UnknownWorkflowError from graphrag.index.workflows.load import create_workflow, load_workflows diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py index 969a16ff0a..b796bd9f33 100644 --- a/tests/unit/query/context_builder/test_entity_extraction.py +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -3,14 +3,14 @@ from typing import Any -from graphrag.model import Entity +from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder from graphrag.query.context_builder.entity_extraction import ( EntityVectorStoreKey, map_query_to_entities, ) from graphrag.query.llm.base import BaseTextEmbedding -from graphrag.vector_stores import ( +from graphrag.vector_stores.base import ( BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult, diff --git a/tests/unit/query/input/retrieval/test_entities.py b/tests/unit/query/input/retrieval/test_entities.py index a66e3432b9..f7175882e2 100644 --- a/tests/unit/query/input/retrieval/test_entities.py +++ b/tests/unit/query/input/retrieval/test_entities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.model import Entity +from graphrag.model.entity import Entity from graphrag.query.input.retrieval.entities import ( get_entity_by_id, get_entity_by_key, diff --git a/tests/verbs/util.py b/tests/verbs/util.py index c53c4e5be4..6ce7b478dc 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -7,12 +7,10 @@ from datashaper import Workflow from pandas.testing import assert_series_equal -from graphrag.config import create_graphrag_config -from graphrag.index import ( - PipelineWorkflowConfig, - create_pipeline_config, -) +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.config.workflow import PipelineWorkflowConfig from graphrag.index.context import PipelineRunContext +from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.index.run.utils import create_run_context pd.set_option("display.max_columns", None)