Skip to content

Commit 1ee6fb9

Browse files
committed
crucible_factorization: initial repo and svd svf migration
1 parent e967e64 commit 1ee6fb9

26 files changed

Lines changed: 1417 additions & 16 deletions

.github/workflows/ci.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
8+
jobs:
9+
test:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
- uses: erlef/setup-beam@v1
14+
with:
15+
otp-version: "28.3"
16+
elixir-version: "1.18.4"
17+
- uses: actions/cache@v4
18+
with:
19+
path: |
20+
deps
21+
_build
22+
key: ${{ runner.os }}-mix-${{ hashFiles('mix.lock') }}
23+
restore-keys: ${{ runner.os }}-mix-
24+
- run: mix local.hex --force
25+
- run: mix local.rebar --force
26+
- run: mix ci

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2026 North Shore AI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

MIGRATION.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Migration Notes
2+
3+
This repo was scaffolded during TRINITY decomposition Phase 1 so
4+
`trinity_framework` could resolve local path dependencies before the public
5+
GitHub repos existed.
6+
7+
Source material for the Phase 3 implementation:
8+
9+
- `nshkrdotcom/trinity_coordinator` tag `v0.1.0-monolith`
10+
- source commit `64144a2983950e5fc9f2db2d26323a576c7379a1`
11+
- `lib/trinity_coordinator/sakana/svd.ex`
12+
- math portions of `lib/trinity_coordinator/sakana/stage_check.ex`
13+
- math/reporting portions of `lib/trinity_coordinator/sakana/parity_trace.ex`
14+
- sync-timing pattern from `lib/trinity_coordinator/sakana/exporter.ex`
15+
16+
The implementation keeps provider, orchestration, tracing, and product runtime
17+
dependencies out of the factorization package. Compatibility functions that
18+
previously accepted model-state structs now operate on generic maps or structs
19+
with a `:data` field.

README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# CrucibleFactorization
22

3-
**TODO: Add description**
3+
Nx SVD/SVF factorization primitives for model surgery and TRINITY artifact
4+
export.
5+
6+
This package intentionally owns the temporary Nx/EXLA git pin required for the
7+
thin-SVD memory behavior used by the coordinator. Contract packages should not
8+
inherit that pin directly.
49

510
## Installation
611

@@ -19,3 +24,14 @@ Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_do
1924
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
2025
be found at <https://hexdocs.pm/crucible_factorization>.
2126

