Skip to content

Commit e99a766

Browse files
committed
Use env credentials instead of passing them through
1 parent cc1dfc9 commit e99a766

13 files changed

Lines changed: 46 additions & 47 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfi
324324

325325
# 2. Get overrides from CLI and kwargs
326326
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
327+
if "hf_access_token" in cli_cfg:
328+
logger.warning(
329+
"WARNING: Passing 'hf_access_token' via command-line arguments is deprecated and insecure because it makes "
330+
"your token visible in 'ps' and shell history. Please set the 'HF_TOKEN' environment variable instead."
331+
)
327332
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
328333
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
329334

src/maxtext/experimental/agent/ckpt_conversion_agent/README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ The agent is used to automate the model-specific mappings of checkpoint conversi
44
## Quick starts
55
To begin, you'll need:
66

7-
1. A Google atccount.
8-
2. An API key (create one in [Google AI Studio](https://aistudio.google.com/app/apikey)).
7+
1. A Google account.
8+
2. An API key (create one in [Google AI Studio](https://aistudio.google.com/app/apikey)), and set it as an environment variable:
9+
```bash
10+
export GEMINI_API_KEY="<Your-API-KEY>"
11+
```
912
3. Install the Google Generative AI Python library:
1013
```
1114
pip install -q -U "google-genai>=1.0.0"
@@ -30,7 +33,7 @@ After it, you can get two `*.json` files in `config.base_output_directory` folde
3033

3134
```bash
3235
python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step1 --target_model=<model_name> \
33-
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key=<Your-API-KEY>
36+
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent
3437
```
3538

3639
Our engineer should check the `src/maxtext/experimental/agent/ckpt_conversion_agent/outputs/proposed_dsl.txt` for potential new DSL and assess if it's needed. Then we need to add this ops into `src/maxtext/experimental/agent/ckpt_conversion_agent/context/dsl.txt`.
@@ -39,7 +42,7 @@ Our engineer should check the `src/maxtext/experimental/agent/ckpt_conversion_ag
3942

4043
```bash
4144
python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step2 --target_model=<model_name> \
42-
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key=<Your-API-KEY>
45+
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent
4346
```
4447

4548
## Evaluation and Debugging
@@ -53,7 +56,7 @@ You can automatically verify the output by comparing the generated code against
5356

5457
```bash
5558
python3 -m maxtext.experimental.agent.ckpt_conversion_agent.evaluation --files ground_truth/<model>.py \
56-
outputs/hook_fn.py --api_key=<Your-API-KEY> --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent
59+
outputs/hook_fn.py --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent
5760
```
5861

5962
### Manual Debugging (No Ground-Truth Code)
@@ -121,5 +124,5 @@ Run the [One-shot agent Jyputer notebook](./baselines/one-shot-agent.ipynb)
121124
### Prompt-chain Agent:
122125
```bash
123126
python3 -m maxtext.experimental.agent.ckpt_conversion_agent.prompt_chain --target_model=<model_name> \
124-
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key=<Your-API-KEY>
127+
--dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent
125128
```

src/maxtext/experimental/agent/ckpt_conversion_agent/analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ class AnalysisAgent(BaseAgent):
2727
conversion script, with verification that every parameter is mapped.
2828
"""
2929

30-
def __init__(self, api_key, dir_path, target_model="gemma3", max_retries=3):
30+
def __init__(self, dir_path, target_model="gemma3", max_retries=3):
3131
"""
3232
Initializes the PlanAgent.
3333
3434
Args:
3535
target_model (str): The target model for conversion.
3636
max_retries (int): The maximum number of retries for generation.
3737
"""
38-
super().__init__(api_key)
38+
super().__init__()
3939

4040
self.target_model = target_model
4141
self.max_retries = max_retries

src/maxtext/experimental/agent/ckpt_conversion_agent/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
A base class for other agents
1717
"""
1818

19+
import os
20+
1921
from google import genai
2022
from google.genai.types import GenerateContentConfig
2123

@@ -25,16 +27,14 @@
2527
class BaseAgent:
2628
"""A base class for agents that provides text generation capabilities."""
2729

28-
def __init__(self, api_key, model_id=MODEL_ID):
30+
def __init__(self, model_id=MODEL_ID):
2931
"""
3032
Initializes the BaseAgent with a genai client.
31-
32-
Args:
33-
client: An initialized genai.Client object.
3433
"""
35-
if not api_key:
36-
raise ValueError("A valid api_key must be provided.")
37-
client = genai.Client(api_key=api_key)
34+
resolved_api_key = os.environ.get("GEMINI_API_KEY")
35+
if not resolved_api_key:
36+
raise ValueError("GEMINI_API_KEY environment variable is not set.")
37+
client = genai.Client(api_key=resolved_api_key)
3838
self.client = client
3939
self.model_id = model_id
4040

src/maxtext/experimental/agent/ckpt_conversion_agent/dsl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
class DSLAgent(BaseAgent):
2525
"""DSL Agent"""
2626

27-
def __init__(self, api_key, dir_path, target_model="gemma3-4b", max_retries=3):
27+
def __init__(self, dir_path, target_model="gemma3-4b", max_retries=3):
2828
"""
2929
Initializes the DSLAgent.
3030
@@ -33,7 +33,7 @@ def __init__(self, api_key, dir_path, target_model="gemma3-4b", max_retries=3):
3333
max_retries (int): The maximum number of retries for generation.
3434
"""
3535
# Initialize the parent BaseAgent with the client
36-
super().__init__(api_key)
36+
super().__init__()
3737

3838
self.target_model = target_model
3939
self.max_retries = max_retries
@@ -83,7 +83,6 @@ def verify_dsl(self):
8383
parser.add_argument(
8484
"--dir_path", type=str, required=True, help='The file path to the context directory (e.g., "context/gemma3").'
8585
)
86-
parser.add_argument("--api_key", type=str, help="Optional API key for external services.")
8786
args = parser.parse_args()
88-
agent = DSLAgent(api_key=args.api_key, dir_path=args.dir_path, target_model=TARGET_MODEL)
87+
agent = DSLAgent(dir_path=args.dir_path, target_model=TARGET_MODEL)
8988
global_verification_dsl = agent.verify_dsl()

src/maxtext/experimental/agent/ckpt_conversion_agent/evaluation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@ def main():
6666
description="Gemini evaluate the agent code implementation against human-written ground truth code"
6767
)
6868
parser.add_argument("--files", nargs=2, help="Paths to code files to analyze.")
69-
parser.add_argument("--api_key", type=str, help="API key.")
7069
parser.add_argument("--dir_path", type=str, help="Directory path.")
7170

7271
args = parser.parse_args()
7372

74-
baseAgent = BaseAgent(api_key=args.api_key)
73+
baseAgent = BaseAgent()
7574
dir_path = args.dir_path
7675

7776
prompt_templates = {

src/maxtext/experimental/agent/ckpt_conversion_agent/mapping.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ class MappingAgent(BaseAgent):
2828
An agent that generates and verifies mapping functions for model conversion.
2929
"""
3030

31-
def __init__(self, api_key, dir_path, target_model="gemma3-4b", max_retries=3):
31+
def __init__(self, dir_path, target_model="gemma3-4b", max_retries=3):
3232
"""
3333
Initializes the MappingAgent.
3434
3535
Args:
3636
target_model (str): The target model for conversion.
3737
max_retries (int): The maximum number of retries for generation.
3838
"""
39-
super().__init__(api_key)
39+
super().__init__()
4040

4141
self.target_model = target_model
4242
self.max_retries = max_retries
@@ -171,9 +171,8 @@ def generate_shape_mapping(self):
171171
parser.add_argument(
172172
"--dir_path", type=str, required=True, help='The file path to the context directory (e.g., "context/gemma3").'
173173
)
174-
parser.add_argument("--api_key", type=str, help="Optional API key for external services.")
175174
args = parser.parse_args()
176-
agent = MappingAgent(api_key=args.api_key, dir_path=args.dir_path, target_model=TARGET_MODEL)
175+
agent = MappingAgent(dir_path=args.dir_path, target_model=TARGET_MODEL)
177176
try:
178177
param_mapping_code = agent.generate_param_mapping()
179178
shape_mapping_code = agent.generate_shape_mapping()

src/maxtext/experimental/agent/ckpt_conversion_agent/plan.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ class PlanAgent(BaseAgent):
2828
conversion script, with verification that every parameter is mapped.
2929
"""
3030

31-
def __init__(self, api_key, dir_path, target_model="gemma3", max_retries=3):
31+
def __init__(self, dir_path, target_model="gemma3", max_retries=3):
3232
"""
3333
Initializes the PlanAgent.
3434
3535
Args:
3636
target_model (str): The target model for conversion.
3737
max_retries (int): The maximum number of retries for generation.
3838
"""
39-
super().__init__(api_key)
39+
super().__init__()
4040

4141
self.target_model = target_model
4242
self.max_retries = max_retries
@@ -109,7 +109,6 @@ def plan_conversion(self):
109109
parser.add_argument(
110110
"--dir_path", type=str, required=True, help='The file path to the context directory (e.g., "context/gemma3").'
111111
)
112-
parser.add_argument("--api_key", type=str, help="Optional API key for external services.")
113112
args = parser.parse_args()
114-
agent = PlanAgent(api_key=args.api_key, dir_path=args.dir_path, target_model=TARGET_MODEL)
113+
agent = PlanAgent(dir_path=args.dir_path, target_model=TARGET_MODEL)
115114
agent.plan_conversion()

src/maxtext/experimental/agent/ckpt_conversion_agent/prompt_chain.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ class prompt_chaining_agent(BaseAgent):
2929
with verification that every parameter is actually mapped.
3030
"""
3131

32-
def __init__(self, api_key, target_model="gemma3", max_retries=3, dir_path="context"):
32+
def __init__(self, target_model="gemma3", max_retries=3, dir_path="context"):
3333
# Initialize the parent BaseAgent with the client
34-
super().__init__(api_key)
34+
super().__init__()
3535
self.target_model = target_model
3636
self.max_retries = max_retries
3737
self.dir_path = dir_path
@@ -178,7 +178,6 @@ def run_chain(self, max_retries=3):
178178

179179
parser.add_argument("--target_model", type=str, required=True, help='The name of the target model (e.g., "GEMMA3").')
180180
parser.add_argument("--dir_path", type=str, required=True, help="The file path to the ckpt conversion agent directory.")
181-
parser.add_argument("--api_key", type=str, required=True, help="Gemini API key.")
182181
args = parser.parse_args()
183-
agent = prompt_chaining_agent(api_key=args.api_key, target_model=args.target_model, dir_path=args.dir_path)
182+
agent = prompt_chaining_agent(target_model=args.target_model, dir_path=args.dir_path)
184183
agent.run_chain()

src/maxtext/experimental/agent/ckpt_conversion_agent/step1.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,15 @@
2828
parser.add_argument(
2929
"--dir_path", type=str, required=True, help='The file path to the context directory (e.g., "context/gemma3").'
3030
)
31-
parser.add_argument("--api_key", type=str, help="Optional API key for external services.")
3231
args = parser.parse_args()
3332

3433
TARGET_MODEL = args.target_model
3534
dir_path = args.dir_path
36-
api_key = args.api_key
3735

38-
analysisAgent = AnalysisAgent(api_key=api_key, dir_path=dir_path, target_model=TARGET_MODEL)
36+
analysisAgent = AnalysisAgent(dir_path=dir_path, target_model=TARGET_MODEL)
3937
analysisAgent.analyze_model_structures()
4038

41-
dslAgent = DSLAgent(api_key=api_key, dir_path=dir_path, target_model=TARGET_MODEL)
39+
dslAgent = DSLAgent(dir_path=dir_path, target_model=TARGET_MODEL)
4240
dslAgent.verify_dsl()
4341

4442
# Human interaction needed,

0 commit comments

Comments
 (0)