Skip to content

Commit 8ffa429

Browse files
committed
fix: remove tensor patch product leakage
1 parent 6a04db7 commit 8ffa429

8 files changed

Lines changed: 219 additions & 28 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [0.1.0] - 2026-05-23
99

1010
### Added
11-
- Initial release of `CrucibleTensorPatch` under the TRINITY decomposition.
11+
- Initial release of `CrucibleTensorPatch` from the monolith extraction.

MIGRATION.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Migration Notes
22

3-
This repo was scaffolded during TRINITY decomposition Phase 1 so
4-
`trinity_framework` could resolve local path dependencies before the public
5-
GitHub repos existed.
3+
This repo was scaffolded during the monolith extraction so framework consumers
4+
could resolve local path dependencies before the public GitHub repos existed.
65

76
Source material for the Phase 4 implementation:
87

README.md

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ identity and SVF patch operations, manifest/checksum emission, resume/force
2020
rules, backend label round trips, and stage comparison. It intentionally avoids
2121
provider, tracing, and orchestration dependencies.
2222

23+
## What It Provides
24+
25+
- `CrucibleTensorPatch.Plan` loads and validates patch plans from maps or JSON
26+
files.
27+
- `CrucibleTensorPatch.Apply` applies identity and SVF patch operations and
28+
writes deterministic safetensors outputs.
29+
- `CrucibleTensorPatch.ParamTree` patches nested parameter trees using manifest
30+
selected-tensor entries.
31+
- `CrucibleTensorPatch.Manifest` emits deterministic operation manifests.
32+
- `CrucibleTensorPatch.BackendLabel` round-trips known Nx/EXLA backend labels.
33+
- `CrucibleTensorPatch.StageCheck` compares stage reports using shared tensor
34+
comparison behavior.
35+
36+
The package does not fetch models, load provider credentials, run coordination
37+
loops, or own application runtime configuration.
38+
2339
## Installation
2440