27+
## CI
28+
29+
```sh
30+
mix ci
31+
```
32+
33+
CUDA is opt-in:
34+
35+
```sh
36+
XLA_TARGET=cuda12 mix test --only cuda
37+
```
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defmodule Crucible.Factorization.ParityReport do
2+
@moduledoc "Math-only parity summaries for factorization outputs."
3+
4+
alias Crucible.Factorization.StageCheck
5+
6+
@doc "Compares two tensors and returns numeric error metrics."
7+
@spec compare_tensors(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: map()
8+
def compare_tensors(%Nx.Tensor{} = computed, %Nx.Tensor{} = reference, opts \\ []) do
9+
opts = Keyword.validate!(opts, stage: "tensor")
10+
11+
[check] =
12+
StageCheck.compare_stage_tensors(
13+
%{opts[:stage] => computed},
14+
%{opts[:stage] => reference},
15+
include_alt_hashes: false,
16+
include_tensor_summaries: false
17+
)
18+
19+
check
20+
end
21+
22+
@doc "Returns a compact summary of a named tensor set."
23+
@spec tensor_set_summary(%{String.t() => Nx.Tensor.t()}) :: [map()]
24+
def tensor_set_summary(tensors) when is_map(tensors) do
25+
tensors
26+
|> Enum.sort_by(fn {name, _tensor} -> name end)
27+
|> Enum.map(fn {name, tensor} ->
28+
Map.put(StageCheck.tensor_summary(tensor, include_alt_hashes: false), "name", name)
29+
end)
30+
end
31+
end
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
defmodule Crucible.Factorization.StageCheck do
2+
@moduledoc "Shared tensor comparisons for factorization parity checks."
3+
4+
alias CrucibleFactorization.{Backend, TensorHash}
5+
6+
@doc "Compares computed stage tensors with reference stage tensors."
7+
def compare_stage_tensors(stage_tensors, reference_stage_tensors, opts \\ [])
8+
def compare_stage_tensors(_stage_tensors, nil, _opts), do: []
9+
10+
def compare_stage_tensors(stage_tensors, reference_stage_tensors, opts)
11+
when is_map(stage_tensors) do
12+
opts =
13+
Keyword.validate!(opts,
14+
include_alt_hashes: true,
15+
include_tensor_summaries: true,
16+
compute_byte_match: true
17+
)
18+
19+
Nx.with_default_backend(Nx.BinaryBackend, fn ->
20+
stage_tensors
21+
|> Enum.sort_by(fn {key, _tensor} -> key end)
22+
|> Enum.map(&compare_one_stage(&1, reference_stage_tensors, opts))
23+
end)
24+
end
25+
26+
@doc "Returns true when every required stage check passed."
27+
def checks_passed?([]), do: nil
28+
29+
def checks_passed?(checks) when is_list(checks) do
30+
Enum.all?(checks, fn check ->
31+
not check["required_for_functional_parity"] or check["functional_passed"]
32+
end)
33+
end
34+
35+
@doc "Returns a JSON-safe summary for a tensor."
36+
def tensor_summary(tensor, opts \\ []) do
37+
opts = Keyword.validate!(opts, prefix_count: 8, include_alt_hashes: true, backend_label: nil)
38+
39+
Nx.with_default_backend(Nx.BinaryBackend, fn ->
40+
tensor = host_snapshot(tensor)
41+
tensor_f32 = Nx.as_type(tensor, :f32) |> host_snapshot()
42+
size = Nx.size(tensor)
43+
prefix_count = min(size, opts[:prefix_count])
44+
45+
base = %{
46+
"shape" => shape_list(tensor),
47+
"type" => inspect(Nx.type(tensor)),
48+
"backend" => opts[:backend_label] || Backend.label(tensor),
49+
"snapshot_backend" => Backend.label(tensor),
50+
"size" => size,
51+
"sha256" => TensorHash.tensor_sha256(tensor),
52+
"min" => scalar(Nx.reduce_min(tensor_f32)),
53+
"max" => scalar(Nx.reduce_max(tensor_f32)),
54+
"sum" => scalar(Nx.sum(tensor_f32)),
55+
"prefix_f32" => prefix_f32(tensor, prefix_count)
56+
}
57+
58+
if opts[:include_alt_hashes] do
59+
Map.merge(base, %{
60+
"sha256_as_f32" => TensorHash.tensor_sha256(Nx.as_type(tensor, :f32)),
61+
"sha256_as_bf16" => TensorHash.tensor_sha256(Nx.as_type(tensor, :bf16))
62+
})
63+
else
64+
base
65+
end
66+
end)
67+
end
68+
69+
defp compare_one_stage({key, tensor}, reference_stage_tensors, opts) do
70+
case Map.fetch(reference_stage_tensors, key) do
71+
{:ok, reference_tensor} -> stage_check(key, tensor, reference_tensor, opts)
72+
:error -> missing_stage_check(key)
73+
end
74+
end
75+
76+
defp stage_check(key, computed_tensor, reference_tensor, opts) do
77+
computed_tensor = host_snapshot(computed_tensor)
78+
reference_tensor = host_snapshot(reference_tensor)
79+
tolerance = stage_tolerance(key)
80+
shape_match = Nx.shape(computed_tensor) == Nx.shape(reference_tensor)
81+
byte_match = maybe_byte_match(computed_tensor, reference_tensor, opts)
82+
83+
if shape_match do
84+
stage_value_check(key, computed_tensor, reference_tensor, tolerance, byte_match, opts)
85+
else
86+
%{
87+
"stage" => key,
88+
"required_for_functional_parity" => tolerance.required?,
89+
"byte_match" => byte_match,
90+
"shape_match" => false,
91+
"computed" => maybe_tensor_summary(computed_tensor, opts),
92+
"reference" => maybe_tensor_summary(reference_tensor, opts),
93+
"max_abs_error" => nil,
94+
"mean_abs_error" => nil,
95+
"mismatched_element_count" => nil,
96+
"tolerance" => %{
97+
"max_abs_error" => tolerance.max_abs,
98+
"mean_abs_error" => tolerance.mean_abs
99+
},
100+
"functional_passed" => false
101+
}
102+
end
103+
end
104+
105+
defp stage_value_check(key, computed_tensor, reference_tensor, tolerance, byte_match, opts) do
106+
computed_f32 = Nx.as_type(computed_tensor, :f32)
107+
reference_f32 = Nx.as_type(reference_tensor, :f32)
108+
abs_diff = Nx.abs(Nx.subtract(computed_f32, reference_f32))
109+
max_abs = scalar(Nx.reduce_max(abs_diff))
110+
mean_abs = scalar(Nx.divide(Nx.sum(abs_diff), Nx.tensor(Nx.size(abs_diff), type: :f32)))
111+
112+
mismatch_count =
113+
scalar(Nx.sum(Nx.as_type(Nx.not_equal(computed_tensor, reference_tensor), :s64)))
114+
115+
%{
116+
"stage" => key,
117+
"required_for_functional_parity" => tolerance.required?,
118+
"byte_match" => byte_match,
119+
"shape_match" => true,
120+
"computed" => maybe_tensor_summary(computed_tensor, opts),
121+
"reference" => maybe_tensor_summary(reference_tensor, opts),
122+
"max_abs_error" => max_abs,
123+
"mean_abs_error" => mean_abs,
124+
"mismatched_element_count" => mismatch_count,
125+
"tolerance" => %{
126+
"max_abs_error" => tolerance.max_abs,
127+
"mean_abs_error" => tolerance.mean_abs
128+
},
129+
"functional_passed" => max_abs <= tolerance.max_abs and mean_abs <= tolerance.mean_abs
130+
}
131+
end
132+
133+
defp maybe_byte_match(computed_tensor, reference_tensor, opts) do
134+
if opts[:compute_byte_match] do
135+
TensorHash.tensor_sha256(computed_tensor) == TensorHash.tensor_sha256(reference_tensor)
136+
end
137+
end
138+
139+
defp maybe_tensor_summary(tensor, opts) do
140+
if opts[:include_tensor_summaries] do
141+
tensor_summary(tensor,
142+
prefix_count: 8,
143+
include_alt_hashes: opts[:include_alt_hashes]
144+
)
145+
else
146+
%{
147+
"shape" => shape_list(tensor),
148+
"type" => inspect(Nx.type(tensor)),
149+
"backend" => Backend.label(tensor),
150+
"size" => Nx.size(tensor),
151+
"summary_omitted" => true
152+
}
153+
end
154+
end
155+
156+
defp missing_stage_check(key) do
157+
tolerance = stage_tolerance(key)
158+
159+
%{
160+
"stage" => key,
161+
"required_for_functional_parity" => tolerance.required?,
162+
"byte_match" => false,
163+
"shape_match" => false,
164+
"missing_reference_stage" => true,
165+
"functional_passed" => not tolerance.required?
166+
}
167+
end
168+
169+
defp stage_tolerance("stage.source_f32"), do: %{required?: true, max_abs: 0.0, mean_abs: 0.0}
170+
defp stage_tolerance("stage.offsets_f32"), do: %{required?: true, max_abs: 0.0, mean_abs: 0.0}
171+
172+
defp stage_tolerance("stage.scaled_s"),
173+
do: %{required?: true, max_abs: 1.0e-6, mean_abs: 1.0e-8}
174+
175+
defp stage_tolerance("stage.normalization"),
176+
do: %{required?: true, max_abs: 1.0e-6, mean_abs: 1.0e-6}
177+
178+
defp stage_tolerance("stage.u_scaled"),
179+
do: %{required?: true, max_abs: 1.0e-6, mean_abs: 1.0e-8}
180+
181+
defp stage_tolerance("stage.zero_source_f32"),
182+
do: %{required?: true, max_abs: 1.0e-3, mean_abs: 1.0e-5}
183+
184+
defp stage_tolerance("stage.matmul_pre_norm"),
185+
do: %{required?: true, max_abs: 1.0e-3, mean_abs: 1.0e-5}
186+
187+
defp stage_tolerance("stage.adapted_source_f32"),
188+
do: %{required?: true, max_abs: 1.0e-3, mean_abs: 1.0e-5}
189+
190+
defp stage_tolerance("stage.final_f32"),
191+
do: %{required?: true, max_abs: 1.0e-3, mean_abs: 1.0e-5}
192+
193+
defp stage_tolerance("stage.final_bf16"),
194+
do: %{required?: false, max_abs: 1.0e-3, mean_abs: 1.0e-5}
195+
196+
defp stage_tolerance(_key), do: %{required?: false, max_abs: 1.0e-3, mean_abs: 1.0e-5}
197+
198+
defp prefix_f32(_tensor, 0), do: []
199+
200+
defp prefix_f32(tensor, count) do
201+
tensor
202+
|> host_snapshot()
203+
|> Nx.as_type(:f32)
204+
|> Nx.reshape({Nx.size(tensor)})
205+
|> Nx.slice([0], [count])
206+
|> host_snapshot()
207+
|> Nx.to_flat_list()
208+
end
209+
210+
defp scalar(tensor), do: tensor |> host_snapshot() |> Nx.to_number() |> finite_float()
211+
defp finite_float(value) when is_float(value), do: value
212+
defp finite_float(value), do: value
213+
defp shape_list(tensor), do: tensor |> Nx.shape() |> Tuple.to_list()
214+
defp host_snapshot(%Nx.Tensor{} = tensor), do: Nx.backend_transfer(tensor, Nx.BinaryBackend)
215+
end

0 commit comments

Comments
 (0)