Skip to content

Commit f1342df

Browse files
committed
fix: make router vector defaults generic
1 parent 6686d24 commit f1342df

7 files changed

Lines changed: 127 additions & 26 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [0.1.0] - 2026-05-23
99

1010
### Added
11-
- Initial release of `CrucibleFactorization` under the TRINITY decomposition.
11+
- Initial release of `CrucibleFactorization` from the monolith extraction.

MIGRATION.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Migration Notes
22

3-
This repo was scaffolded during TRINITY decomposition Phase 1 so
4-
`trinity_framework` could resolve local path dependencies before the public
5-
GitHub repos existed.
3+
This repo was scaffolded during the monolith extraction so framework consumers
4+
could resolve local path dependencies before the public GitHub repos existed.
65

76
Source material for the Phase 3 implementation:
87

@@ -16,4 +15,5 @@ Source material for the Phase 3 implementation:
1615
The implementation keeps provider, orchestration, tracing, and product runtime
1716
dependencies out of the factorization package. Compatibility functions that
1817
previously accepted model-state structs now operate on generic maps or structs
19-
with a `:data` field.
18+
with a `:data` field. Product-specific tensor keys belong in callers, not in
19+
the reusable package defaults.

README.md

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,30 @@
1313
</a>
1414
</p>
1515

16-
Nx SVD/SVF factorization primitives for model surgery and TRINITY artifact
17-
export.
16+
Nx SVD/SVF factorization primitives for model surgery and artifact export.
1817

19-
This package intentionally owns the temporary Nx/EXLA git pin required for the
20-
thin-SVD memory behavior used by the coordinator. Contract packages should not
21-
inherit that pin directly.
18+
This package is intentionally narrow. It owns numerical factorization,
19+
reconstruction, tensor traversal, manifest helpers, and router-vector splitting.
20+
It avoids provider, tracing, orchestration, and application runtime dependencies.
21+
Callers that use product-specific artifact names should pass those names
22+
explicitly at the API boundary.
23+
24+
## What It Provides
25+
26+
- `CrucibleFactorization.SVD.thin/2` and `thin!/2` run reduced SVD with timing,
27+
backend, rank, and sync metadata.
28+
- `CrucibleFactorization.SVD.reconstruct/3` rebuilds tensors from SVD
29+
components and scale offsets.
30+
- `CrucibleFactorization.SVD.decompose_tensors/2`,
31+
`reconstruct_tensors/3`, and `adapt_tensors/3` operate on selected tensor
32+
entries from nested parameter trees.
33+
- `CrucibleFactorization.SVF` provides low-rank singular-vector-field helpers
34+
for `base_tensor + low_rank_delta` workflows.
35+
- `CrucibleFactorization.StageCheck` and `ParityReport` provide math-only
36+
tensor comparison summaries.
37+
- `CrucibleFactorization.SVD.load_router_vector!/2` and
38+
`split_router_vector/4` load and split a flat vector into scale offsets and
39+
dense head weights. The default tensor name is the generic `"router_vector"`.
2240

2341
## Installation
2442

