|
| 1 | +defmodule CrucibleTensorPatch.Apply do |
| 2 | + @moduledoc "Applies tensor patch plans to source tensors." |
| 3 | + |
| 4 | + alias Crucible.Factorization.SVD |
| 5 | + alias CrucibleSafetensors.{Checksum, Writer} |
| 6 | + alias CrucibleTensorPatch.{Errors, Manifest, Operation, ParamTree, Plan, TensorPath} |
| 7 | + |
| 8 | + @doc "Applies a plan to source tensors and writes output safetensors." |
| 9 | + @spec apply(Plan.t(), map(), Path.t(), keyword()) :: {:ok, map()} | {:error, Exception.t()} |
| 10 | + def apply(%Plan{} = plan, source_artifact, out_dir, opts \\ []) when is_map(source_artifact) do |
| 11 | + {:ok, apply!(plan, source_artifact, out_dir, opts)} |
| 12 | + rescue |
| 13 | + exception -> {:error, exception} |
| 14 | + end |
| 15 | + |
| 16 | + @doc "Applies a plan, raising on invalid input." |
| 17 | + @spec apply!(Plan.t(), map(), Path.t(), keyword()) :: map() |
| 18 | + def apply!(%Plan{} = plan, source_artifact, out_dir, opts \\ []) when is_map(source_artifact) do |
| 19 | + opts = Keyword.validate!(opts, components: %{}, force: false, write_manifest?: true) |
| 20 | + File.mkdir_p!(out_dir) |
| 21 | + |
| 22 | + reports = |
| 23 | + plan.operations |
| 24 | + |> Enum.map(fn operation -> apply_operation!(operation, source_artifact, out_dir, opts) end) |
| 25 | + |
| 26 | + manifest = Manifest.build(reports, source: plan.source) |
| 27 | + |
| 28 | + if opts[:write_manifest?] do |
| 29 | + Manifest.write!(out_dir, manifest) |
| 30 | + end |
| 31 | + |
| 32 | + manifest |
| 33 | + end |
| 34 | + |
| 35 | + @doc "Applies manifest-style adapted tensors into a params tree." |
| 36 | + defdelegate patch_params!(params, manifest, tensors, opts \\ []), to: ParamTree, as: :patch! |
| 37 | + |
| 38 | + defp apply_operation!(%Operation{} = operation, source_artifact, out_dir, opts) do |
| 39 | + output_path = Path.join(out_dir, operation.output_path) |
| 40 | + existing = existing_output_status(operation, output_path, opts) |
| 41 | + |
| 42 | + case existing do |
| 43 | + {:skip, report} -> |
| 44 | + report |
| 45 | + |
| 46 | + :write -> |
| 47 | + tensor = build_tensor!(operation, source_artifact, opts[:components]) |
| 48 | + validate_tensor!(operation, tensor) |
| 49 | + write_tensor!(operation, output_path, tensor) |
| 50 | + end |
| 51 | + end |
| 52 | + |
| 53 | + defp existing_output_status(operation, output_path, opts) do |
| 54 | + cond do |
| 55 | + opts[:force] -> |
| 56 | + :write |
| 57 | + |
| 58 | + File.regular?(output_path) and is_binary(operation.expected_output_sha256) -> |
| 59 | + actual = Checksum.file_sha256!(output_path) |
| 60 | + |
| 61 | + if actual == operation.expected_output_sha256 do |
| 62 | + {:skip, report(operation, output_path, actual, "complete", true)} |
| 63 | + else |
| 64 | + raise Errors, |
| 65 | + "resume checksum mismatch for #{operation.id}: expected #{operation.expected_output_sha256}, got #{actual}" |
| 66 | + end |
| 67 | + |
| 68 | + true -> |
| 69 | + :write |
| 70 | + end |
| 71 | + end |
| 72 | + |
| 73 | + defp build_tensor!(%Operation{operation: :identity} = operation, source_artifact, _components) do |
| 74 | + fetch_source_tensor!(source_artifact, operation) |
| 75 | + end |
| 76 | + |
| 77 | + defp build_tensor!(%Operation{operation: :svf_apply} = operation, source_artifact, components) do |
| 78 | + source = fetch_source_tensor!(source_artifact, operation) |
| 79 | + inputs = operation.inputs || %{} |
| 80 | + |
| 81 | + decomposition = %{ |
| 82 | + u: fetch_input!(inputs, components, "u"), |
| 83 | + s: fetch_input!(inputs, components, "s"), |
| 84 | + v: fetch_input!(inputs, components, "v") |
| 85 | + } |
| 86 | + |
| 87 | + offsets = fetch_input!(inputs, components, "scale_offsets") |
| 88 | + |
| 89 | + decomposition |
| 90 | + |> SVD.reconstruct(Nx.as_type(offsets, Nx.type(decomposition.s))) |
| 91 | + |> Nx.as_type(Nx.type(source)) |
| 92 | + end |
| 93 | + |
| 94 | + defp fetch_source_tensor!(source_artifact, operation) do |
| 95 | + cond do |
| 96 | + operation.source_path && Map.has_key?(source_artifact, operation.source_path) -> |
| 97 | + Map.fetch!(source_artifact, operation.source_path) |
| 98 | + |
| 99 | + operation.segments -> |
| 100 | + source_artifact |
| 101 | + |> TensorPath.fetch!( |
| 102 | + TensorPath.normalize(operation.segments, operation.source_path), |
| 103 | + operation.source_path |
| 104 | + ) |
| 105 | + |
| 106 | + operation.source_path -> |
| 107 | + source_artifact |
| 108 | + |> TensorPath.fetch!( |
| 109 | + TensorPath.normalize(nil, operation.source_path), |
| 110 | + operation.source_path |
| 111 | + ) |
| 112 | + |
| 113 | + true -> |
| 114 | + raise Errors, "operation #{operation.id} has no source_path" |
| 115 | + end |
| 116 | + end |
| 117 | + |
| 118 | + defp fetch_input!(inputs, components, name) do |
| 119 | + value = Map.get(inputs, name) || Map.get(inputs, String.to_atom(name)) |
| 120 | + |
| 121 | + cond do |
| 122 | + match?(%Nx.Tensor{}, value) -> value |
| 123 | + is_binary(value) -> Map.fetch!(components, value) |
| 124 | + true -> raise Errors, "missing input #{inspect(name)}" |
| 125 | + end |
| 126 | + end |
| 127 | + |
| 128 | + defp validate_tensor!(operation, %Nx.Tensor{} = tensor) do |
| 129 | + expected_shape = normalize_shape(operation.expected_shape) |
| 130 | + |
| 131 | + if expected_shape && Nx.shape(tensor) != expected_shape do |
| 132 | + raise Errors, |
| 133 | + "operation #{operation.id} shape mismatch: expected #{inspect(expected_shape)}, got #{inspect(Nx.shape(tensor))}" |
| 134 | + end |
| 135 | + |
| 136 | + if operation.expected_dtype && |
| 137 | + normalize_dtype(Nx.type(tensor)) != normalize_dtype(operation.expected_dtype) do |
| 138 | + raise Errors, |
| 139 | + "operation #{operation.id} dtype mismatch: expected #{inspect(operation.expected_dtype)}, got #{inspect(Nx.type(tensor))}" |
| 140 | + end |
| 141 | + |
| 142 | + :ok |
| 143 | + end |
| 144 | + |
| 145 | + defp write_tensor!(operation, output_path, tensor) do |
| 146 | + Writer.write!(%{operation.id => tensor_payload(tensor)}, output_path) |
| 147 | + checksum = Checksum.file_sha256!(output_path) |
| 148 | + report(operation, output_path, checksum, "complete", false) |
| 149 | + end |
| 150 | + |
| 151 | + defp tensor_payload(%Nx.Tensor{} = tensor) do |
| 152 | + host = Nx.backend_transfer(tensor, Nx.BinaryBackend) |
| 153 | + |
| 154 | + %{ |
| 155 | + dtype: writer_dtype!(Nx.type(host)), |
| 156 | + shape: host |> Nx.shape() |> Tuple.to_list(), |
| 157 | + data: Nx.to_binary(host) |
| 158 | + } |
| 159 | + end |
| 160 | + |
| 161 | + defp report(operation, output_path, checksum, status, skipped?) do |
| 162 | + %{ |
| 163 | + "id" => operation.id, |
| 164 | + "operation" => Atom.to_string(operation.operation), |
| 165 | + "source_path" => operation.source_path, |
| 166 | + "output_path" => output_path, |
| 167 | + "status" => status, |
| 168 | + "skipped" => skipped?, |
| 169 | + "sha256" => checksum |
| 170 | + } |
| 171 | + end |
| 172 | + |
| 173 | + defp writer_dtype!({:f, 16}), do: :f16 |
| 174 | + defp writer_dtype!({:bf, 16}), do: :bf16 |
| 175 | + defp writer_dtype!({:f, 32}), do: :f32 |
| 176 | + defp writer_dtype!({:s, 32}), do: :i32 |
| 177 | + defp writer_dtype!({:s, 64}), do: :i64 |
| 178 | + defp writer_dtype!(type), do: raise(Errors, "unsupported output dtype #{inspect(type)}") |
| 179 | + |
| 180 | + defp normalize_shape(nil), do: nil |
| 181 | + defp normalize_shape(shape) when is_tuple(shape), do: shape |
| 182 | + defp normalize_shape(shape) when is_list(shape), do: List.to_tuple(shape) |
| 183 | + defp normalize_dtype({:f, 16}), do: "f16" |
| 184 | + defp normalize_dtype({:bf, 16}), do: "bf16" |
| 185 | + defp normalize_dtype({:f, 32}), do: "f32" |
| 186 | + defp normalize_dtype({:s, 32}), do: "i32" |
| 187 | + defp normalize_dtype({:s, 64}), do: "i64" |
| 188 | + defp normalize_dtype(dtype) when is_atom(dtype), do: Atom.to_string(dtype) |
| 189 | + defp normalize_dtype(dtype), do: inspect(dtype) |
| 190 | +end |
0 commit comments