Skip to content

Commit 16acecb

Browse files
committed
crucible_tensor_patch: initial repo and tensor patch migration
1 parent 8bb2e35 commit 16acecb

24 files changed

Lines changed: 1179 additions & 16 deletions

.github/workflows/ci.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
8+
jobs:
9+
test:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
- uses: erlef/setup-beam@v1
14+
with:
15+
otp-version: "28.3"
16+
elixir-version: "1.18.4"
17+
- uses: actions/cache@v4
18+
with:
19+
path: |
20+
deps
21+
_build
22+
key: ${{ runner.os }}-mix-${{ hashFiles('mix.lock') }}
23+
restore-keys: ${{ runner.os }}-mix-
24+
- run: mix local.hex --force
25+
- run: mix local.rebar --force
26+
- run: mix ci

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2026 North Shore AI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

MIGRATION.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Migration Notes
2+
3+
This repo was scaffolded during TRINITY decomposition Phase 1 so
4+
`trinity_framework` could resolve local path dependencies before the public
5+
GitHub repos existed.
6+
7+
Source material for the Phase 4 implementation:
8+
9+
- `nshkrdotcom/trinity_coordinator` tag `v0.1.0-monolith`
10+
- source commit `64144a2983950e5fc9f2db2d26323a576c7379a1`
11+
- `lib/trinity_coordinator/runtime/backend_label.ex`
12+
- path traversal and param-tree patching portions of `lib/trinity_coordinator/sakana/artifact.ex`
13+
- patch/export flow from `lib/trinity_coordinator/sakana/exporter.ex`
14+
15+
The implementation keeps the artifact runtime's product-specific loading and
16+
model-head wiring out of this package. This repo owns reusable tensor patch
17+
plans, path traversal, checksum/manifest output, resume behavior, and SVF-based
18+
patch application.

README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# CrucibleTensorPatch
22

3-
**TODO: Add description**
3+
Deterministic tensor patch plans and patch application for model surgery.
4+
5+
The package owns generic patch behavior: plan parsing, tensor path traversal,
6+
identity and SVF patch operations, manifest/checksum emission, resume/force
7+
rules, backend label round trips, and stage comparison. It intentionally avoids
8+
provider, tracing, and orchestration dependencies.
49

510
## Installation
611

@@ -19,3 +24,15 @@ Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_do
1924
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
2025
be found at <https://hexdocs.pm/crucible_tensor_patch>.
2126

