diff --git a/docs/examples_notebooks/api_overview.ipynb b/docs/examples_notebooks/api_overview.ipynb index 86b0184b29..06187a5771 100644 --- a/docs/examples_notebooks/api_overview.ipynb +++ b/docs/examples_notebooks/api_overview.ipynb @@ -16,7 +16,7 @@ "source": [ "## API Overview\n", "\n", - "This notebook provides a demonstration of how to interact with graphrag as a library using the API as opposed to the CLI. Note that graphrag's CLI actually connects to the library through this API for all operations. " + "This notebook provides a demonstration of how to interact with graphrag as a library using the API as opposed to the CLI. Note that graphrag's CLI actually connects to the library through this API for all operations.\n" ] }, { @@ -48,16 +48,17 @@ "metadata": {}, "source": [ "## Prerequisite\n", + "\n", "As a prerequisite to all API operations, a `GraphRagConfig` object is required. It is the primary means to control the behavior of graphrag and can be instantiated from a `settings.yaml` configuration file.\n", "\n", - "Please refer to the [CLI docs](https://microsoft.github.io/graphrag/cli/#init) for more detailed information on how to generate the `settings.yaml` file." + "Please refer to the [CLI docs](https://microsoft.github.io/graphrag/cli/#init) for more detailed information on how to generate the `settings.yaml` file.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Generate a `GraphRagConfig` object" + "### Generate a `GraphRagConfig` object\n" ] }, { @@ -77,14 +78,14 @@ "source": [ "## Indexing API\n", "\n", - "*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats." + "_Indexing_ is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Build an index" + "## Build an index\n" ] }, { @@ -107,7 +108,7 @@ "source": [ "## Query an index\n", "\n", - "To query an index, several index files must first be read into memory and passed to the query API. " + "To query an index, several index files must first be read into memory and passed to the query API.\n" ] }, { @@ -138,7 +139,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response." + "The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response.\n" ] }, { @@ -154,7 +155,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)." + "Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model).\n" ] }, { diff --git a/docs/examples_notebooks/input_documents.ipynb b/docs/examples_notebooks/input_documents.ipynb index 73e51780d8..b9af6075ab 100644 --- a/docs/examples_notebooks/input_documents.ipynb +++ b/docs/examples_notebooks/input_documents.ipynb @@ -18,7 +18,7 @@ "\n", "Newer versions of GraphRAG let you submit a dataframe directly instead of running through the input processing step. This notebook demonstrates with regular or update runs.\n", "\n", - "If performing an update, the assumption is that your dataframe contains only the new documents to add to the index." + "If performing an update, the assumption is that your dataframe contains only the new documents to add to the index.\n" ] }, { @@ -54,7 +54,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Generate a `GraphRagConfig` object" + "### Generate a `GraphRagConfig` object\n" ] }, { @@ -72,14 +72,14 @@ "source": [ "## Indexing API\n", "\n", - "*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats." + "_Indexing_ is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Build an index" + "## Build an index\n" ] }, { @@ -109,7 +109,7 @@ "source": [ "## Query an index\n", "\n", - "To query an index, several index files must first be read into memory and passed to the query API. " + "To query an index, several index files must first be read into memory and passed to the query API.\n" ] }, { @@ -140,7 +140,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response." + "The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response.\n" ] }, { @@ -156,7 +156,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)." + "Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model).\n" ] }, { diff --git a/docs/get_started.md b/docs/get_started.md index 5de773924d..9a6f496383 100644 --- a/docs/get_started.md +++ b/docs/get_started.md @@ -10,40 +10,59 @@ The following is a simple end-to-end example for using GraphRAG on the command l It shows how to use the system to index some text, and then use the indexed data to answer questions about the documents. -# Install GraphRAG +## Install GraphRAG + +To get started, create a project space and python virtual environment to install `graphrag`. + +### Create Project Space ```bash -pip install graphrag +mkdir graphrag_quickstart +cd graphrag_quickstart +python -m venv .venv ``` +### Activate Python Virtual Environment - Unix/MacOS -# Running the Indexer -We need to set up a data project and some initial configuration. First let's get a sample dataset ready: +```bash +source .venv/bin/activate +``` -```sh -mkdir -p ./christmas/input +### Activate Python Virtual Environment - Windows + +```bash +.venv\Scripts\activate ``` -Get a copy of A Christmas Carol by Charles Dickens from a trusted source: +### Install GraphRAG -```sh -curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./christmas/input/book.txt +```bash +python -m pip install graphrag ``` -## Set Up Your Workspace Variables +### Initialize GraphRAG To initialize your workspace, first run the `graphrag init` command. -Since we have already configured a directory named `./christmas` in the previous step, run the following command: ```sh -graphrag init --root ./christmas +graphrag init ``` -This will create two files: `.env` and `settings.yaml` in the `./christmas` directory. +This will create two files, `.env` and `settings.yaml`, and a directory `input`, in the current directory. +- `input` Location of text files to process with `graphrag`. - `.env` contains the environment variables required to run the GraphRAG pipeline. If you inspect the file, you'll see a single environment variable defined, `GRAPHRAG_API_KEY=`. Replace `` with your own OpenAI or Azure API key. - `settings.yaml` contains the settings for the pipeline. You can modify this file to change the settings for the pipeline. -
+ +### Download Sample Text + +Get a copy of A Christmas Carol by Charles Dickens from a trusted source: + +```sh +curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./input/book.txt +``` + +## Set Up Workspace Variables ### Using OpenAI @@ -56,13 +75,14 @@ In addition to setting your API key, Azure OpenAI users should set the variables ```yaml type: chat model_provider: azure +model: gpt-4.1 +deployment_name: api_base: https://.openai.azure.com api_version: 2024-02-15-preview # You can customize this for other versions ``` -Most people tend to name their deployments the same as their model - if yours are different, add the `deployment_name` as well. - #### Using Managed Auth on Azure + To use managed auth, edit the auth_type in your model config and *remove* the api_key line: ```yaml @@ -71,38 +91,34 @@ auth_type: azure_managed_identity # Default auth_type is is api_key You will also need to login with [az login](https://learn.microsoft.com/en-us/cli/azure/authenticate-azure-cli) and select the subscription with your endpoint. -## Running the Indexing pipeline +## Index -Now we're ready to run the pipeline! +Now we're ready to index! ```sh -graphrag index --root ./christmas +graphrag index ``` ![pipeline executing from the CLI](img/pipeline-running.png) -This process will usually take a few minutes to run. Once the pipeline is complete, you should see a new folder called `./christmas/output` with a series of parquet files. +This process will usually take a few minutes to run. Once the pipeline is complete, you should see a new folder called `./output` with a series of parquet files. -# Using the Query Engine +# Query Now let's ask some questions using this dataset. Here is an example using Global search to ask a high-level question: ```sh -graphrag query \ ---root ./christmas \ ---method global \ ---query "What are the top themes in this story?" +graphrag query "What are the top themes in this story?" ``` Here is an example using Local search to ask a more specific question about a particular character: ```sh graphrag query \ ---root ./christmas \ ---method local \ ---query "Who is Scrooge and what are his main relationships?" +"Who is Scrooge and what are his main relationships?" \ +--method local ``` Please refer to [Query Engine](query/overview.md) docs for detailed information about how to leverage our Local and Global search mechanisms for extracting meaningful insights from data after the Indexer has wrapped up execution. diff --git a/docs/index/byog.md b/docs/index/byog.md index 7866e4b8f6..acd7679348 100644 --- a/docs/index/byog.md +++ b/docs/index/byog.md @@ -65,4 +65,4 @@ Putting it all together: - `output`: Create an output folder and put your entities and relationships (and optionally text_units) parquet files in it. - Update your config as noted above to only run the workflows subset you need. -- Run `graphrag index --root ` \ No newline at end of file +- Run `graphrag index --root ` \ No newline at end of file diff --git a/docs/prompt_tuning/auto_prompt_tuning.md b/docs/prompt_tuning/auto_prompt_tuning.md index 440741f571..eb77044cd5 100644 --- a/docs/prompt_tuning/auto_prompt_tuning.md +++ b/docs/prompt_tuning/auto_prompt_tuning.md @@ -20,16 +20,14 @@ Before running auto tuning, ensure you have already initialized your workspace w You can run the main script from the command line with various options: ```bash -graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selection-method METHOD] [--limit LIMIT] [--language LANGUAGE] \ +graphrag prompt-tune [--root ROOT] [--domain DOMAIN] [--selection-method METHOD] [--limit LIMIT] [--language LANGUAGE] \ [--max-tokens MAX_TOKENS] [--chunk-size CHUNK_SIZE] [--n-subset-max N_SUBSET_MAX] [--k K] \ [--min-examples-required MIN_EXAMPLES_REQUIRED] [--discover-entity-types] [--output OUTPUT] ``` ## Command-Line Options -- `--config` (required): The path to the configuration file. This is required to load the data and model settings. - -- `--root` (optional): The data project root directory, including the config files (YML, JSON, or .env). Defaults to the current directory. +- `--root` (optional): Path to the project directory that contains the config file (settings.yaml). Defaults to the current directory. - `--domain` (optional): The domain related to your input data, such as 'space science', 'microbiology', or 'environmental news'. If left empty, the domain will be inferred from the input data. @@ -56,7 +54,7 @@ graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selec ## Example Usage ```bash -python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --domain "environmental news" \ +python -m graphrag prompt-tune --root /path/to/project --domain "environmental news" \ --selection-method random --limit 10 --language English --max-tokens 2048 --chunk-size 256 --min-examples-required 3 \ --no-discover-entity-types --output /path/to/output ``` @@ -64,7 +62,7 @@ python -m graphrag prompt-tune --root /path/to/project --config /path/to/setting or, with minimal configuration (suggested): ```bash -python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-discover-entity-types +python -m graphrag prompt-tune --root /path/to/project --no-discover-entity-types ``` ## Document Selection Methods diff --git a/packages/graphrag-common/README.md b/packages/graphrag-common/README.md index 4447dc7219..45565a33a7 100644 --- a/packages/graphrag-common/README.md +++ b/packages/graphrag-common/README.md @@ -48,4 +48,72 @@ single2 = factory.create("some_other_strategy", {"value": "ignored"}) assert single1 is single2 assert single1.get_value() == "singleton" assert single2.get_value() == "singleton" +``` + +## Config module + +```python +from pydantic import BaseModel, Field +from graphrag_common.config import load_config + +from pathlib import Path + +class Logging(BaseModel): + """Test nested model.""" + + directory: str = Field(default="output/logs") + filename: str = Field(default="logs.txt") + +class Config(BaseModel): + """Test configuration model.""" + + name: str = Field(description="Name field.") + logging: Logging = Field(description="Nested model field.") + +# Basic - by default: +# - searches for Path.cwd() / settings.[yaml|yml|json] +# - sets the CWD to the directory containing the config file. +# so if no custom config path is provided than CWD remains unchanged. +# - loads config_directory/.env file +# - parses ${env} in the config file +config = load_config(Config) + +# Custom file location +config = load_config(Config, "path_to_config_filename_or_directory_containing_settings.[yaml|yml|json]") + +# Using a custom file extension with +# custom config parser (str) -> dict[str, Any] +config = load_config( + config_initializer=Config, + config_path="config.toml", + config_parser=lambda contents: toml.loads(contents) # Needs toml pypi package +) + +# With overrides - provided values override whats in the config file +# Only overrides what is specified - recursively merges settings. +config = load_config( + config_initializer=Config, + overrides={ + "name": "some name", + "logging": { + "filename": "my_logs.txt" + } + }, +) + +# By default, sets CWD to directory containing config file +# So custom config paths will change the CWD. +config = load_config( + config_initializer=Config, + config_path="some/path/to/config.yaml", + set_cwd=True # default +) + +# now cwd == some/path/to +assert Path.cwd() == "some/path/to" + +# And now throughout the codebase resolving relative paths in config +# will resolve relative to the config directory +Path(config.logging.directory) == "some/path/to/output/logs" + ``` \ No newline at end of file diff --git a/packages/graphrag-common/graphrag_common/config/__init__.py b/packages/graphrag-common/graphrag_common/config/__init__.py new file mode 100644 index 0000000000..71f24b79b7 --- /dev/null +++ b/packages/graphrag-common/graphrag_common/config/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GraphRAG config module.""" + +from graphrag_common.config.load_config import ConfigParsingError, load_config + +__all__ = ["ConfigParsingError", "load_config"] diff --git a/packages/graphrag-common/graphrag_common/config/load_config.py b/packages/graphrag-common/graphrag_common/config/load_config.py new file mode 100644 index 0000000000..c8929149f1 --- /dev/null +++ b/packages/graphrag-common/graphrag_common/config/load_config.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load configuration.""" + +import json +import os +from collections.abc import Callable +from pathlib import Path +from string import Template +from typing import Any, TypeVar + +import yaml +from dotenv import load_dotenv + +T = TypeVar("T", covariant=True) + +_default_config_files = ["settings.yaml", "settings.yml", "settings.json"] + + +class ConfigParsingError(ValueError): + """Configuration Parsing Error.""" + + def __init__(self, msg: str) -> None: + """Initialize the ConfigParsingError.""" + super().__init__(msg) + + +def _get_config_file_path(config_dir_or_file: Path) -> Path: + """Resolve the config path from the given directory or file.""" + config_dir_or_file = Path(config_dir_or_file) + + if config_dir_or_file.is_file(): + return config_dir_or_file + + if not config_dir_or_file.is_dir(): + msg = f"Invalid config path: {config_dir_or_file} is not a directory" + raise FileNotFoundError(msg) + + for file in _default_config_files: + if (config_dir_or_file / file).is_file(): + return config_dir_or_file / file + + msg = f"No 'settings.[yaml|yml|json]' config file found in directory: {config_dir_or_file}" + raise FileNotFoundError(msg) + + +def _load_dotenv(env_file_path: Path, required: bool) -> None: + """Load the .env file if it exists.""" + if not env_file_path.is_file(): + if not required: + return + msg = f"dot_env_path not found: {env_file_path}" + raise FileNotFoundError(msg) + load_dotenv(env_file_path) + + +def _parse_json(data: str) -> dict[str, Any]: + """Parse JSON data.""" + return json.loads(data) + + +def _parse_yaml(data: str) -> dict[str, Any]: + """Parse YAML data.""" + return yaml.safe_load(data) + + +def _get_parser_for_file(file_path: str | Path) -> Callable[[str], dict[str, Any]]: + """Get the parser for the given file path.""" + file_path = Path(file_path).resolve() + match file_path.suffix.lower(): + case ".json": + return _parse_json + case ".yaml" | ".yml": + return _parse_yaml + case _: + msg = ( + f"Failed to parse, {file_path}. Unsupported file extension, " + + f"{file_path.suffix}. Pass in a custom config_parser argument or " + + "use one of the supported file extensions, .json, .yaml, .yml, .toml." + ) + raise ConfigParsingError(msg) + + +def _parse_env_variables(text: str) -> str: + """Parse environment variables in the configuration text.""" + try: + return Template(text).substitute(os.environ) + except KeyError as error: + msg = f"Environment variable not found: {error}" + raise ConfigParsingError(msg) from error + + +def _recursive_merge_dicts(dest: dict[str, Any], src: dict[str, Any]) -> None: + """Recursively merge two dictionaries in place.""" + for key, value in src.items(): + if isinstance(value, dict): + if isinstance(dest.get(key), dict): + _recursive_merge_dicts(dest[key], value) + else: + dest[key] = value + else: + dest[key] = value + + +def load_config( + config_initializer: Callable[..., T], + config_path: str | Path | None = None, + overrides: dict[str, Any] | None = None, + set_cwd: bool = True, + parse_env_vars: bool = True, + load_dot_env_file: bool = True, + dot_env_path: str | Path | None = None, + config_parser: Callable[[str], dict[str, Any]] | None = None, + file_encoding: str = "utf-8", +) -> T: + """Load configuration from a file. + + Parameters + ---------- + config_initializer : Callable[..., T] + Configuration constructor/initializer. + Should accept **kwargs to initialize the configuration, + e.g., Config(**kwargs). + config_path : str | Path | None, optional (default=None) + Path to the configuration directory containing settings.[yaml|yml|json]. + Or path to a configuration file itself. + If None, search the current working directory for + settings.[yaml|yml|json]. + overrides : dict[str, Any] | None, optional (default=None) + Configuration overrides. + Useful for overriding configuration settings programmatically, + perhaps from CLI flags. + set_cwd : bool, optional (default=True) + Whether to set the current working directory to the directory + containing the configuration file. Helpful for resolving relative paths + in the configuration file. + parse_env_vars : bool, optional (default=True) + Whether to parse environment variables in the configuration text. + load_dot_env_file : bool, optional (default=True) + Whether to load the .env file prior to parsing environment variables. + dot_env_path : str | Path | None, optional (default=None) + Optional .env file to load prior to parsing env variables. + If None and load_dot_env_file is True, looks for a .env file in the + same directory as the config file. + config_parser : Callable[[str], dict[str, Any]] | None, optional (default=None) + function to parse the configuration text, (str) -> dict[str, Any]. + If None, the parser is inferred from the file extension. + Supported extensions: .json, .yaml, .yml. + file_encoding : str, optional (default="utf-8") + File encoding to use when reading the configuration file. + + Returns + ------- + T + The initialized configuration object. + + Raises + ------ + FileNotFoundError + - If the config file is not found. + - If the .env file is not found when parse_env_vars is True and dot_env_path is provided. + + ConfigParsingError + - If an environment variable is not found when parsing env variables. + - If there was a problem merging the overrides with the configuration. + - If parser=None and load_config was unable to determine how to parse + the file based on the file extension. + - If the parser fails to parse the configuration text. + """ + config_path = Path(config_path).resolve() if config_path else Path.cwd() + config_path = _get_config_file_path(config_path) + + file_contents = config_path.read_text(encoding=file_encoding) + + if parse_env_vars: + if load_dot_env_file: + required = dot_env_path is not None + dot_env_path = ( + Path(dot_env_path) if dot_env_path else config_path.parent / ".env" + ) + _load_dotenv(dot_env_path, required=required) + file_contents = _parse_env_variables(file_contents) + + if config_parser is None: + config_parser = _get_parser_for_file(config_path) + + config_data: dict[str, Any] = {} + try: + config_data = config_parser(file_contents) + except Exception as error: + msg = f"Failed to parse config_path: {config_path}. Error: {error}" + raise ConfigParsingError(msg) from error + + if overrides is not None: + try: + _recursive_merge_dicts(config_data, overrides) + except Exception as error: + msg = f"Failed to merge overrides with config_path: {config_path}. Error: {error}" + raise ConfigParsingError(msg) from error + + if set_cwd: + os.chdir(config_path.parent) + + return config_initializer(**config_data) diff --git a/packages/graphrag-common/pyproject.toml b/packages/graphrag-common/pyproject.toml index e8396e2cdb..b2bcb952e3 100644 --- a/packages/graphrag-common/pyproject.toml +++ b/packages/graphrag-common/pyproject.toml @@ -30,7 +30,10 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] -dependencies = [] +dependencies = [ + "python-dotenv>=1.0.1", + "pyyaml>=6.0.2", +] [project.urls] Source = "https://github.com/microsoft/graphrag" diff --git a/packages/graphrag/graphrag/api/query.py b/packages/graphrag/graphrag/api/query.py index e49a0976df..f573559bcb 100644 --- a/packages/graphrag/graphrag/api/query.py +++ b/packages/graphrag/graphrag/api/query.py @@ -167,13 +167,9 @@ def global_search_streaming( entities_ = read_indexer_entities( entities, communities, community_level=community_level ) - map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt) - reduce_prompt = load_search_prompt( - config.root_dir, config.global_search.reduce_prompt - ) - knowledge_prompt = load_search_prompt( - config.root_dir, config.global_search.knowledge_prompt - ) + map_prompt = load_search_prompt(config.global_search.map_prompt) + reduce_prompt = load_search_prompt(config.global_search.reduce_prompt) + knowledge_prompt = load_search_prompt(config.global_search.knowledge_prompt) logger.debug("Executing streaming global search query: %s", query) search_engine = get_global_search_engine( @@ -304,7 +300,7 @@ def local_search_streaming( entities_ = read_indexer_entities(entities, communities, community_level) covariates_ = read_indexer_covariates(covariates) if covariates is not None else [] - prompt = load_search_prompt(config.root_dir, config.local_search.prompt) + prompt = load_search_prompt(config.local_search.prompt) logger.debug("Executing streaming local search query: %s", query) search_engine = get_local_search_engine( @@ -435,10 +431,8 @@ def drift_search_streaming( entities_ = read_indexer_entities(entities, communities, community_level) reports = read_indexer_reports(community_reports, communities, community_level) read_indexer_report_embeddings(reports, full_content_embedding_store) - prompt = load_search_prompt(config.root_dir, config.drift_search.prompt) - reduce_prompt = load_search_prompt( - config.root_dir, config.drift_search.reduce_prompt - ) + prompt = load_search_prompt(config.drift_search.prompt) + reduce_prompt = load_search_prompt(config.drift_search.reduce_prompt) logger.debug("Executing streaming drift search query: %s", query) search_engine = get_drift_search_engine( @@ -538,7 +532,7 @@ def basic_search_streaming( embedding_name=text_unit_text_embedding, ) - prompt = load_search_prompt(config.root_dir, config.basic_search.prompt) + prompt = load_search_prompt(config.basic_search.prompt) logger.debug("Executing streaming basic search query: %s", query) search_engine = get_basic_search_engine( diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index 1f2c700a4e..971c22c6d5 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -28,11 +28,9 @@ class CacheFactory(Factory[PipelineCache]): # --- register built-in cache implementations --- -def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache: +def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" - # Create storage with base_dir in kwargs since FilePipelineStorage expects it there - storage_kwargs = {"base_dir": root_dir, **kwargs} - storage = FilePipelineStorage(**storage_kwargs).child(base_dir) + storage = FilePipelineStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/packages/graphrag/graphrag/cli/index.py b/packages/graphrag/graphrag/cli/index.py index b5464d2544..0a638f63e9 100644 --- a/packages/graphrag/graphrag/cli/index.py +++ b/packages/graphrag/graphrag/cli/index.py @@ -45,18 +45,11 @@ def index_cli( verbose: bool, memprofile: bool, cache: bool, - config_filepath: Path | None, dry_run: bool, skip_validation: bool, - output_dir: Path | None, ): """Run the pipeline with the given config.""" - cli_overrides = {} - if output_dir: - cli_overrides["output.base_dir"] = str(output_dir) - cli_overrides["reporting.base_dir"] = str(output_dir) - cli_overrides["update_index_output.base_dir"] = str(output_dir) - config = load_config(root_dir, config_filepath, cli_overrides) + config = load_config(root_dir=root_dir) _run_index( config=config, method=method, @@ -75,18 +68,12 @@ def update_cli( verbose: bool, memprofile: bool, cache: bool, - config_filepath: Path | None, skip_validation: bool, - output_dir: Path | None, ): """Run the pipeline with the given config.""" - cli_overrides = {} - if output_dir: - cli_overrides["output.base_dir"] = str(output_dir) - cli_overrides["reporting.base_dir"] = str(output_dir) - cli_overrides["update_index_output.base_dir"] = str(output_dir) - - config = load_config(root_dir, config_filepath, cli_overrides) + config = load_config( + root_dir=root_dir, + ) _run_index( config=config, diff --git a/packages/graphrag/graphrag/cli/initialize.py b/packages/graphrag/graphrag/cli/initialize.py index 09215f8c5d..8dbaf30f93 100644 --- a/packages/graphrag/graphrag/cli/initialize.py +++ b/packages/graphrag/graphrag/cli/initialize.py @@ -6,6 +6,7 @@ import logging from pathlib import Path +from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.init_content import INIT_DOTENV, INIT_YAML from graphrag.prompts.index.community_report import ( COMMUNITY_REPORT_PROMPT, @@ -51,26 +52,27 @@ def initialize_project_at(path: Path, force: bool) -> None: If the project already exists and force is False. """ logger.info("Initializing project at %s", path) - root = Path(path) - if not root.exists(): - root.mkdir(parents=True, exist_ok=True) + root = Path(path).resolve() + root.mkdir(parents=True, exist_ok=True) settings_yaml = root / "settings.yaml" if settings_yaml.exists() and not force: msg = f"Project already initialized at {root}" raise ValueError(msg) - with settings_yaml.open("wb") as file: - file.write(INIT_YAML.encode(encoding="utf-8", errors="strict")) + input_path = ( + root / (graphrag_config_defaults.input.storage.base_dir or "input") + ).resolve() + input_path.mkdir(parents=True, exist_ok=True) + + settings_yaml.write_text(INIT_YAML, encoding="utf-8", errors="strict") dotenv = root / ".env" if not dotenv.exists() or force: - with dotenv.open("wb") as file: - file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict")) + dotenv.write_text(INIT_DOTENV, encoding="utf-8", errors="strict") prompts_dir = root / "prompts" - if not prompts_dir.exists(): - prompts_dir.mkdir(parents=True, exist_ok=True) + prompts_dir.mkdir(parents=True, exist_ok=True) prompts = { "extract_graph": GRAPH_EXTRACTION_PROMPT, @@ -91,5 +93,4 @@ def initialize_project_at(path: Path, force: bool) -> None: for name, content in prompts.items(): prompt_file = prompts_dir / f"{name}.txt" if not prompt_file.exists() or force: - with prompt_file.open("wb") as file: - file.write(content.encode(encoding="utf-8", errors="strict")) + prompt_file.write_text(content, encoding="utf-8", errors="strict") diff --git a/packages/graphrag/graphrag/cli/main.py b/packages/graphrag/graphrag/cli/main.py index 7b4d3bb0cd..89768e736d 100644 --- a/packages/graphrag/graphrag/cli/main.py +++ b/packages/graphrag/graphrag/cli/main.py @@ -94,12 +94,13 @@ def completer(incomplete: str) -> list[str]: @app.command("init") def _initialize_cli( root: Path = typer.Option( - Path(), + Path.cwd(), "--root", "-r", help="The project root directory.", dir_okay=True, writable=True, + file_okay=False, resolve_path=True, autocompletion=ROOT_AUTOCOMPLETE, ), @@ -118,23 +119,14 @@ def _initialize_cli( @app.command("index") def _index_cli( - config: Path | None = typer.Option( - None, - "--config", - "-c", - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=CONFIG_AUTOCOMPLETE, - ), root: Path = typer.Option( - Path(), + Path.cwd(), "--root", "-r", help="The project root directory.", exists=True, dir_okay=True, + file_okay=False, writable=True, resolve_path=True, autocompletion=ROOT_AUTOCOMPLETE, @@ -174,18 +166,6 @@ def _index_cli( "--skip-validation", help="Skip any preflight validation. Useful when running no LLM steps.", ), - output: Path | None = typer.Option( - None, - "--output", - "-o", - help=( - "Indexing pipeline output directory. " - "Overrides output.base_dir in the configuration file." - ), - dir_okay=True, - writable=True, - resolve_path=True, - ), ) -> None: """Build a knowledge graph index.""" from graphrag.cli.index import index_cli @@ -195,33 +175,22 @@ def _index_cli( verbose=verbose, memprofile=memprofile, cache=cache, - config_filepath=config, dry_run=dry_run, skip_validation=skip_validation, - output_dir=output, method=method, ) @app.command("update") def _update_cli( - config: Path | None = typer.Option( - None, - "--config", - "-c", - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=CONFIG_AUTOCOMPLETE, - ), root: Path = typer.Option( - Path(), + Path.cwd(), "--root", "-r", help="The project root directory.", exists=True, dir_okay=True, + file_okay=False, writable=True, resolve_path=True, autocompletion=ROOT_AUTOCOMPLETE, @@ -253,18 +222,6 @@ def _update_cli( "--skip-validation", help="Skip any preflight validation. Useful when running no LLM steps.", ), - output: Path | None = typer.Option( - None, - "--output", - "-o", - help=( - "Indexing pipeline output directory. " - "Overrides output.base_dir in the configuration file." - ), - dir_okay=True, - writable=True, - resolve_path=True, - ), ) -> None: """ Update an existing knowledge graph index. @@ -278,9 +235,7 @@ def _update_cli( verbose=verbose, memprofile=memprofile, cache=cache, - config_filepath=config, skip_validation=skip_validation, - output_dir=output, method=method, ) @@ -288,26 +243,17 @@ def _update_cli( @app.command("prompt-tune") def _prompt_tune_cli( root: Path = typer.Option( - Path(), + Path.cwd(), "--root", "-r", help="The project root directory.", exists=True, dir_okay=True, + file_okay=False, writable=True, resolve_path=True, autocompletion=ROOT_AUTOCOMPLETE, ), - config: Path | None = typer.Option( - None, - "--config", - "-c", - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=CONFIG_AUTOCOMPLETE, - ), verbose: bool = typer.Option( False, "--verbose", @@ -392,7 +338,6 @@ def _prompt_tune_cli( loop.run_until_complete( prompt_tune( root=root, - config=config, domain=domain, verbose=verbose, selection_method=selection_method, @@ -412,28 +357,27 @@ def _prompt_tune_cli( @app.command("query") def _query_cli( + query: str = typer.Argument( + help="The query to execute.", + ), + root: Path = typer.Option( + Path.cwd(), + "--root", + "-r", + help="The project root directory.", + exists=True, + dir_okay=True, + file_okay=False, + writable=True, + resolve_path=True, + autocompletion=ROOT_AUTOCOMPLETE, + ), method: SearchMethod = typer.Option( - ..., + SearchMethod.GLOBAL, "--method", "-m", help="The query algorithm to use.", ), - query: str = typer.Option( - ..., - "--query", - "-q", - help="The query to execute.", - ), - config: Path | None = typer.Option( - None, - "--config", - "-c", - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=CONFIG_AUTOCOMPLETE, - ), verbose: bool = typer.Option( False, "--verbose", @@ -451,17 +395,6 @@ def _query_cli( resolve_path=True, autocompletion=ROOT_AUTOCOMPLETE, ), - root: Path = typer.Option( - Path(), - "--root", - "-r", - help="The project root directory.", - exists=True, - dir_okay=True, - writable=True, - resolve_path=True, - autocompletion=ROOT_AUTOCOMPLETE, - ), community_level: int = typer.Option( 2, "--community-level", @@ -500,7 +433,6 @@ def _query_cli( match method: case SearchMethod.LOCAL: run_local_search( - config_filepath=config, data_dir=data, root_dir=root, community_level=community_level, @@ -511,7 +443,6 @@ def _query_cli( ) case SearchMethod.GLOBAL: run_global_search( - config_filepath=config, data_dir=data, root_dir=root, community_level=community_level, @@ -523,7 +454,6 @@ def _query_cli( ) case SearchMethod.DRIFT: run_drift_search( - config_filepath=config, data_dir=data, root_dir=root, community_level=community_level, @@ -534,7 +464,6 @@ def _query_cli( ) case SearchMethod.BASIC: run_basic_search( - config_filepath=config, data_dir=data, root_dir=root, response_type=response_type, diff --git a/packages/graphrag/graphrag/cli/prompt_tune.py b/packages/graphrag/graphrag/cli/prompt_tune.py index 2646776f6f..2ba02e0ead 100644 --- a/packages/graphrag/graphrag/cli/prompt_tune.py +++ b/packages/graphrag/graphrag/cli/prompt_tune.py @@ -24,7 +24,6 @@ async def prompt_tune( root: Path, - config: Path | None, domain: str | None, verbose: bool, selection_method: api.DocSelectionType, @@ -43,7 +42,6 @@ async def prompt_tune( Parameters ---------- - - config: The configuration file. - root: The root directory. - domain: The domain to map the input documents to. - verbose: Enable verbose logging. @@ -58,8 +56,9 @@ async def prompt_tune( - k: The number of documents to select when using auto selection method. - min_examples_required: The minimum number of examples required for entity extraction prompts. """ - root_path = Path(root).resolve() - graph_config = load_config(root_path, config) + graph_config = load_config( + root_dir=root, + ) # override chunking config in the configuration if chunk_size != graph_config.chunks.size: diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 914196b865..93163db19d 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -22,7 +22,6 @@ def run_global_search( - config_filepath: Path | None, data_dir: Path | None, root_dir: Path, community_level: int | None, @@ -36,11 +35,13 @@ def run_global_search( Loads index files required for global search and calls the Query API. """ - root = root_dir.resolve() - cli_overrides = {} + cli_overrides: dict[str, Any] = {} if data_dir: - cli_overrides["output.base_dir"] = str(data_dir) - config = load_config(root, config_filepath, cli_overrides) + cli_overrides["output"] = {"base_dir": str(data_dir)} + config = load_config( + root_dir=root_dir, + cli_overrides=cli_overrides, + ) dataframe_dict = _resolve_output_files( config=config, @@ -108,7 +109,6 @@ def on_context(context: Any) -> None: def run_local_search( - config_filepath: Path | None, data_dir: Path | None, root_dir: Path, community_level: int, @@ -121,11 +121,13 @@ def run_local_search( Loads index files required for local search and calls the Query API. """ - root = root_dir.resolve() - cli_overrides = {} + cli_overrides: dict[str, Any] = {} if data_dir: - cli_overrides["output.base_dir"] = str(data_dir) - config = load_config(root, config_filepath, cli_overrides) + cli_overrides["output"] = {"base_dir": str(data_dir)} + config = load_config( + root_dir=root_dir, + cli_overrides=cli_overrides, + ) dataframe_dict = _resolve_output_files( config=config, @@ -204,7 +206,6 @@ def on_context(context: Any) -> None: def run_drift_search( - config_filepath: Path | None, data_dir: Path | None, root_dir: Path, community_level: int, @@ -217,11 +218,13 @@ def run_drift_search( Loads index files required for local search and calls the Query API. """ - root = root_dir.resolve() - cli_overrides = {} + cli_overrides: dict[str, Any] = {} if data_dir: - cli_overrides["output.base_dir"] = str(data_dir) - config = load_config(root, config_filepath, cli_overrides) + cli_overrides["output"] = {"base_dir": str(data_dir)} + config = load_config( + root_dir=root_dir, + cli_overrides=cli_overrides, + ) dataframe_dict = _resolve_output_files( config=config, @@ -295,7 +298,6 @@ def on_context(context: Any) -> None: def run_basic_search( - config_filepath: Path | None, data_dir: Path | None, root_dir: Path, response_type: str, @@ -307,11 +309,13 @@ def run_basic_search( Loads index files required for basic search and calls the Query API. """ - root = root_dir.resolve() - cli_overrides = {} + cli_overrides: dict[str, Any] = {} if data_dir: - cli_overrides["output.base_dir"] = str(data_dir) - config = load_config(root, config_filepath, cli_overrides) + cli_overrides["output"] = {"base_dir": str(data_dir)} + config = load_config( + root_dir=root_dir, + cli_overrides=cli_overrides, + ) dataframe_dict = _resolve_output_files( config=config, diff --git a/packages/graphrag/graphrag/config/create_graphrag_config.py b/packages/graphrag/graphrag/config/create_graphrag_config.py deleted file mode 100644 index 59d7699e71..0000000000 --- a/packages/graphrag/graphrag/config/create_graphrag_config.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parameterization settings for the default configuration, loaded from environment variables.""" - -from pathlib import Path -from typing import Any - -from graphrag.config.models.graph_rag_config import GraphRagConfig - - -def create_graphrag_config( - values: dict[str, Any] | None = None, - root_dir: str | None = None, -) -> GraphRagConfig: - """Load Configuration Parameters from a dictionary. - - Parameters - ---------- - values : dict[str, Any] | None - Dictionary of configuration values to pass into pydantic model. - root_dir : str | None - Root directory for the project. - skip_validation : bool - Skip pydantic model validation of the configuration. - This is useful for testing and mocking purposes but - should not be used in the core code or API. - - Returns - ------- - GraphRagConfig - The configuration object. - - Raises - ------ - ValidationError - If the configuration values do not satisfy pydantic validation. - """ - values = values or {} - if root_dir: - root_path = Path(root_dir).resolve() - values["root_dir"] = str(root_path) - return GraphRagConfig(**values) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index f57bf610f9..88449a6050 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -386,7 +386,6 @@ class VectorStoreDefaults: class GraphRagConfigDefaults: """Default values for GraphRAG.""" - root_dir: str = "" models: dict = field(default_factory=dict) reporting: ReportingDefaults = field(default_factory=ReportingDefaults) storage: StorageDefaults = field(default_factory=StorageDefaults) diff --git a/packages/graphrag/graphrag/config/environment_reader.py b/packages/graphrag/graphrag/config/environment_reader.py deleted file mode 100644 index 258422666c..0000000000 --- a/packages/graphrag/graphrag/config/environment_reader.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A configuration reader utility class.""" - -from collections.abc import Callable -from contextlib import contextmanager -from enum import Enum -from typing import Any, TypeVar - -from environs import Env - -T = TypeVar("T") - -KeyValue = str | Enum -EnvKeySet = str | list[str] - - -def read_key(value: KeyValue) -> str: - """Read a key value.""" - if not isinstance(value, str): - return value.value.lower() - return value.lower() - - -class EnvironmentReader: - """A configuration reader utility class.""" - - _env: Env - _config_stack: list[dict] - - def __init__(self, env: Env): - self._env = env - self._config_stack = [] - - @property - def env(self): - """Get the environment object.""" - return self._env - - def _read_env( - self, env_key: str | list[str], default_value: T, read: Callable[[str, T], T] - ) -> T | None: - if isinstance(env_key, str): - env_key = [env_key] - - for k in env_key: - result = read(k.upper(), default_value) - if result is not default_value: - return result - - return default_value - - def envvar_prefix(self, prefix: KeyValue): - """Set the environment variable prefix.""" - prefix = read_key(prefix) - prefix = f"{prefix}_".upper() - return self._env.prefixed(prefix) - - def use(self, value: Any | None): - """Create a context manager to push the value into the config_stack.""" - - @contextmanager - def config_context(): - self._config_stack.append(value or {}) - try: - yield - finally: - self._config_stack.pop() - - return config_context() - - @property - def section(self) -> dict: - """Get the current section.""" - return self._config_stack[-1] if self._config_stack else {} - - def str( - self, - key: KeyValue, - env_key: EnvKeySet | None = None, - default_value: str | None = None, - ) -> str | None: - """Read a configuration value.""" - key = read_key(key) - if self.section and key in self.section: - return self.section[key] - - return self._read_env( - env_key or key, default_value, (lambda k, dv: self._env(k, dv)) - ) - - def int( - self, - key: KeyValue, - env_key: EnvKeySet | None = None, - default_value: int | None = None, - ) -> int | None: - """Read an integer configuration value.""" - key = read_key(key) - if self.section and key in self.section: - return int(self.section[key]) - return self._read_env( - env_key or key, default_value, lambda k, dv: self._env.int(k, dv) - ) - - def bool( - self, - key: KeyValue, - env_key: EnvKeySet | None = None, - default_value: bool | None = None, - ) -> bool | None: - """Read an integer configuration value.""" - key = read_key(key) - if self.section and key in self.section: - return bool(self.section[key]) - - return self._read_env( - env_key or key, default_value, lambda k, dv: self._env.bool(k, dv) - ) - - def float( - self, - key: KeyValue, - env_key: EnvKeySet | None = None, - default_value: float | None = None, - ) -> float | None: - """Read a float configuration value.""" - key = read_key(key) - if self.section and key in self.section: - return float(self.section[key]) - return self._read_env( - env_key or key, default_value, lambda k, dv: self._env.float(k, dv) - ) - - def list( - self, - key: KeyValue, - env_key: EnvKeySet | None = None, - default_value: list | None = None, - ) -> list | None: - """Parse an list configuration value.""" - key = read_key(key) - result = None - if self.section and key in self.section: - result = self.section[key] - if isinstance(result, list): - return result - - if result is None: - result = self.str(key, env_key) - if result is not None: - result = [s.strip() for s in result.split(",")] - return [s for s in result if s] - return default_value diff --git a/packages/graphrag/graphrag/config/load_config.py b/packages/graphrag/graphrag/config/load_config.py index de9026037d..104bd59d43 100644 --- a/packages/graphrag/graphrag/config/load_config.py +++ b/packages/graphrag/graphrag/config/load_config.py @@ -3,149 +3,16 @@ """Default method for loading config.""" -import json -import os from pathlib import Path -from string import Template from typing import Any -import yaml -from dotenv import load_dotenv +from graphrag_common.config import load_config as lc -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.models.graph_rag_config import GraphRagConfig -_default_config_files = ["settings.yaml", "settings.yml", "settings.json"] - - -def _search_for_config_in_root_dir(root: str | Path) -> Path | None: - """Resolve the config path from the given root directory. - - Parameters - ---------- - root : str | Path - The path to the root directory containing the config file. - Searches for a default config file (settings.{yaml,yml,json}). - - Returns - ------- - Path | None - returns a Path if there is a config in the root directory - Otherwise returns None. - """ - root = Path(root) - - if not root.is_dir(): - msg = f"Invalid config path: {root} is not a directory" - raise FileNotFoundError(msg) - - for file in _default_config_files: - if (root / file).is_file(): - return root / file - - return None - - -def _parse_env_variables(text: str) -> str: - """Parse environment variables in the configuration text. - - Parameters - ---------- - text : str - The configuration text. - - Returns - ------- - str - The configuration text with environment variables parsed. - - Raises - ------ - KeyError - If an environment variable is not found. - """ - return Template(text).substitute(os.environ) - - -def _load_dotenv(config_path: Path | str) -> None: - """Load the .env file if it exists in the same directory as the config file. - - Parameters - ---------- - config_path : Path | str - The path to the config file. - """ - config_path = Path(config_path) - dotenv_path = config_path.parent / ".env" - if dotenv_path.exists(): - load_dotenv(dotenv_path) - - -def _get_config_path(root_dir: Path, config_filepath: Path | None) -> Path: - """Get the configuration file path. - - Parameters - ---------- - root_dir : str | Path - The root directory of the project. Will search for the config file in this directory. - config_filepath : str | None - The path to the config file. - If None, searches for config file in root. - - Returns - ------- - Path - The configuration file path. - """ - if config_filepath: - config_path = config_filepath.resolve() - if not config_path.exists(): - msg = f"Specified Config file not found: {config_path}" - raise FileNotFoundError(msg) - else: - config_path = _search_for_config_in_root_dir(root_dir) - - if not config_path: - msg = f"Config file not found in root directory: {root_dir}" - raise FileNotFoundError(msg) - - return config_path - - -def _apply_overrides(data: dict[str, Any], overrides: dict[str, Any]) -> None: - """Apply the overrides to the raw configuration.""" - for key, value in overrides.items(): - keys = key.split(".") - target = data - current_path = keys[0] - for k in keys[:-1]: - current_path += f".{k}" - target_obj = target.get(k, {}) - if not isinstance(target_obj, dict): - msg = f"Cannot override non-dict value: data[{current_path}] is not a dict." - raise TypeError(msg) - target[k] = target_obj - target = target[k] - target[keys[-1]] = value - - -def _parse(file_extension: str, contents: str) -> dict[str, Any]: - """Parse configuration.""" - match file_extension: - case ".yaml" | ".yml": - return yaml.safe_load(contents) - case ".json": - return json.loads(contents) - case _: - msg = ( - f"Unable to parse config. Unsupported file extension: {file_extension}" - ) - raise ValueError(msg) - def load_config( - root_dir: Path, - config_filepath: Path | None = None, + root_dir: str | Path, cli_overrides: dict[str, Any] | None = None, ) -> GraphRagConfig: """Load configuration from a file. @@ -153,13 +20,11 @@ def load_config( Parameters ---------- root_dir : str | Path - The root directory of the project. Will search for the config file in this directory. - config_filepath : str | None - The path to the config file. - If None, searches for config file in root. + The root directory of the project. + Searches for settings.[yaml|yml|json] config files. cli_overrides : dict[str, Any] | None - A flat dictionary of cli overrides. - Example: {'output.base_dir': 'override_value'} + A nested dictionary of cli overrides. + Example: {'output': {'base_dir': 'override_value'}} Returns ------- @@ -170,22 +35,13 @@ def load_config( ------ FileNotFoundError If the config file is not found. - ValueError - If the config file extension is not supported. - TypeError - If applying cli overrides to the config fails. - KeyError - If config file references a non-existent environment variable. + ConfigParsingError + If there was an error parsing the config file or its environment variables. ValidationError If there are pydantic validation errors when instantiating the config. """ - root = root_dir.resolve() - config_path = _get_config_path(root, config_filepath) - _load_dotenv(config_path) - config_extension = config_path.suffix - config_text = config_path.read_text(encoding="utf-8") - config_text = _parse_env_variables(config_text) - config_data = _parse(config_extension, config_text) - if cli_overrides: - _apply_overrides(config_data, cli_overrides) - return create_graphrag_config(config_data, root_dir=str(root)) + return lc( + config_initializer=GraphRagConfig, + config_path=root_dir, + overrides=cli_overrides, + ) diff --git a/packages/graphrag/graphrag/config/models/community_reports_config.py b/packages/graphrag/graphrag/config/models/community_reports_config.py index 1257124bde..c4f920cefe 100644 --- a/packages/graphrag/graphrag/config/models/community_reports_config.py +++ b/packages/graphrag/graphrag/config/models/community_reports_config.py @@ -51,15 +51,13 @@ class CommunityReportsConfig(BaseModel): default=graphrag_config_defaults.community_reports.max_input_length, ) - def resolved_prompts(self, root_dir: str) -> CommunityReportPrompts: + def resolved_prompts(self) -> CommunityReportPrompts: """Get the resolved community report extraction prompts.""" return CommunityReportPrompts( - graph_prompt=(Path(root_dir) / self.graph_prompt).read_text( - encoding="utf-8" - ) + graph_prompt=Path(self.graph_prompt).read_text(encoding="utf-8") if self.graph_prompt else COMMUNITY_REPORT_PROMPT, - text_prompt=(Path(root_dir) / self.text_prompt).read_text(encoding="utf-8") + text_prompt=Path(self.text_prompt).read_text(encoding="utf-8") if self.text_prompt else COMMUNITY_REPORT_TEXT_PROMPT, ) diff --git a/packages/graphrag/graphrag/config/models/extract_claims_config.py b/packages/graphrag/graphrag/config/models/extract_claims_config.py index 77a633b05f..63fec7ac5a 100644 --- a/packages/graphrag/graphrag/config/models/extract_claims_config.py +++ b/packages/graphrag/graphrag/config/models/extract_claims_config.py @@ -47,10 +47,10 @@ class ExtractClaimsConfig(BaseModel): default=graphrag_config_defaults.extract_claims.max_gleanings, ) - def resolved_prompts(self, root_dir: str) -> ClaimExtractionPrompts: + def resolved_prompts(self) -> ClaimExtractionPrompts: """Get the resolved claim extraction prompts.""" return ClaimExtractionPrompts( - extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") + extraction_prompt=Path(self.prompt).read_text(encoding="utf-8") if self.prompt else EXTRACT_CLAIMS_PROMPT, ) diff --git a/packages/graphrag/graphrag/config/models/extract_graph_config.py b/packages/graphrag/graphrag/config/models/extract_graph_config.py index 8a61585e5c..81c8df4235 100644 --- a/packages/graphrag/graphrag/config/models/extract_graph_config.py +++ b/packages/graphrag/graphrag/config/models/extract_graph_config.py @@ -43,10 +43,10 @@ class ExtractGraphConfig(BaseModel): default=graphrag_config_defaults.extract_graph.max_gleanings, ) - def resolved_prompts(self, root_dir: str) -> ExtractGraphPrompts: + def resolved_prompts(self) -> ExtractGraphPrompts: """Get the resolved graph extraction prompts.""" return ExtractGraphPrompts( - extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") + extraction_prompt=Path(self.prompt).read_text(encoding="utf-8") if self.prompt else GRAPH_EXTRACTION_PROMPT, ) diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 71509f0176..15d02eaf3a 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -53,22 +53,6 @@ def __str__(self): """Get a string representation.""" return self.model_dump_json(indent=4) - root_dir: str = Field( - description="The root directory for the configuration.", - default=graphrag_config_defaults.root_dir, - ) - - def _validate_root_dir(self) -> None: - """Validate the root directory.""" - if self.root_dir.strip() == "": - self.root_dir = str(Path.cwd()) - - root_dir = Path(self.root_dir).resolve() - if not root_dir.is_dir(): - msg = f"Invalid root directory: {self.root_dir} is not a directory." - raise FileNotFoundError(msg) - self.root_dir = str(root_dir) - models: dict[str, LanguageModelConfig] = Field( description="Available language model configurations.", default=graphrag_config_defaults.models, @@ -156,7 +140,7 @@ def _validate_input_base_dir(self) -> None: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) self.input.storage.base_dir = str( - (Path(self.root_dir) / self.input.storage.base_dir).resolve() + Path(self.input.storage.base_dir).resolve() ) chunks: ChunkingConfig = Field( @@ -179,9 +163,7 @@ def _validate_output_base_dir(self) -> None: if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) - self.output.base_dir = str( - (Path(self.root_dir) / self.output.base_dir).resolve() - ) + self.output.base_dir = str(Path(self.output.base_dir).resolve()) update_index_output: StorageConfig = Field( description="The output configuration for the updated index.", @@ -198,7 +180,7 @@ def _validate_update_index_output_base_dir(self) -> None: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) self.update_index_output.base_dir = str( - (Path(self.root_dir) / self.update_index_output.base_dir).resolve() + Path(self.update_index_output.base_dir).resolve() ) cache: CacheConfig = Field( @@ -217,9 +199,7 @@ def _validate_reporting_base_dir(self) -> None: if self.reporting.base_dir.strip() == "": msg = "Reporting base directory is required for file reporting. Please rerun `graphrag init` and set the reporting configuration." raise ValueError(msg) - self.reporting.base_dir = str( - (Path(self.root_dir) / self.reporting.base_dir).resolve() - ) + self.reporting.base_dir = str(Path(self.reporting.base_dir).resolve()) vector_store: VectorStoreConfig = Field( description="The vector store configuration.", default=VectorStoreConfig() @@ -315,7 +295,7 @@ def _validate_vector_store_db_uri(self) -> None: if not store.db_uri or store.db_uri.strip == "": msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." raise ValueError(msg) - store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) + store.db_uri = str(Path(store.db_uri).resolve()) def _validate_factories(self) -> None: """Validate the factories used in the configuration.""" @@ -349,7 +329,6 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig: @model_validator(mode="after") def _validate_model(self): """Validate the model configuration.""" - self._validate_root_dir() self._validate_models() self._validate_input_pattern() self._validate_input_base_dir() diff --git a/packages/graphrag/graphrag/config/models/summarize_descriptions_config.py b/packages/graphrag/graphrag/config/models/summarize_descriptions_config.py index 3414db71ec..024d2d964b 100644 --- a/packages/graphrag/graphrag/config/models/summarize_descriptions_config.py +++ b/packages/graphrag/graphrag/config/models/summarize_descriptions_config.py @@ -43,10 +43,10 @@ class SummarizeDescriptionsConfig(BaseModel): default=graphrag_config_defaults.summarize_descriptions.max_input_tokens, ) - def resolved_prompts(self, root_dir: str) -> SummarizeDescriptionsPrompts: + def resolved_prompts(self) -> SummarizeDescriptionsPrompts: """Get the resolved description summarization prompts.""" return SummarizeDescriptionsPrompts( - summarize_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") + summarize_prompt=Path(self.prompt).read_text(encoding="utf-8") if self.prompt else SUMMARIZE_PROMPT, ) diff --git a/packages/graphrag/graphrag/config/read_dotenv.py b/packages/graphrag/graphrag/config/read_dotenv.py deleted file mode 100644 index a2da5d43fe..0000000000 --- a/packages/graphrag/graphrag/config/read_dotenv.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing the read_dotenv utility.""" - -import logging -import os -from pathlib import Path - -from dotenv import dotenv_values - -logger = logging.getLogger(__name__) - - -def read_dotenv(root: str) -> None: - """Read a .env file in the given root path.""" - env_path = Path(root) / ".env" - if env_path.exists(): - logger.info("Loading pipeline .env file") - env_config = dotenv_values(f"{env_path}") - for key, value in env_config.items(): - if key not in os.environ: - os.environ[key] = value or "" - else: - logger.info("No .env file found at %s", root) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index f652db7acd..a0b2011eab 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -35,11 +35,9 @@ async def run_pipeline( input_documents: pd.DataFrame | None = None, ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" - root_dir = config.root_dir - input_storage = create_storage_from_config(config.input.storage) output_storage = create_storage_from_config(config.output) - cache = create_cache_from_config(config.cache, root_dir) + cache = create_cache_from_config(config.cache) # load existing state in case any workflows are stateful state_json = await output_storage.get("context.json") diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports.py b/packages/graphrag/graphrag/index/workflows/create_community_reports.py index 0415cb3b0b..981e47227e 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports.py @@ -55,7 +55,7 @@ async def run_workflow( claims = await load_table_from_storage("covariates", context.output_storage) model_config = config.get_language_model_config(config.community_reports.model_id) - prompts = config.community_reports.resolved_prompts(config.root_dir) + prompts = config.community_reports.resolved_prompts() model = ModelManager().get_or_create_chat_model( name=config.community_reports.model_instance_name, diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py index 94d79ca572..ac269df8bc 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py @@ -56,7 +56,7 @@ async def run_workflow( tokenizer = get_tokenizer(model_config) - prompts = config.community_reports.resolved_prompts(config.root_dir) + prompts = config.community_reports.resolved_prompts() output = await create_community_reports_text( entities, diff --git a/packages/graphrag/graphrag/index/workflows/extract_covariates.py b/packages/graphrag/graphrag/index/workflows/extract_covariates.py index 52da07451e..99a8f279ed 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/extract_covariates.py @@ -45,7 +45,7 @@ async def run_workflow( cache=context.cache, ) - prompts = config.extract_claims.resolved_prompts(config.root_dir) + prompts = config.extract_claims.resolved_prompts() output = await extract_covariates( text_units=text_units, diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index 1a28a3978a..de9c454261 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -36,7 +36,7 @@ async def run_workflow( extraction_model_config = config.get_language_model_config( config.extract_graph.model_id ) - extraction_prompts = config.extract_graph.resolved_prompts(config.root_dir) + extraction_prompts = config.extract_graph.resolved_prompts() extraction_model = ModelManager().get_or_create_chat_model( name=config.extract_graph.model_instance_name, model_type=extraction_model_config.type, @@ -47,9 +47,7 @@ async def run_workflow( summarization_model_config = config.get_language_model_config( config.summarize_descriptions.model_id ) - summarization_prompts = config.summarize_descriptions.resolved_prompts( - config.root_dir - ) + summarization_prompts = config.summarize_descriptions.resolved_prompts() summarization_model = ModelManager().get_or_create_chat_model( name=config.summarize_descriptions.model_instance_name, model_type=summarization_model_config.type, diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 69bdefa3f2..1245303559 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -81,7 +81,7 @@ async def _update_entities_and_relationships( summarization_model_config = config.get_language_model_config( config.summarize_descriptions.model_id ) - prompts = config.summarize_descriptions.resolved_prompts(config.root_dir) + prompts = config.summarize_descriptions.resolved_prompts() model = ModelManager().get_or_create_chat_model( name="summarize_descriptions", model_type=summarization_model_config.type, diff --git a/packages/graphrag/graphrag/logger/factory.py b/packages/graphrag/graphrag/logger/factory.py index 744a563272..1094d080f3 100644 --- a/packages/graphrag/graphrag/logger/factory.py +++ b/packages/graphrag/graphrag/logger/factory.py @@ -32,10 +32,9 @@ class LoggerFactory(Factory[logging.Handler]): # --- register built-in logger implementations --- def create_file_logger(**kwargs) -> logging.Handler: """Create a file-based logger.""" - root_dir = kwargs["root_dir"] base_dir = kwargs["base_dir"] filename = kwargs["filename"] - log_dir = Path(root_dir) / base_dir + log_dir = Path(base_dir) log_dir.mkdir(parents=True, exist_ok=True) log_file_path = log_dir / filename diff --git a/packages/graphrag/graphrag/logger/standard_logging.py b/packages/graphrag/graphrag/logger/standard_logging.py index 31296e12d9..d9e4d0f26f 100644 --- a/packages/graphrag/graphrag/logger/standard_logging.py +++ b/packages/graphrag/graphrag/logger/standard_logging.py @@ -77,7 +77,7 @@ def init_loggers( reporting_config = config.reporting config_dict = reporting_config.model_dump() - args = {**config_dict, "root_dir": config.root_dir, "filename": filename} + args = {**config_dict, "filename": filename} handler = LoggerFactory().create(reporting_config.type, args) logger.addHandler(handler) diff --git a/packages/graphrag/graphrag/utils/api.py b/packages/graphrag/graphrag/utils/api.py index a972d21db9..f264c1a9ed 100644 --- a/packages/graphrag/graphrag/utils/api.py +++ b/packages/graphrag/graphrag/utils/api.py @@ -87,7 +87,7 @@ def reformat_context_data(context_data: dict) -> dict: return final_format -def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: +def load_search_prompt(prompt_config: str | None) -> str | None: """ Load the search prompt from disk if configured. @@ -95,7 +95,7 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: """ if prompt_config: - prompt_file = Path(root_dir) / prompt_config + prompt_file = Path(prompt_config).resolve() if prompt_file.exists(): return prompt_file.read_bytes().decode(encoding="utf-8") return None @@ -110,13 +110,12 @@ def create_storage_from_config(output: StorageConfig) -> PipelineStorage: ) -def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache: +def create_cache_from_config(cache: CacheConfig) -> PipelineCache: """Create a cache object from the config.""" cache_config = cache.model_dump() - args = {**cache_config, "root_dir": root_dir} return CacheFactory().create( strategy=cache_config["type"], - init_args=args, + init_args=cache_config, ) diff --git a/packages/graphrag/pyproject.toml b/packages/graphrag/pyproject.toml index a3167a8045..a7b97f1f0f 100644 --- a/packages/graphrag/pyproject.toml +++ b/packages/graphrag/pyproject.toml @@ -52,8 +52,6 @@ dependencies = [ "pandas>=2.2.3", "pyarrow>=17.0.0", "pydantic>=2.10.3", - "python-dotenv>=1.0.1", - "pyyaml>=6.0.2", "spacy>=3.8.4", "textblob>=0.18.0.post0", "tiktoken>=0.11.0", diff --git a/pyproject.toml b/pyproject.toml index 199fd6a514..3cb1ae1d67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,7 +217,11 @@ convention = "numpy" # https://github.com/microsoft/pyright/blob/9f81564a4685ff5c55edd3959f9b39030f590b2f/docs/configuration.md#sample-pyprojecttoml-file [tool.pyright] -include = ["packages/graphrag/graphrag", "packages/graphrag-common/graphrag_common", "tests"] +include = [ + "packages/graphrag/graphrag", + "packages/graphrag-common/graphrag_common", + "tests" +] exclude = ["**/node_modules", "**/__pycache__"] [tool.pytest.ini_options] diff --git a/tests/integration/cache/test_factory.py b/tests/integration/cache/test_factory.py index 53ec3cba56..766adc1f8f 100644 --- a/tests/integration/cache/test_factory.py +++ b/tests/integration/cache/test_factory.py @@ -34,7 +34,7 @@ def test_create_memory_cache(): def test_create_file_cache(): cache = CacheFactory().create( strategy=CacheType.file.value, - init_args={"root_dir": "/tmp", "base_dir": "testcache"}, + init_args={"base_dir": "testcache"}, ) assert isinstance(cache, JsonPipelineCache) diff --git a/tests/integration/logging/test_standard_logging.py b/tests/integration/logging/test_standard_logging.py index 106e57f9fa..6bb28a7343 100644 --- a/tests/integration/logging/test_standard_logging.py +++ b/tests/integration/logging/test_standard_logging.py @@ -4,6 +4,7 @@ """Tests for standard logging functionality.""" import logging +import os import tempfile from pathlib import Path @@ -38,7 +39,11 @@ def test_logger_hierarchy(): def test_init_loggers_file_config(): """Test that init_loggers works with file configuration.""" with tempfile.TemporaryDirectory() as temp_dir: - config = get_default_graphrag_config(root_dir=temp_dir) + # Need to manually change cwd since we are not using load_config + # to create graphrag config. + cwd = Path.cwd() + os.chdir(temp_dir) + config = get_default_graphrag_config() # call init_loggers with file config init_loggers(config=config) @@ -68,12 +73,17 @@ def test_init_loggers_file_config(): if isinstance(handler, logging.FileHandler): handler.close() logger.handlers.clear() + os.chdir(cwd) def test_init_loggers_file_verbose(): """Test that init_loggers works with verbose flag.""" with tempfile.TemporaryDirectory() as temp_dir: - config = get_default_graphrag_config(root_dir=temp_dir) + # Need to manually change cwd since we are not using load_config + # to create graphrag config. + cwd = Path.cwd() + os.chdir(temp_dir) + config = get_default_graphrag_config() # call init_loggers with file config init_loggers(config=config, verbose=True) @@ -96,12 +106,17 @@ def test_init_loggers_file_verbose(): if isinstance(handler, logging.FileHandler): handler.close() logger.handlers.clear() + os.chdir(cwd) def test_init_loggers_custom_filename(): """Test that init_loggers works with custom filename.""" with tempfile.TemporaryDirectory() as temp_dir: - config = get_default_graphrag_config(root_dir=temp_dir) + # Need to manually change cwd since we are not using load_config + # to create graphrag config. + cwd = Path.cwd() + os.chdir(temp_dir) + config = get_default_graphrag_config() # call init_loggers with file config init_loggers(config=config, filename="custom-log.log") @@ -117,3 +132,4 @@ def test_init_loggers_custom_filename(): if isinstance(handler, logging.FileHandler): handler.close() logger.handlers.clear() + os.chdir(cwd) diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index b6966930e7..9821bed551 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -204,14 +204,13 @@ def __run_query(self, root: Path, query_config: dict[str, str]): "run", "poe", "query", + query_config["query"], "--root", root.resolve().as_posix(), "--method", query_config["method"], "--community-level", str(query_config.get("community_level", 2)), - "--query", - query_config["query"], ] logger.info("running command ", " ".join(command)) diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py index ddf1d17a5f..545e66b7c5 100644 --- a/tests/unit/config/test_config.py +++ b/tests/unit/config/test_config.py @@ -7,9 +7,9 @@ import graphrag.config.defaults as defs import pytest -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import AuthType, ModelType from graphrag.config.load_config import load_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from pydantic import ValidationError from tests.unit.config.utils import ( @@ -33,14 +33,14 @@ def test_missing_openai_required_api_key() -> None: # API Key required for OpenAIChat with pytest.raises(ValidationError): - create_graphrag_config({"models": model_config_missing_api_key}) + GraphRagConfig(models=model_config_missing_api_key) # API Key required for OpenAIEmbedding model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["type"] = ( ModelType.Embedding ) with pytest.raises(ValidationError): - create_graphrag_config({"models": model_config_missing_api_key}) + GraphRagConfig(models=model_config_missing_api_key) def test_missing_azure_api_key() -> None: @@ -58,13 +58,13 @@ def test_missing_azure_api_key() -> None: } with pytest.raises(ValidationError): - create_graphrag_config({"models": model_config_missing_api_key}) + GraphRagConfig(models=model_config_missing_api_key) # API Key not required for managed identity model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["auth_type"] = ( AuthType.AzureManagedIdentity ) - create_graphrag_config({"models": model_config_missing_api_key}) + GraphRagConfig(models=model_config_missing_api_key) def test_conflicting_auth_type() -> None: @@ -79,7 +79,7 @@ def test_conflicting_auth_type() -> None: } with pytest.raises(ValidationError): - create_graphrag_config({"models": model_config_invalid_auth_type}) + GraphRagConfig(models=model_config_invalid_auth_type) def test_conflicting_azure_api_key() -> None: @@ -98,7 +98,7 @@ def test_conflicting_azure_api_key() -> None: } with pytest.raises(ValidationError): - create_graphrag_config({"models": model_config_conflicting_api_key}) + GraphRagConfig(models=model_config_conflicting_api_key) base_azure_model_config = { @@ -117,12 +117,12 @@ def test_missing_azure_api_base() -> None: del missing_api_base_config["api_base"] with pytest.raises(ValidationError): - create_graphrag_config({ - "models": { + GraphRagConfig( + models={ defs.DEFAULT_CHAT_MODEL_ID: missing_api_base_config, defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG, - } - }) + } # type: ignore + ) def test_missing_azure_api_version() -> None: @@ -130,46 +130,49 @@ def test_missing_azure_api_version() -> None: del missing_api_version_config["api_version"] with pytest.raises(ValidationError): - create_graphrag_config({ - "models": { + GraphRagConfig( + models={ defs.DEFAULT_CHAT_MODEL_ID: missing_api_version_config, defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG, - } - }) + } # type: ignore + ) def test_default_config() -> None: expected = get_default_graphrag_config() - actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + actual = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore assert_graphrag_configs(actual, expected) @mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True) def test_load_minimal_config() -> None: - cwd = Path(__file__).parent - root_dir = (cwd / "fixtures" / "minimal_config").resolve() - expected = get_default_graphrag_config(str(root_dir)) - actual = load_config(root_dir=root_dir) + cwd = Path.cwd() + root_dir = (Path(__file__).parent / "fixtures" / "minimal_config").resolve() + os.chdir(root_dir) + expected = get_default_graphrag_config() + + actual = load_config( + root_dir=root_dir, + ) assert_graphrag_configs(actual, expected) + # Need to reset cwd after test + os.chdir(cwd) @mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True) def test_load_config_with_cli_overrides() -> None: - cwd = Path(__file__).parent - root_dir = (cwd / "fixtures" / "minimal_config").resolve() + cwd = Path.cwd() + root_dir = (Path(__file__).parent / "fixtures" / "minimal_config").resolve() + os.chdir(root_dir) output_dir = "some_output_dir" expected_output_base_dir = root_dir / output_dir - expected = get_default_graphrag_config(str(root_dir)) + expected = get_default_graphrag_config() expected.output.base_dir = str(expected_output_base_dir) + actual = load_config( root_dir=root_dir, - cli_overrides={"output.base_dir": output_dir}, + cli_overrides={"output": {"base_dir": output_dir}}, ) assert_graphrag_configs(actual, expected) - - -def test_load_config_missing_env_vars() -> None: - cwd = Path(__file__).parent - root_dir = (cwd / "fixtures" / "minimal_config_missing_env_var").resolve() - with pytest.raises(KeyError): - load_config(root_dir=root_dir) + # Need to reset cwd after test + os.chdir(cwd) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index c7dbb4cd5c..001518f62a 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -54,11 +54,10 @@ } -def get_default_graphrag_config(root_dir: str | None = None) -> GraphRagConfig: +def get_default_graphrag_config() -> GraphRagConfig: return GraphRagConfig(**{ **asdict(defs.graphrag_config_defaults), "models": DEFAULT_MODEL_CONFIG, - **({"root_dir": root_dir} if root_dir else {}), }) @@ -350,8 +349,6 @@ def assert_basic_search_configs( def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> None: - assert actual.root_dir == expected.root_dir - a_keys = sorted(actual.models.keys()) e_keys = sorted(expected.models.keys()) assert len(a_keys) == len(e_keys) diff --git a/tests/unit/indexing/test_init_content.py b/tests/unit/indexing/test_init_content.py index cfd59cd9ad..61f2d11fa6 100644 --- a/tests/unit/indexing/test_init_content.py +++ b/tests/unit/indexing/test_init_content.py @@ -5,14 +5,13 @@ from typing import Any, cast import yaml -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.init_content import INIT_YAML from graphrag.config.models.graph_rag_config import GraphRagConfig def test_init_yaml(): data = yaml.load(INIT_YAML, Loader=yaml.FullLoader) - config = create_graphrag_config(data) + config = GraphRagConfig(**data) GraphRagConfig.model_validate(config, strict=True) @@ -26,5 +25,5 @@ def uncomment_line(line: str) -> str: content = "\n".join([uncomment_line(line) for line in lines]) data = yaml.load(content, Loader=yaml.FullLoader) - config = create_graphrag_config(data) + config = GraphRagConfig(**data) GraphRagConfig.model_validate(config, strict=True) diff --git a/tests/unit/load_config/__init__.py b/tests/unit/load_config/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/load_config/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/load_config/config.py b/tests/unit/load_config/config.py new file mode 100644 index 0000000000..f3b77feb65 --- /dev/null +++ b/tests/unit/load_config/config.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Config models for load_config unit tests.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class TestNestedModel(BaseModel): + """Test nested model.""" + + model_config = ConfigDict(extra="forbid") + + nested_str: str = Field(description="A nested field.") + nested_int: int = Field(description="Another nested field.") + + +class TestConfigModel(BaseModel): + """Test configuration model.""" + + model_config = ConfigDict(extra="forbid") + __test__ = False # type: ignore + + name: str = Field(description="Name field.") + value: int = Field(description="Value field.") + nested: TestNestedModel = Field(description="Nested model field.") + nested_list: list[TestNestedModel] = Field(description="List of nested models.") diff --git a/tests/unit/load_config/fixtures/config_with_env.yaml b/tests/unit/load_config/fixtures/config_with_env.yaml new file mode 100644 index 0000000000..ecefbbc457 --- /dev/null +++ b/tests/unit/load_config/fixtures/config_with_env.yaml @@ -0,0 +1,10 @@ +name: ${LOAD_CONFIG_NAME} +value: 100 +nested: + nested_str: nested_value + nested_int: 42 +nested_list: + - nested_str: list_value_1 + nested_int: 7 + - nested_str: list_value_2 + nested_int: 8 \ No newline at end of file diff --git a/tests/unit/load_config/fixtures/invalid_config.yaml b/tests/unit/load_config/fixtures/invalid_config.yaml new file mode 100644 index 0000000000..d2da11d0eb --- /dev/null +++ b/tests/unit/load_config/fixtures/invalid_config.yaml @@ -0,0 +1 @@ +name: test_name \ No newline at end of file diff --git a/tests/unit/load_config/fixtures/invalid_config_format.yaml b/tests/unit/load_config/fixtures/invalid_config_format.yaml new file mode 100644 index 0000000000..b851bf08c7 --- /dev/null +++ b/tests/unit/load_config/fixtures/invalid_config_format.yaml @@ -0,0 +1,8 @@ +{ + "key": "value", + "invalid_yaml": true +} +{ + "key": "value", + "invalid_yaml": true +} \ No newline at end of file diff --git a/tests/unit/load_config/fixtures/settings.yaml b/tests/unit/load_config/fixtures/settings.yaml new file mode 100644 index 0000000000..a54919d1eb --- /dev/null +++ b/tests/unit/load_config/fixtures/settings.yaml @@ -0,0 +1,10 @@ +name: test_name +value: 100 +nested: + nested_str: nested_value + nested_int: 42 +nested_list: + - nested_str: list_value_1 + nested_int: 7 + - nested_str: list_value_2 + nested_int: 8 \ No newline at end of file diff --git a/tests/unit/load_config/fixtures/test.env b/tests/unit/load_config/fixtures/test.env new file mode 100644 index 0000000000..0ca30592c0 --- /dev/null +++ b/tests/unit/load_config/fixtures/test.env @@ -0,0 +1 @@ +LOAD_CONFIG_NAME=env_name \ No newline at end of file diff --git a/tests/unit/load_config/test_load_config.py b/tests/unit/load_config/test_load_config.py new file mode 100644 index 0000000000..0945cf214f --- /dev/null +++ b/tests/unit/load_config/test_load_config.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Unit tests for graphrag-config.load_config.""" + +import os +from pathlib import Path + +import pytest +from graphrag_common.config import ConfigParsingError, load_config +from pydantic import ValidationError + +from .config import TestConfigModel + + +def test_load_config_validation(): + """Test loading config validation.""" + + with pytest.raises( + FileNotFoundError, + ): + _ = load_config(TestConfigModel, "non_existent_config.yaml") + + config_directory = Path(__file__).parent / "fixtures" + invalid_config_formatting_path = config_directory / "invalid_config_format.yaml" + + with pytest.raises( + FileNotFoundError, + ): + _ = load_config( + config_initializer=TestConfigModel, + config_path=invalid_config_formatting_path, + dot_env_path="non_existent.env", + ) + + # Using yaml to parse invalid json formatting + with pytest.raises( + ConfigParsingError, + ): + _ = load_config(TestConfigModel, invalid_config_formatting_path) + + invalid_config_path = config_directory / "invalid_config.yaml" + + # Test validation error from config model + with pytest.raises( + ValidationError, + ): + _ = load_config( + config_initializer=TestConfigModel, + config_path=invalid_config_path, + set_cwd=False, + ) + + +def test_load_config(): + """Test loading configuration.""" + + config_directory = Path(__file__).parent / "fixtures" + config_path = config_directory / "settings.yaml" + + # Load from dir + config = load_config( + config_initializer=TestConfigModel, config_path=config_directory, set_cwd=False + ) + + assert config.name == "test_name" + assert config.value == 100 + assert config.nested.nested_str == "nested_value" + assert config.nested.nested_int == 42 + assert len(config.nested_list) == 2 + assert config.nested_list[0].nested_str == "list_value_1" + assert config.nested_list[0].nested_int == 7 + assert config.nested_list[1].nested_str == "list_value_2" + assert config.nested_list[1].nested_int == 8 + + # Should not have changed directories + root_repo_dir = Path(__file__).parent.parent.parent.parent.resolve() + assert Path.cwd().resolve() == root_repo_dir + + config = load_config( + config_initializer=TestConfigModel, + config_path=config_path, + set_cwd=False, + ) + + assert config.name == "test_name" + assert config.value == 100 + assert config.nested.nested_str == "nested_value" + assert config.nested.nested_int == 42 + assert len(config.nested_list) == 2 + assert config.nested_list[0].nested_str == "list_value_1" + assert config.nested_list[0].nested_int == 7 + assert config.nested_list[1].nested_str == "list_value_2" + assert config.nested_list[1].nested_int == 8 + + overrides = { + "value": 65537, + "nested": {"nested_int": 84}, + "nested_list": [ + {"nested_str": "overridden_list_value_1", "nested_int": 23}, + ], + } + + cwd = Path.cwd() + config_with_overrides = load_config( + config_initializer=TestConfigModel, + config_path=config_path, + overrides=overrides, + ) + + # Should have changed directories to the config file location + assert Path.cwd() == config_directory + assert ( + Path("some/new/path").resolve() + == (config_directory / "some/new/path").resolve() + ) + # Reset cwd + os.chdir(cwd) + + assert config_with_overrides.name == "test_name" + assert config_with_overrides.value == 65537 + assert config_with_overrides.nested.nested_str == "nested_value" + assert config_with_overrides.nested.nested_int == 84 + assert len(config_with_overrides.nested_list) == 1 + assert config_with_overrides.nested_list[0].nested_str == "overridden_list_value_1" + assert config_with_overrides.nested_list[0].nested_int == 23 + + config_with_env_vars_path = config_directory / "config_with_env.yaml" + + # Config contains env vars that do not exist + # and no .env file is provided + with pytest.raises( + ConfigParsingError, + ): + _ = load_config( + config_initializer=TestConfigModel, + config_path=config_with_env_vars_path, + load_dot_env_file=False, + set_cwd=False, + ) + + env_path = config_directory / "test.env" + config_with_env_vars = load_config( + config_initializer=TestConfigModel, + config_path=config_with_env_vars_path, + dot_env_path=env_path, + ) + + assert config_with_env_vars.name == "env_name" + assert config_with_env_vars.value == 100 + assert config_with_env_vars.nested.nested_str == "nested_value" + assert config_with_env_vars.nested.nested_int == 42 + assert len(config_with_env_vars.nested_list) == 2 + assert config_with_env_vars.nested_list[0].nested_str == "list_value_1" + assert config_with_env_vars.nested_list[0].nested_int == 7 + assert config_with_env_vars.nested_list[1].nested_str == "list_value_2" + assert config_with_env_vars.nested_list[1].nested_int == 8 diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 87148a981d..73b72fc2cb 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.workflows.create_base_text_units import run_workflow from graphrag.utils.storage import load_table_from_storage @@ -19,7 +19,7 @@ async def test_create_base_text_units(): context = await create_test_context() - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore await run_workflow(config, context) @@ -33,7 +33,7 @@ async def test_create_base_text_units_metadata(): context = await create_test_context() - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.input.metadata = ["title"] config.chunks.prepend_metadata = True @@ -50,7 +50,7 @@ async def test_create_base_text_units_metadata_included_in_chunk(): context = await create_test_context() - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.input.metadata = ["title"] config.chunks.prepend_metadata = True config.chunks.chunk_size_includes_metadata = True diff --git a/tests/verbs/test_create_communities.py b/tests/verbs/test_create_communities.py index 1f51667cf3..5754c814a4 100644 --- a/tests/verbs/test_create_communities.py +++ b/tests/verbs/test_create_communities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS from graphrag.index.workflows.create_communities import ( run_workflow, @@ -26,7 +26,7 @@ async def test_create_communities(): ], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore await run_workflow( config, diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index 3947d4dc58..d479120ce2 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -2,7 +2,7 @@ # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS from graphrag.index.operations.summarize_communities.community_reports_extractor import ( CommunityReportResponse, @@ -50,7 +50,7 @@ async def test_create_community_reports(): ] ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.models["default_chat_model"].type = "mock_chat" config.models["default_chat_model"].responses = MOCK_RESPONSES # type: ignore diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index 89ff2d7c5c..05f4b60691 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS from graphrag.index.workflows.create_final_documents import ( run_workflow, @@ -24,7 +24,7 @@ async def test_create_final_documents(): storage=["text_units"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore await run_workflow(config, context) @@ -41,7 +41,7 @@ async def test_create_final_documents_with_metadata_column(): storage=["text_units"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.input.metadata = ["title"] # simulate the metadata construction during initial input loading diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 979f48d5b4..d8d7686d41 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS from graphrag.index.workflows.create_final_text_units import ( run_workflow, @@ -28,7 +28,7 @@ async def test_create_final_text_units(): ], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.extract_claims.enabled = True await run_workflow(config, context) diff --git a/tests/verbs/test_extract_covariates.py b/tests/verbs/test_extract_covariates.py index d4645a34f9..3351f04273 100644 --- a/tests/verbs/test_extract_covariates.py +++ b/tests/verbs/test_extract_covariates.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import ModelType +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS from graphrag.index.workflows.extract_covariates import ( run_workflow, @@ -30,7 +30,7 @@ async def test_extract_covariates(): storage=["text_units"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.extract_claims.enabled = True config.extract_claims.description = "description" llm_settings = config.get_language_model_config(config.extract_claims.model_id) diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 28dbfcf3c0..7bdc38e4a2 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -1,8 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import ModelType +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.workflows.extract_graph import ( run_workflow, ) @@ -43,12 +43,12 @@ async def test_extract_graph(): extraction_model = DEFAULT_CHAT_MODEL_CONFIG.copy() extraction_model["type"] = ModelType.MockChat extraction_model["responses"] = MOCK_LLM_ENTITY_RESPONSES # type: ignore - config = create_graphrag_config({ - "models": { + config = GraphRagConfig( + models={ "default_chat_model": extraction_model, "default_embedding_model": DEFAULT_EMBEDDING_MODEL_CONFIG, - } - }) + } # type: ignore + ) summarize_llm_settings = config.get_language_model_config( config.summarize_descriptions.model_id diff --git a/tests/verbs/test_extract_graph_nlp.py b/tests/verbs/test_extract_graph_nlp.py index 92da89288c..0c1f54c66e 100644 --- a/tests/verbs/test_extract_graph_nlp.py +++ b/tests/verbs/test_extract_graph_nlp.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.workflows.extract_graph_nlp import ( run_workflow, ) @@ -18,7 +18,7 @@ async def test_extract_graph_nlp(): storage=["text_units"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore await run_workflow(config, context) diff --git a/tests/verbs/test_finalize_graph.py b/tests/verbs/test_finalize_graph.py index 0950738ae7..cce5adad94 100644 --- a/tests/verbs/test_finalize_graph.py +++ b/tests/verbs/test_finalize_graph.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import ( ENTITIES_FINAL_COLUMNS, RELATIONSHIPS_FINAL_COLUMNS, @@ -21,7 +21,7 @@ async def test_finalize_graph(): context = await _prep_tables() - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore await run_workflow(config, context) diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 33254874e7..e9e6c00afa 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -1,11 +1,11 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.embeddings import ( all_embeddings, ) from graphrag.config.enums import ModelType +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.workflows.generate_text_embeddings import ( run_workflow, ) @@ -28,7 +28,7 @@ async def test_generate_text_embeddings(): ] ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore llm_settings = config.get_language_model_config(config.embed_text.model_id) llm_settings.type = ModelType.MockEmbedding diff --git a/tests/verbs/test_pipeline_state.py b/tests/verbs/test_pipeline_state.py index 306cd84256..2e41ecc4af 100644 --- a/tests/verbs/test_pipeline_state.py +++ b/tests/verbs/test_pipeline_state.py @@ -3,7 +3,6 @@ """Tests for pipeline state passthrough.""" -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import create_run_context from graphrag.index.typing.context import PipelineRunContext @@ -32,7 +31,7 @@ async def test_pipeline_state(): PipelineFactory.register("workflow_1", run_workflow_1) PipelineFactory.register("workflow_2", run_workflow_2) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.workflows = ["workflow_1", "workflow_2"] context = create_run_context() @@ -45,7 +44,7 @@ async def test_pipeline_state(): async def test_pipeline_existing_state(): PipelineFactory.register("workflow_2", run_workflow_2) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.workflows = ["workflow_2"] context = create_run_context(state={"count": 4}) diff --git a/tests/verbs/test_prune_graph.py b/tests/verbs/test_prune_graph.py index 6ed0001973..426230161a 100644 --- a/tests/verbs/test_prune_graph.py +++ b/tests/verbs/test_prune_graph.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.index.workflows.prune_graph import ( run_workflow, @@ -19,7 +19,7 @@ async def test_prune_graph(): storage=["entities", "relationships"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) + config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore config.prune_graph = PruneGraphConfig( min_node_freq=4, min_node_degree=0, min_edge_weight_pct=0 ) diff --git a/unified-search-app/app/knowledge_loader/data_sources/blob_source.py b/unified-search-app/app/knowledge_loader/data_sources/blob_source.py index f54ed71854..63774d79f7 100644 --- a/unified-search-app/app/knowledge_loader/data_sources/blob_source.py +++ b/unified-search-app/app/knowledge_loader/data_sources/blob_source.py @@ -13,7 +13,6 @@ import yaml from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient, ContainerClient -from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.models.graph_rag_config import GraphRagConfig from knowledge_loader.data_sources.typing import Datasource @@ -114,7 +113,7 @@ def read_settings( str_settings = settings.read().decode("utf-8") config = os.path.expandvars(str_settings) settings_yaml = yaml.safe_load(config) - graphrag_config = create_graphrag_config(values=settings_yaml) + graphrag_config = GraphRagConfig(**settings_yaml) except Exception as err: if throw_on_missing: error_msg = f"File {file} does not exist" diff --git a/uv.lock b/uv.lock index 761cd308f0..6ef295df9d 100644 --- a/uv.lock +++ b/uv.lock @@ -1055,8 +1055,6 @@ dependencies = [ { name = "pandas" }, { name = "pyarrow" }, { name = "pydantic" }, - { name = "python-dotenv" }, - { name = "pyyaml" }, { name = "spacy" }, { name = "textblob" }, { name = "tiktoken" }, @@ -1086,8 +1084,6 @@ requires-dist = [ { name = "pandas", specifier = ">=2.2.3" }, { name = "pyarrow", specifier = ">=17.0.0" }, { name = "pydantic", specifier = ">=2.10.3" }, - { name = "python-dotenv", specifier = ">=1.0.1" }, - { name = "pyyaml", specifier = ">=6.0.2" }, { name = "spacy", specifier = ">=3.8.4" }, { name = "textblob", specifier = ">=0.18.0.post0" }, { name = "tiktoken", specifier = ">=0.11.0" }, @@ -1100,6 +1096,16 @@ requires-dist = [ name = "graphrag-common" version = "2.7.0" source = { editable = "packages/graphrag-common" } +dependencies = [ + { name = "python-dotenv" }, + { name = "pyyaml" }, +] + +[package.metadata] +requires-dist = [ + { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "pyyaml", specifier = ">=6.0.2" }, +] [[package]] name = "graphrag-monorepo"