@@ -37,6 +55,78 @@ Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_do
3755
and published on [HexDocs](https://hexdocs.pm). Once published, the docs can
3856
be found at <https://hexdocs.pm/crucible_factorization>.
3957

58+
## Thin SVD
59+
60+
```elixir
61+
alias CrucibleFactorization.SVD
62+
63+
matrix = Nx.tensor([[2.0, 4.0], [1.0, 2.0]], type: :f32)
64+
65+
{:ok, result} = SVD.thin(matrix, rank: 1, compute_type: :f32, force_sync?: true)
66+
reconstructed = SVD.reconstruct(result, Nx.broadcast(0.0, {result.rank}))
67+
```
68+
69+
`result` includes `:u`, `:s`, `:v`, `:rank`, source type, backend label,
70+
decompose timing, and optional force-sync timing.
71+
72+
## SVF Delta Reconstruction
73+
74+
```elixir
75+
alias CrucibleFactorization.SVF
76+
77+
base = Nx.tensor([[1.0, 1.0], [1.0, 1.0]], type: :f32)
78+
delta = Nx.tensor([[2.0, 4.0], [1.0, 2.0]], type: :f32)
79+
80+
{:ok, svf} = SVF.decompose(delta, rank: 1)
81+
{:ok, adapted} = SVF.reconstruct(base, svf)
82+
```
83+
84+
## Tensor Traversal
85+
86+
```elixir
87+
entries =
88+
params
89+
|> SVD.decomposable_tensor_entries(path_filter: SVD.layer_index_filter([26]))
90+
91+
manifest = SVD.tensor_manifest(entries)
92+
count = SVD.singular_value_count(entries)
93+
```
94+
95+
The helpers accept generic nested maps, lists, tuples, and structs with a
96+
`:data` field. They do not require a framework runtime struct.
97+
98+
## Router Vector Helpers
99+
100+
```elixir
101+
vector = SVD.load_router_vector!("router_vector.safetensors")
102+
103+
split =
104+
SVD.split_router_vector(
105+
vector,
106+
scale_count,
107+
hidden_size,
108+
output_count
109+
)
110+
111+
split.scale_offsets
112+
split.head_weights
113+
```
114+
115+
For product-specific safetensors keys, pass the key explicitly:
116+
117+
```elixir
118+
vector = SVD.load_router_vector!("artifact.safetensors", "product_router_vector")
119+
```
120+
121+
## Backend And Sync Options
122+
123+
`thin/2` accepts:
124+
125+
- `:rank` for reduced rank selection.
126+
- `:compute_type`, either `:source` or `:f32`.
127+
- `:backend` for `Nx.backend_transfer/2`.
128+
- `:force_sync?` and `:sync_fun` for timing asynchronous backends.
129+
40130
## CI
41131

42132
```sh
@@ -48,3 +138,6 @@ CUDA is opt-in:
48138
```sh
49139
XLA_TARGET=cuda12 mix test --only cuda
50140
```
141+
142+
`mix ci` runs dependency fetch, format check, warning-as-error compile, tests,
143+
Credo strict, Dialyzer, and docs generation.

lib/crucible/factorization/svd.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule Crucible.Factorization.SVD do
55

66
alias CrucibleFactorization.{Backend, Errors, StageTiming}
77

8-
@router_vector_key "trinity_router_es_vector"
8+
@router_vector_key "router_vector"
99

1010
@type decomposition :: %{
1111
required(:u) => Nx.Tensor.t(),

lib/crucible_factorization.ex

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
11
defmodule CrucibleFactorization do
22
@moduledoc """
3-
Documentation for `CrucibleFactorization`.
3+
Nx SVD/SVF factorization primitives for model surgery and artifact export.
44
"""
5-
6-
@doc """
7-
Hello world.
8-
9-
## Examples
10-
11-
iex> CrucibleFactorization.hello()
12-
:world
13-
14-
"""
15-
def hello do
16-
:world
17-
end
185
end

lib/crucible_factorization/svd.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ defmodule CrucibleFactorization.SVD do
2424
defdelegate adapt_tensors(tensors, scale_offsets, opts \\ []), to: Crucible.Factorization.SVD
2525
defdelegate put_tensor_entries(container, tensor_entries), to: Crucible.Factorization.SVD
2626

27-
defdelegate load_router_vector!(path, tensor_name \\ "trinity_router_es_vector"),
27+
defdelegate load_router_vector!(path, tensor_name \\ "router_vector"),
2828
to: Crucible.Factorization.SVD
2929

3030
defdelegate split_router_vector(vector, scale_count, hidden_size, output_count),

test/crucible_factorization/svd_test.exs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ defmodule CrucibleFactorization.SVDTest do
6262
assert result.backend_label == "Nx.BinaryBackend"
6363
end
6464

65+
test "load_router_vector!/2 defaults to a generic tensor name" do
66+
path = tmp_path("router_vector.safetensors")
67+
Safetensors.write!(path, %{"router_vector" => Nx.tensor([1.0, 2.0], type: :f32)})
68+
69+
assert SVD.load_router_vector!(path) |> Nx.to_flat_list() == [1.0, 2.0]
70+
end
71+
72+
test "load_router_vector!/2 accepts explicit product tensor names" do
73+
path = tmp_path("custom_router_vector.safetensors")
74+
tensor_name = "product_router_vector"
75+
Safetensors.write!(path, %{tensor_name => Nx.tensor([3.0, 4.0], type: :f32)})
76+
77+
assert SVD.load_router_vector!(path, tensor_name) |> Nx.to_flat_list() == [3.0, 4.0]
78+
end
79+
6580
defp reconstruction_error(result, matrix) do
6681
result
6782
|> SVD.reconstruct(Nx.broadcast(0.0, {result.rank}))
@@ -81,4 +96,10 @@ defmodule CrucibleFactorization.SVDTest do
8196

8297
assert max_abs <= tolerance
8398
end
99+
100+
defp tmp_path(name) do
101+
dir = Path.join(System.tmp_dir!(), "crucible_factorization_tests")
102+
File.mkdir_p!(dir)
103+
Path.join(dir, "#{System.unique_integer([:positive])}_#{name}")
104+
end
84105
end

0 commit comments

Comments
 (0)