27+
## CI
28+
29+
```sh
30+
mix ci
31+
```
32+
33+
Large local fixture checks are opt-in:
34+
35+
```sh
36+
TRINITY_ARTIFACT_DIR=~/p/g/n/trinity_coordinator/priv/sakana_trinity/adapted_qwen3_0_6b_layer26 \
37+
mix test --only large_tensor_patch
38+
```
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
%{
2+
deps: %{
3+
crucible_safetensors: %{
4+
path: "../crucible_safetensors",
5+
github: %{repo: "North-Shore-AI/crucible_safetensors", branch: "main"},
6+
hex: "~> 0.1.0",
7+
default_order: [:path, :github, :hex],
8+
publish_order: [:hex]
9+
},
10+
crucible_factorization: %{
11+
path: "../crucible_factorization",
12+
github: %{repo: "North-Shore-AI/crucible_factorization", branch: "main"},
13+
hex: "~> 0.1.0",
14+
default_order: [:path, :github, :hex],
15+
publish_order: [:hex]
16+
}
17+
}
18+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
defmodule DependencySources do
2+
@moduledoc false
3+
4+
def dep(app, repo_root, extra_opts \\ []) when is_atom(app) and is_binary(repo_root) do
5+
repo_root = Path.expand(repo_root)
6+
config = load_config!(repo_root)
7+
dep_config = config |> Map.fetch!(:deps) |> Map.fetch!(app)
8+
source = select_source!(dep_config, repo_root)
9+
dep_tuple(app, dep_config, source, repo_root, extra_opts)
10+
end
11+
12+
defp load_config!(repo_root) do
13+
repo_root
14+
|> Path.join("build_support/dependency_sources.config.exs")
15+
|> Code.eval_file()
16+
|> elem(0)
17+
end
18+
19+
defp select_source!(config, repo_root) do
20+
order = Map.get(config, :default_order, [:path, :github, :hex])
21+
22+
Enum.find(order, fn
23+
:path -> config[:path] && File.exists?(Path.expand(config[:path], repo_root))
24+
source -> Map.has_key?(config, source)
25+
end) || raise ArgumentError, "no dependency source available"
26+
end
27+
28+
defp dep_tuple(app, config, :path, repo_root, extra_opts) do
29+
{app, Keyword.merge([path: Path.expand(config[:path], repo_root)], extra_opts)}
30+
end
31+
32+
defp dep_tuple(app, config, :github, _repo_root, extra_opts) do
33+
github = Map.fetch!(config, :github)
34+
opts = github |> Map.delete(:repo) |> Map.to_list()
35+
{app, Keyword.merge([github: github.repo], Keyword.merge(opts, extra_opts))}
36+
end
37+
38+
defp dep_tuple(app, config, :hex, _repo_root, extra_opts) do
39+
case extra_opts do
40+
[] -> {app, Map.fetch!(config, :hex)}
41+
opts -> {app, Map.fetch!(config, :hex), opts}
42+
end
43+
end
44+
end
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
defmodule Crucible.TensorPatch.BackendLabel do
2+
@moduledoc "Compatibility namespace for `CrucibleTensorPatch.BackendLabel`."
3+
4+
defdelegate from_label(label), to: CrucibleTensorPatch.BackendLabel
5+
defdelegate from_label!(label), to: CrucibleTensorPatch.BackendLabel
6+
end

lib/crucible_tensor_patch/apply.ex

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
defmodule CrucibleTensorPatch.Apply do
2+
@moduledoc "Applies tensor patch plans to source tensors."
3+
4+
alias Crucible.Factorization.SVD
5+
alias CrucibleSafetensors.{Checksum, Writer}
6+
alias CrucibleTensorPatch.{Errors, Manifest, Operation, ParamTree, Plan, TensorPath}
7+
8+
@doc "Applies a plan to source tensors and writes output safetensors."
9+
@spec apply(Plan.t(), map(), Path.t(), keyword()) :: {:ok, map()} | {:error, Exception.t()}
10+
def apply(%Plan{} = plan, source_artifact, out_dir, opts \\ []) when is_map(source_artifact) do
11+
{:ok, apply!(plan, source_artifact, out_dir, opts)}
12+
rescue
13+
exception -> {:error, exception}
14+
end
15+
16+
@doc "Applies a plan, raising on invalid input."
17+
@spec apply!(Plan.t(), map(), Path.t(), keyword()) :: map()
18+
def apply!(%Plan{} = plan, source_artifact, out_dir, opts \\ []) when is_map(source_artifact) do
19+
opts = Keyword.validate!(opts, components: %{}, force: false, write_manifest?: true)
20+
File.mkdir_p!(out_dir)
21+
22+
reports =
23+
plan.operations
24+
|> Enum.map(fn operation -> apply_operation!(operation, source_artifact, out_dir, opts) end)
25+
26+
manifest = Manifest.build(reports, source: plan.source)
27+
28+
if opts[:write_manifest?] do
29+
Manifest.write!(out_dir, manifest)
30+
end
31+
32+
manifest
33+
end
34+
35+
@doc "Applies manifest-style adapted tensors into a params tree."
36+
defdelegate patch_params!(params, manifest, tensors, opts \\ []), to: ParamTree, as: :patch!
37+
38+
defp apply_operation!(%Operation{} = operation, source_artifact, out_dir, opts) do
39+
output_path = Path.join(out_dir, operation.output_path)
40+
existing = existing_output_status(operation, output_path, opts)
41+
42+
case existing do
43+
{:skip, report} ->
44+
report
45+
46+
:write ->
47+
tensor = build_tensor!(operation, source_artifact, opts[:components])
48+
validate_tensor!(operation, tensor)
49+
write_tensor!(operation, output_path, tensor)
50+
end
51+
end
52+
53+
defp existing_output_status(operation, output_path, opts) do
54+
cond do
55+
opts[:force] ->
56+
:write
57+
58+
File.regular?(output_path) and is_binary(operation.expected_output_sha256) ->
59+
actual = Checksum.file_sha256!(output_path)
60+
61+
if actual == operation.expected_output_sha256 do
62+
{:skip, report(operation, output_path, actual, "complete", true)}
63+
else
64+
raise Errors,
65+
"resume checksum mismatch for #{operation.id}: expected #{operation.expected_output_sha256}, got #{actual}"
66+
end
67+
68+
true ->
69+
:write
70+
end
71+
end
72+
73+
defp build_tensor!(%Operation{operation: :identity} = operation, source_artifact, _components) do
74+
fetch_source_tensor!(source_artifact, operation)
75+
end
76+
77+
defp build_tensor!(%Operation{operation: :svf_apply} = operation, source_artifact, components) do
78+
source = fetch_source_tensor!(source_artifact, operation)
79+
inputs = operation.inputs || %{}
80+
81+
decomposition = %{
82+
u: fetch_input!(inputs, components, "u"),
83+
s: fetch_input!(inputs, components, "s"),
84+
v: fetch_input!(inputs, components, "v")
85+
}
86+
87+
offsets = fetch_input!(inputs, components, "scale_offsets")
88+
89+
decomposition
90+
|> SVD.reconstruct(Nx.as_type(offsets, Nx.type(decomposition.s)))
91+
|> Nx.as_type(Nx.type(source))
92+
end
93+
94+
defp fetch_source_tensor!(source_artifact, operation) do
95+
cond do
96+
operation.source_path && Map.has_key?(source_artifact, operation.source_path) ->
97+
Map.fetch!(source_artifact, operation.source_path)
98+
99+
operation.segments ->
100+
source_artifact
101+
|> TensorPath.fetch!(
102+
TensorPath.normalize(operation.segments, operation.source_path),
103+
operation.source_path
104+
)
105+
106+
operation.source_path ->
107+
source_artifact
108+
|> TensorPath.fetch!(
109+
TensorPath.normalize(nil, operation.source_path),
110+
operation.source_path
111+
)
112+
113+
true ->
114+
raise Errors, "operation #{operation.id} has no source_path"
115+
end
116+
end
117+
118+
defp fetch_input!(inputs, components, name) do
119+
value = Map.get(inputs, name) || Map.get(inputs, String.to_atom(name))
120+
121+
cond do
122+
match?(%Nx.Tensor{}, value) -> value
123+
is_binary(value) -> Map.fetch!(components, value)
124+
true -> raise Errors, "missing input #{inspect(name)}"
125+
end
126+
end
127+
128+
defp validate_tensor!(operation, %Nx.Tensor{} = tensor) do
129+
expected_shape = normalize_shape(operation.expected_shape)
130+
131+
if expected_shape && Nx.shape(tensor) != expected_shape do
132+
raise Errors,
133+
"operation #{operation.id} shape mismatch: expected #{inspect(expected_shape)}, got #{inspect(Nx.shape(tensor))}"
134+
end
135+
136+
if operation.expected_dtype &&
137+
normalize_dtype(Nx.type(tensor)) != normalize_dtype(operation.expected_dtype) do
138+
raise Errors,
139+
"operation #{operation.id} dtype mismatch: expected #{inspect(operation.expected_dtype)}, got #{inspect(Nx.type(tensor))}"
140+
end
141+
142+
:ok
143+
end
144+
145+
defp write_tensor!(operation, output_path, tensor) do
146+
Writer.write!(%{operation.id => tensor_payload(tensor)}, output_path)
147+
checksum = Checksum.file_sha256!(output_path)
148+
report(operation, output_path, checksum, "complete", false)
149+
end
150+
151+
defp tensor_payload(%Nx.Tensor{} = tensor) do
152+
host = Nx.backend_transfer(tensor, Nx.BinaryBackend)
153+
154+
%{
155+
dtype: writer_dtype!(Nx.type(host)),
156+
shape: host |> Nx.shape() |> Tuple.to_list(),
157+
data: Nx.to_binary(host)
158+
}
159+
end
160+
161+
defp report(operation, output_path, checksum, status, skipped?) do
162+
%{
163+
"id" => operation.id,
164+
"operation" => Atom.to_string(operation.operation),
165+
"source_path" => operation.source_path,
166+
"output_path" => output_path,
167+
"status" => status,
168+
"skipped" => skipped?,
169+
"sha256" => checksum
170+
}
171+
end
172+
173+
defp writer_dtype!({:f, 16}), do: :f16
174+
defp writer_dtype!({:bf, 16}), do: :bf16
175+
defp writer_dtype!({:f, 32}), do: :f32
176+
defp writer_dtype!({:s, 32}), do: :i32
177+
defp writer_dtype!({:s, 64}), do: :i64
178+
defp writer_dtype!(type), do: raise(Errors, "unsupported output dtype #{inspect(type)}")
179+
180+
defp normalize_shape(nil), do: nil
181+
defp normalize_shape(shape) when is_tuple(shape), do: shape
182+
defp normalize_shape(shape) when is_list(shape), do: List.to_tuple(shape)
183+
defp normalize_dtype({:f, 16}), do: "f16"
184+
defp normalize_dtype({:bf, 16}), do: "bf16"
185+
defp normalize_dtype({:f, 32}), do: "f32"
186+
defp normalize_dtype({:s, 32}), do: "i32"
187+
defp normalize_dtype({:s, 64}), do: "i64"
188+
defp normalize_dtype(dtype) when is_atom(dtype), do: Atom.to_string(dtype)
189+
defp normalize_dtype(dtype), do: inspect(dtype)
190+
end
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
defmodule CrucibleTensorPatch.BackendLabel do
2+
@moduledoc "Recovers an Nx backend specifier from a stored backend label."
3+
4+
require Logger
5+
6+
@type backend_spec :: module() | {module(), keyword()}
7+
8+
@doc "Returns `{:ok, backend_spec}` for known labels."
9+
@spec from_label(String.t()) ::
10+
{:ok, backend_spec()} | {:error, {:unknown_backend_label, String.t()}}
11+
def from_label("EXLA.Backend<cuda" <> _), do: {:ok, {EXLA.Backend, client: :cuda}}
12+
def from_label("EXLA.Backend<host" <> _), do: {:ok, {EXLA.Backend, client: :host}}
13+
def from_label("Nx.BinaryBackend"), do: {:ok, Nx.BinaryBackend}
14+
def from_label("EMLX.Backend" <> _), do: {:ok, {EMLX.Backend, device: :gpu}}
15+
def from_label(other) when is_binary(other), do: {:error, {:unknown_backend_label, other}}
16+
17+
@doc "Returns a backend spec, falling back audibly to `Nx.BinaryBackend` for unknown labels."
18+
@spec from_label!(String.t()) :: backend_spec()
19+
def from_label!(label) when is_binary(label) do
20+
case from_label(label) do
21+
{:ok, backend_spec} ->
22+
backend_spec
23+
24+
{:error, {:unknown_backend_label, ^label}} ->
25+
Logger.warning(
26+
"unknown backend label #{inspect(label)}, falling back to Nx.BinaryBackend"
27+
)
28+
29+
Nx.BinaryBackend
30+
end
31+
end
32+
end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
defmodule CrucibleTensorPatch.Errors do
2+
@moduledoc "Raised for invalid tensor patch plans or patch inputs."
3+
4+
defexception [:message]
5+
end

0 commit comments

Comments
 (0)