2541
If [available in Hex](https://hex.pm/docs/publish), the package can be installed
@@ -37,6 +53,112 @@ Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_do
3753
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
3854
be found at <https://hexdocs.pm/crucible_tensor_patch>.
3955

56+
## Plan Schema
57+
58+
A plan contains a `schema` string and an ordered `operations` list:
59+
60+
```elixir
61+
plan_doc = %{
62+
"schema" => "example.v1",
63+
"operations" => [
64+
%{
65+
"id" => "copy_layer",
66+
"operation" => "identity",
67+
"source_path" => "layers.0.kernel",
68+
"output_path" => "layers/0000_kernel.safetensors",
69+
"expected_shape" => [2, 2],
70+
"expected_dtype" => "f32"
71+
}
72+
]
73+
}
74+
75+
{:ok, plan} = CrucibleTensorPatch.Plan.load(plan_doc)
76+
```
77+
78+
Supported operations are:
79+
80+
- `"identity"`: copies a source tensor to a safetensors output.
81+
- `"svf_apply"`: reconstructs a tensor from SVD/SVF components and scale
82+
offsets before writing the output.
83+
84+
Supported dtype strings are `bf16`, `f16`, `f32`, `i32`, and `i64`.
85+
86+
## Applying Identity Operations
87+
88+
```elixir
89+
source = %{
90+
"layers.0.kernel" => Nx.tensor([[1.0, 2.0], [3.0, 4.0]], type: :f32)
91+
}
92+
93+
{:ok, manifest} =
94+
CrucibleTensorPatch.Apply.apply(
95+
plan,
96+
source,
97+
"out/patch",
98+
force: false
99+
)
100+
```
101+
102+
The result manifest includes operation status, output paths, skip state, and
103+
SHA-256 checksums. `manifest.json` is written by default.
104+
105+
## Applying SVF Operations
106+
107+
SVF operations reference component tensors by name:
108+
109+
```elixir
110+
operation = %{
111+
"id" => "svf_layer",
112+
"operation" => "svf_apply",
113+
"source_path" => "layers.0.kernel",
114+
"output_path" => "layers/0000_kernel.safetensors",
115+
"inputs" => %{
116+
"u" => "layer_0_u",
117+
"s" => "layer_0_s",
118+
"v" => "layer_0_v",
119+
"scale_offsets" => "layer_0_offsets"
120+
},
121+
"expected_shape" => [2, 2],
122+
"expected_dtype" => "f32"
123+
}
124+
```
125+
126+
Pass the component tensors through the `:components` option:
127+
128+
```elixir
129+
CrucibleTensorPatch.Apply.apply(plan, source, "out/patch",
130+
components: %{
131+
"layer_0_u" => u,
132+
"layer_0_s" => s,
133+
"layer_0_v" => v,
134+
"layer_0_offsets" => offsets
135+
}
136+
)
137+
```
138+
139+
## Patching Parameter Trees
140+
141+
```elixir
142+
patched =
143+
CrucibleTensorPatch.Apply.patch_params!(
144+
params,
145+
manifest,
146+
tensors,
147+
cast_tensors: true
148+
)
149+
```
150+
151+
The patcher accepts maps or structs with a `:data` field. Manifest entries may
152+
include explicit `segments`; otherwise the path string is split into traversal
153+
segments.
154+
155+
## Resume And Force
156+
157+
When an operation has `expected_output_sha256` and the output file already
158+
exists, the applier verifies the checksum and skips completed output. A checksum
159+
mismatch raises. Pass `force: true` to rewrite outputs regardless of existing
160+
files.
161+
40162
## CI
41163

42164
```sh
@@ -46,6 +168,10 @@ mix ci
46168
Large local fixture checks are opt-in:
47169

48170
```sh
49-
TRINITY_ARTIFACT_DIR=~/p/g/n/trinity_coordinator/priv/sakana_trinity/adapted_qwen3_0_6b_layer26 \
50-
mix test --only large_tensor_patch
171+
mkdir -p tmp
172+
ln -s /path/to/artifact_bundle tmp/crucible_tensor_patch_fixture
173+
mix test --only large_tensor_patch
51174
```
175+
176+
`mix ci` runs dependency fetch, format check, warning-as-error compile, tests,
177+
Credo strict, Dialyzer, and docs generation.

lib/crucible_tensor_patch.ex

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
11
defmodule CrucibleTensorPatch do
22
@moduledoc """
3-
Documentation for `CrucibleTensorPatch`.
3+
Deterministic tensor patch plans and patch application for model surgery.
44
"""
5-
6-
@doc """
7-
Hello world.
8-
9-
## Examples
10-
11-
iex> CrucibleTensorPatch.hello()
12-
:world
13-
14-
"""
15-
def hello do
16-
:world
17-
end
185
end

lib/crucible_tensor_patch/apply.ex

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ defmodule CrucibleTensorPatch.Apply do
55
alias CrucibleSafetensors.{Checksum, Writer}
66
alias CrucibleTensorPatch.{Errors, Manifest, Operation, ParamTree, Plan, TensorPath}
77

8+
@input_atom_keys %{
9+
"scale_offsets" => :scale_offsets,
10+
"s" => :s,
11+
"u" => :u,
12+
"v" => :v
13+
}
14+
815
@doc "Applies a plan to source tensors and writes output safetensors."
916
@spec apply(Plan.t(), map(), Path.t(), keyword()) :: {:ok, map()} | {:error, Exception.t()}
1017
def apply(%Plan{} = plan, source_artifact, out_dir, opts \\ []) when is_map(source_artifact) do
@@ -116,7 +123,7 @@ defmodule CrucibleTensorPatch.Apply do
116123
end
117124

118125
defp fetch_input!(inputs, components, name) do
119-
value = Map.get(inputs, name) || Map.get(inputs, String.to_atom(name))
126+
value = input_value(inputs, name)
120127

121128
cond do
122129
match?(%Nx.Tensor{}, value) -> value
@@ -125,6 +132,19 @@ defmodule CrucibleTensorPatch.Apply do
125132
end
126133
end
127134

135+
defp input_value(inputs, name) do
136+
case Map.fetch(inputs, name) do
137+
{:ok, value} ->
138+
value
139+
140+
:error ->
141+
case Map.fetch(@input_atom_keys, name) do
142+
{:ok, atom_key} -> Map.get(inputs, atom_key)
143+
:error -> nil
144+
end
145+
end
146+
end
147+
128148
defp validate_tensor!(operation, %Nx.Tensor{} = tensor) do
129149
expected_shape = normalize_shape(operation.expected_shape)
130150

lib/crucible_tensor_patch/plan.ex

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ defmodule CrucibleTensorPatch.Plan do
55

66
@operations [:identity, :svf_apply]
77
@operation_strings %{"identity" => :identity, "svf_apply" => :svf_apply}
8+
@dtype_strings %{
9+
"bf16" => :bf16,
10+
"f16" => :f16,
11+
"f32" => :f32,
12+
"i32" => :i32,
13+
"i64" => :i64
14+
}
815
@field_atom_keys %{
916
"checksum_policy" => :checksum_policy,
1017
"expected_dtype" => :expected_dtype,
@@ -98,8 +105,14 @@ defmodule CrucibleTensorPatch.Plan do
98105
defp normalize_dtype(nil), do: nil
99106
defp normalize_dtype(dtype) when is_atom(dtype), do: dtype
100107

101-
defp normalize_dtype(dtype) when is_binary(dtype),
102-
do: dtype |> String.downcase() |> String.to_existing_atom()
108+
defp normalize_dtype(dtype) when is_binary(dtype) do
109+
normalized = String.downcase(dtype)
110+
111+
case Map.fetch(@dtype_strings, normalized) do
112+
{:ok, atom} -> atom
113+
:error -> raise Errors, "unsupported dtype #{inspect(dtype)}"
114+
end
115+
end
103116

104117
defp normalize_checksum(nil), do: nil
105118
defp normalize_checksum(policy) when is_atom(policy), do: policy

test/crucible_tensor_patch/large_tensor_patch_test.exs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@ defmodule CrucibleTensorPatch.LargeTensorPatchTest do
55

66
@tag :large_tensor_patch
77
test "fixture artifact manifest can seed a patch plan" do
8-
artifact_dir =
9-
System.get_env(
10-
"TRINITY_ARTIFACT_DIR",
11-
Path.expand("~/p/g/n/trinity_coordinator/priv/sakana_trinity/adapted_qwen3_0_6b_layer26")
12-
)
8+
artifact_dir = Path.expand("tmp/crucible_tensor_patch_fixture")
139

1410
manifest_path = Path.join(artifact_dir, "manifest.json")
1511
assert File.regular?(manifest_path)

test/crucible_tensor_patch/plan_apply_test.exs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ defmodule CrucibleTensorPatch.PlanApplyTest do
2424
assert operation.expected_dtype == :f32
2525
end
2626

27+
test "Plan.load/1 normalizes supported dtype strings without dynamic atoms" do
28+
assert {:ok, plan} =
29+
Plan.load(%{
30+
"operations" => [
31+
identity_op("copy", "layer.kernel", "copy.safetensors")
32+
|> Map.put("expected_dtype", "BF16")
33+
]
34+
})
35+
36+
assert [operation] = plan.operations
37+
assert operation.expected_dtype == :bf16
38+
end
39+
2740
test "Plan.load/1 rejects unknown operations" do
2841
assert {:error, %Errors{message: message}} =
2942
Plan.load(%{
@@ -85,6 +98,43 @@ defmodule CrucibleTensorPatch.PlanApplyTest do
8598
assert [%{"status" => "complete"}] = manifest["operations"]
8699
end
87100

101+
test "Apply.apply/4 accepts fixed atom input keys" do
102+
dir = tmp_dir()
103+
source = %{"layer.kernel" => Nx.tensor([[2.0, 4.0], [1.0, 2.0]], type: :f32)}
104+
105+
decomp =
106+
CrucibleFactorization.SVD.decompose_tensor(source["layer.kernel"], compute_type: :f32)
107+
108+
offsets = Nx.broadcast(0.0, {Nx.axis_size(decomp.s, 0)})
109+
110+
{:ok, plan} =
111+
Plan.load(%{
112+
"operations" => [
113+
%{
114+
"id" => "svf",
115+
"operation" => "svf_apply",
116+
"source_path" => "layer.kernel",
117+
"output_path" => "svf.safetensors",
118+
"inputs" => %{u: "u", s: "s", v: "v", scale_offsets: "offsets"},
119+
"expected_shape" => [2, 2],
120+
"expected_dtype" => "f32"
121+
}
122+
]
123+
})
124+
125+
assert {:ok, manifest} =
126+
Apply.apply(plan, source, dir,
127+
components: %{
128+
"u" => decomp.u,
129+
"s" => decomp.s,
130+
"v" => decomp.v,
131+
"offsets" => offsets
132+
}
133+
)
134+
135+
assert [%{"status" => "complete"}] = manifest["operations"]
136+
end
137+
88138
test "Apply.apply/4 supports resume and force" do
89139
dir = tmp_dir()
90140
source = %{"layer.kernel" => Nx.tensor([[1.0, 2.0]], type: :f32)}

0 commit comments

Comments
 (0)