From 8d5813534a9c66d493eaef633f79a41326861f5c Mon Sep 17 00:00:00 2001 From: noah Date: Sat, 4 Apr 2026 23:04:54 +0800 Subject: [PATCH 1/2] Preserve Tinker user metadata when reloading checkpoints --- src/art/tinker/service.py | 4 +- tests/unit/test_tinker_service.py | 74 +++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_tinker_service.py diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 90e704ebe..5de927539 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -199,17 +199,17 @@ async def _get_state(self) -> "TinkerState": assert config is not None, "Tinker args are required" service_client = tinker.ServiceClient() rest_client = service_client.create_rest_client() + training_client_args = dict(config.get("training_client_args", {})) checkpoint_dir = self._get_last_checkpoint_dir() if checkpoint_dir: info = yaml.safe_load(open(checkpoint_dir / "info.yaml", "r")) with log_timing("Creating Tinker training client from checkpoint"): training_client = await service_client.create_training_client_from_state_with_optimizer_async( path=info["state_with_optimizer_path"], - user_metadata=config.get("user_metadata", None), + user_metadata=training_client_args.get("user_metadata"), ) else: with log_timing("Creating Tinker training client"): - training_client_args = config.get("training_client_args", {}) if "rank" not in training_client_args: training_client_args["rank"] = 8 if "train_unembed" not in training_client_args: diff --git a/tests/unit/test_tinker_service.py b/tests/unit/test_tinker_service.py new file mode 100644 index 000000000..021f7e99c --- /dev/null +++ b/tests/unit/test_tinker_service.py @@ -0,0 +1,74 @@ +import pytest +import yaml + + +@pytest.fixture +def tinker_service_module(): + try: + import art.tinker.service as service_module + + return service_module + except ImportError as e: + pytest.skip(f"Tinker dependencies not available: {e}") + + +@pytest.mark.asyncio +async def test_get_state_reuses_nested_user_metadata_from_training_client_args( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + tinker_service_module, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0001" + checkpoint_dir.mkdir(parents=True) + info_path = checkpoint_dir / "info.yaml" + info_path.write_text( + yaml.safe_dump( + { + "state_with_optimizer_path": "tinker://state/0001", + "sampler_weights_path": "tinker://sampler/0001", + } + ) + ) + + observed: dict[str, object] = {} + fake_training_client = object() + + class FakeServiceClient: + def create_rest_client(self) -> object: + return object() + + async def create_training_client_from_state_with_optimizer_async( + self, + *, + path: str, + user_metadata: dict[str, str] | None = None, + ) -> object: + observed["path"] = path + observed["user_metadata"] = user_metadata + return fake_training_client + + monkeypatch.setattr( + tinker_service_module.tinker, + "ServiceClient", + FakeServiceClient, + ) + + service = tinker_service_module.TinkerService( + model_name="test-model", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config={ + "tinker_args": { + "renderer_name": "qwen3_5", + "training_client_args": { + "user_metadata": {"tenant": "test-tenant"}, + }, + } + }, + output_dir=str(tmp_path), + ) + + state = await service._get_state() + + assert observed["path"] == "tinker://state/0001" + assert observed["user_metadata"] == {"tenant": "test-tenant"} + assert state.training_client is fake_training_client From 980a6d3614682e99887c277247a081d5a43dcbeb Mon Sep 17 00:00:00 2001 From: noah Date: Sat, 4 Apr 2026 23:12:11 +0800 Subject: [PATCH 2/2] Make Tinker metadata regression test self-contained --- tests/unit/test_tinker_service.py | 136 ++++++++++++++++++++++++++---- 1 file changed, 120 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_tinker_service.py b/tests/unit/test_tinker_service.py index 021f7e99c..6d0674e17 100644 --- a/tests/unit/test_tinker_service.py +++ b/tests/unit/test_tinker_service.py @@ -1,28 +1,131 @@ +import asyncio +import importlib.util +import json +from pathlib import Path +import sys +import types + import pytest -import yaml -@pytest.fixture -def tinker_service_module(): - try: - import art.tinker.service as service_module +def _install_stub(monkeypatch: pytest.MonkeyPatch, name: str, module: types.ModuleType): + monkeypatch.setitem(sys.modules, name, module) + + +def _load_tinker_service_module(monkeypatch: pytest.MonkeyPatch): + repo_root = Path(__file__).resolve().parents[2] + service_path = repo_root / "src" / "art" / "tinker" / "service.py" + + art_pkg = types.ModuleType("art") + art_pkg.__path__ = [str(repo_root / "src" / "art")] # type: ignore[attr-defined] + _install_stub(monkeypatch, "art", art_pkg) + + tinker_pkg = types.ModuleType("art.tinker") + tinker_pkg.__path__ = [str(repo_root / "src" / "art" / "tinker")] # type: ignore[attr-defined] + _install_stub(monkeypatch, "art.tinker", tinker_pkg) + + preprocessing_pkg = types.ModuleType("art.preprocessing") + preprocessing_pkg.__path__ = [str(repo_root / "src" / "art" / "preprocessing")] # type: ignore[attr-defined] + _install_stub(monkeypatch, "art.preprocessing", preprocessing_pkg) + + dev_mod = types.ModuleType("art.dev") + dev_mod.InternalModelConfig = dict + dev_mod.OpenAIServerConfig = dict + dev_mod.TrainConfig = dict + _install_stub(monkeypatch, "art.dev", dev_mod) + + types_mod = types.ModuleType("art.types") + types_mod.TrainConfig = dict + _install_stub(monkeypatch, "art.types", types_mod) + + loss_mod = types.ModuleType("art.loss") + loss_mod.loss_fn = lambda *args, **kwargs: None + loss_mod.shift_tensor = lambda tensor, _: tensor + _install_stub(monkeypatch, "art.loss", loss_mod) + + inputs_mod = types.ModuleType("art.preprocessing.inputs") + inputs_mod.TrainInputs = dict + inputs_mod.create_train_inputs = lambda *args, **kwargs: {} + _install_stub(monkeypatch, "art.preprocessing.inputs", inputs_mod) + + pack_mod = types.ModuleType("art.preprocessing.pack") + pack_mod.DiskPackedTensors = dict + pack_mod.packed_tensors_from_dir = lambda **kwargs: kwargs + _install_stub(monkeypatch, "art.preprocessing.pack", pack_mod) + + server_mod = types.ModuleType("art.tinker.server") + server_mod.OpenAICompatibleTinkerServer = type( + "OpenAICompatibleTinkerServer", (), {} + ) + _install_stub(monkeypatch, "art.tinker.server", server_mod) - return service_module - except ImportError as e: - pytest.skip(f"Tinker dependencies not available: {e}") + yaml_mod = types.ModuleType("yaml") + def safe_load(stream_or_text): + if hasattr(stream_or_text, "read"): + return json.loads(stream_or_text.read()) + return json.loads(stream_or_text) -@pytest.mark.asyncio -async def test_get_state_reuses_nested_user_metadata_from_training_client_args( + def safe_dump(data, stream=None): + text = json.dumps(data) + if stream is None: + return text + stream.write(text) + return None + + yaml_mod.safe_load = safe_load + yaml_mod.safe_dump = safe_dump + _install_stub(monkeypatch, "yaml", yaml_mod) + + torch_mod = types.ModuleType("torch") + torch_mod.Tensor = type("Tensor", (), {}) + torch_mod.float32 = "float32" + _install_stub(monkeypatch, "torch", torch_mod) + + tinker_mod = types.ModuleType("tinker") + tinker_mod.ServiceClient = type("ServiceClient", (), {}) + tinker_mod.TrainingClient = type("TrainingClient", (), {}) + tinker_mod.Datum = type("Datum", (), {}) + tinker_mod.TensorData = type( + "TensorData", (), {"from_torch": staticmethod(lambda value: value)} + ) + tinker_mod.ModelInput = type( + "ModelInput", (), {"from_ints": staticmethod(lambda ints: ints)} + ) + tinker_mod.AdamParams = type("AdamParams", (), {}) + _install_stub(monkeypatch, "tinker", tinker_mod) + + tinker_lib_pkg = types.ModuleType("tinker.lib") + tinker_lib_pkg.__path__ = [] # type: ignore[attr-defined] + _install_stub(monkeypatch, "tinker.lib", tinker_lib_pkg) + + public_interfaces_pkg = types.ModuleType("tinker.lib.public_interfaces") + public_interfaces_pkg.__path__ = [] # type: ignore[attr-defined] + _install_stub(monkeypatch, "tinker.lib.public_interfaces", public_interfaces_pkg) + + rest_client_mod = types.ModuleType("tinker.lib.public_interfaces.rest_client") + rest_client_mod.RestClient = type("RestClient", (), {}) + _install_stub(monkeypatch, "tinker.lib.public_interfaces.rest_client", rest_client_mod) + + spec = importlib.util.spec_from_file_location("art.tinker.service", service_path) + assert spec is not None and spec.loader is not None + service_module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "art.tinker.service", service_module) + spec.loader.exec_module(service_module) + return service_module + + +def test_get_state_reuses_nested_user_metadata_from_training_client_args( monkeypatch: pytest.MonkeyPatch, - tmp_path, - tinker_service_module, + tmp_path: Path, ) -> None: + service_module = _load_tinker_service_module(monkeypatch) + checkpoint_dir = tmp_path / "checkpoints" / "0001" checkpoint_dir.mkdir(parents=True) info_path = checkpoint_dir / "info.yaml" info_path.write_text( - yaml.safe_dump( + json.dumps( { "state_with_optimizer_path": "tinker://state/0001", "sampler_weights_path": "tinker://sampler/0001", @@ -48,12 +151,12 @@ async def create_training_client_from_state_with_optimizer_async( return fake_training_client monkeypatch.setattr( - tinker_service_module.tinker, + service_module.tinker, "ServiceClient", FakeServiceClient, ) - service = tinker_service_module.TinkerService( + service = service_module.TinkerService( model_name="test-model", base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", config={ @@ -67,8 +170,9 @@ async def create_training_client_from_state_with_optimizer_async( output_dir=str(tmp_path), ) - state = await service._get_state() + state = asyncio.run(service._get_state()) assert observed["path"] == "tinker://state/0001" assert observed["user_metadata"] == {"tenant": "test-tenant"} assert state.training_client is fake_training_client + assert state.models["test-model"] == "tinker://sampler/0001"