4242import torch .nn .functional as F
4343
4444
45- FIBONACCI = [1 , 2 , 3 , 5 , 8 , 13 , 21 , 34 , 55 , 89 , 144 , 233 , 377 , 610 , 987 , 1597 ]
45+ # Extended unique-positive Fibonacci table — 32 entries.
46+ # Previous 16-entry version caused K>16 to silently clamp.
47+ FIBONACCI = [
48+ 1 , 2 , 3 , 5 , 8 , 13 , 21 , 34 , 55 , 89 , 144 , 233 , 377 , 610 , 987 ,
49+ 1597 , 2584 , 4181 , 6765 , 10946 , 17711 , 28657 , 46368 , 75025 ,
50+ 121393 , 196418 , 317811 , 514229 , 832040 , 1346269 , 2178309 , 3524578 ,
51+ ]
4652
4753
4854class FibGenLinear (nn .Module ):
4955 """Drop-in replacement for nn.Linear where W is generated from a seed.
5056
57+ Two generator modes:
58+
59+ "separable" (the original): each component uses the SAME Fibonacci
60+ frequency on both axes. Generates rank-K terms.
61+ W[i,j] = Σ_k [a_k cos(F_k·i) cos(F_k·j) + ...]
62+ Seed: 4·K params.
63+
64+ "cross" (new): each component uses INDEPENDENT Fibonacci frequencies
65+ on the two axes. Generates a full K_i × K_j grid of frequency
66+ pairs, so the matrix is a sum of K_i·K_j outer products of
67+ single-frequency 1-D bases.
68+ W[i,j] = Σ_{k_i, k_j} [a_{kk'} cos(F_{k_i}·i) cos(F_{k_j}·j) + ...]
69+ Seed: 4·K² params. Equal expressivity as separable at K_separable = K²,
70+ but with the substrate-canonical Fibonacci-coprime structure that
71+ makes the basis non-degenerate (Fibonacci frequencies are pairwise
72+ substrate-distinguishable).
73+
5174 Args:
5275 in_features: input dim.
5376 out_features: output dim.
54- K: number of Fibonacci-frequency components in the generator.
55- Higher K = more capacity, more params. K=16 → 64 params
56- (vs in·out for a stored matrix).
77+ K: number of Fibonacci frequencies per axis.
78+ mode: "separable" or "cross".
5779 bias: whether to include a learnable bias vector.
58- init_scale: scales the seed initialization. The generated W has
59- magnitude ~ init_scale · sqrt(4K), so smaller init_scale
60- gives smaller initial weights.
80+ init_scale: scales the seed initialization.
6181 """
6282
6383 def __init__ (self , in_features : int , out_features : int , K : int = 16 ,
84+ mode : str = "separable" ,
6485 bias : bool = True , init_scale : float = 0.1 ):
6586 super ().__init__ ()
6687 self .in_features = in_features
6788 self .out_features = out_features
6889 self .K = min (K , len (FIBONACCI ))
69- # Seed: 4 coefficients per Fibonacci component (cc, sc, cs, ss).
90+ if mode not in ("separable" , "cross" ):
91+ raise ValueError (f"unknown mode: { mode } " )
92+ self .mode = mode
93+ n_components = self .K if mode == "separable" else self .K * self .K
7094 self .seed = nn .Parameter (
71- torch .randn (self . K , 4 ) * (init_scale / max (1 , math .sqrt (self . K )))
95+ torch .randn (n_components , 4 ) * (init_scale / max (1 , math .sqrt (n_components )))
7296 )
7397 if bias :
7498 self .bias = nn .Parameter (torch .zeros (out_features ))
7599 else :
76100 self .register_parameter ("bias" , None )
77- # Precompute the cos/sin of position·Fibonacci-frequency for both
78- # axes. These are FIXED — no gradient flows through positions.
101+ # Precompute cos/sin position·Fibonacci-frequency tables.
79102 i_idx = torch .arange (out_features ).float ()
80103 j_idx = torch .arange (in_features ).float ()
81104 freqs = torch .tensor (FIBONACCI [:self .K ], dtype = torch .float )
82- # angles: [out, K], [in, K]
83105 a_i = 2 * math .pi * i_idx .unsqueeze (1 ) * freqs .unsqueeze (0 ) / max (out_features , 1 )
84106 a_j = 2 * math .pi * j_idx .unsqueeze (1 ) * freqs .unsqueeze (0 ) / max (in_features , 1 )
85107 self .register_buffer ("cos_i" , torch .cos (a_i )) # [out, K]
@@ -88,21 +110,24 @@ def __init__(self, in_features: int, out_features: int, K: int = 16,
88110 self .register_buffer ("sin_j" , torch .sin (a_j ))
89111
90112 def generate_W (self ) -> torch .Tensor :
91- # seed: [K, 4] → split into 4 [K] tensors.
92- a , b , c , d = self .seed [:, 0 ], self .seed [:, 1 ], self .seed [:, 2 ], self .seed [:, 3 ]
93- # W = sum_k (
94- # a_k · cos_i[:, k] · cos_j[:, k]^T +
95- # b_k · sin_i[:, k] · cos_j[:, k]^T +
96- # c_k · cos_i[:, k] · sin_j[:, k]^T +
97- # d_k · sin_i[:, k] · sin_j[:, k]^T
98- # )
99- # Each term is an [out, in] outer product.
100- # Compose via einsum: [out, K] · [K] · [K, in] (with the diagonal)
101- # → [out, in].
102- W = torch .einsum ("ok,k,jk->oj" , self .cos_i , a , self .cos_j )
103- W = W + torch .einsum ("ok,k,jk->oj" , self .sin_i , b , self .cos_j )
104- W = W + torch .einsum ("ok,k,jk->oj" , self .cos_i , c , self .sin_j )
105- W = W + torch .einsum ("ok,k,jk->oj" , self .sin_i , d , self .sin_j )
113+ if self .mode == "separable" :
114+ a , b , c , d = self .seed [:, 0 ], self .seed [:, 1 ], self .seed [:, 2 ], self .seed [:, 3 ]
115+ W = torch .einsum ("ok,k,jk->oj" , self .cos_i , a , self .cos_j )
116+ W = W + torch .einsum ("ok,k,jk->oj" , self .sin_i , b , self .cos_j )
117+ W = W + torch .einsum ("ok,k,jk->oj" , self .cos_i , c , self .sin_j )
118+ W = W + torch .einsum ("ok,k,jk->oj" , self .sin_i , d , self .sin_j )
119+ return W
120+ # mode == "cross": seed shape [K*K, 4], reshape to [K, K, 4]
121+ K = self .K
122+ seed = self .seed .view (K , K , 4 )
123+ a , b , c , d = seed [..., 0 ], seed [..., 1 ], seed [..., 2 ], seed [..., 3 ]
124+ # W[i,j] = Σ_{k_i, k_j} [a · cos_i[i, k_i] cos_j[j, k_j] + ...]
125+ # einsum: cos_i [out, k_i] @ a [k_i, k_j] -> [out, k_j], then
126+ # · cos_j [in, k_j] -> [out, in].
127+ W = torch .einsum ("ol,lm,jm->oj" , self .cos_i , a , self .cos_j )
128+ W = W + torch .einsum ("ol,lm,jm->oj" , self .sin_i , b , self .cos_j )
129+ W = W + torch .einsum ("ol,lm,jm->oj" , self .cos_i , c , self .sin_j )
130+ W = W + torch .einsum ("ol,lm,jm->oj" , self .sin_i , d , self .sin_j )
106131 return W
107132
108133 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -127,11 +152,11 @@ def n_dense_equivalent_params(self) -> int:
127152class FibGenAttention (nn .Module ):
128153 """Single-head self-attention with all linear layers FibGen-generated."""
129154
130- def __init__ (self , d_model : int , K : int = 16 ):
155+ def __init__ (self , d_model : int , K : int = 16 , mode : str = "separable" ):
131156 super ().__init__ ()
132157 self .d_model = d_model
133- self .qkv = FibGenLinear (d_model , 3 * d_model , K = K )
134- self .out = FibGenLinear (d_model , d_model , K = K )
158+ self .qkv = FibGenLinear (d_model , 3 * d_model , K = K , mode = mode )
159+ self .out = FibGenLinear (d_model , d_model , K = K , mode = mode )
135160
136161 def forward (self , x : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
137162 B , T , D = x .shape
@@ -148,21 +173,22 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
148173class FibGenFeedForward (nn .Module ):
149174 """FFN with FibGen-generated linear layers."""
150175
151- def __init__ (self , d_model : int , expansion : int = 4 , K : int = 16 ):
176+ def __init__ (self , d_model : int , expansion : int = 4 , K : int = 16 ,
177+ mode : str = "separable" ):
152178 super ().__init__ ()
153179 d_inner = d_model * expansion
154- self .w1 = FibGenLinear (d_model , d_inner , K = K )
155- self .w2 = FibGenLinear (d_inner , d_model , K = K )
180+ self .w1 = FibGenLinear (d_model , d_inner , K = K , mode = mode )
181+ self .w2 = FibGenLinear (d_inner , d_model , K = K , mode = mode )
156182
157183 def forward (self , x : torch .Tensor ) -> torch .Tensor :
158184 return self .w2 (F .gelu (self .w1 (x )))
159185
160186
161187class FibGenBlock (nn .Module ):
162- def __init__ (self , d_model : int , K : int = 16 ):
188+ def __init__ (self , d_model : int , K : int = 16 , mode : str = "separable" ):
163189 super ().__init__ ()
164- self .attn = FibGenAttention (d_model , K = K )
165- self .ff = FibGenFeedForward (d_model , K = K )
190+ self .attn = FibGenAttention (d_model , K = K , mode = mode )
191+ self .ff = FibGenFeedForward (d_model , K = K , mode = mode )
166192 self .ln1 = nn .LayerNorm (d_model )
167193 self .ln2 = nn .LayerNorm (d_model )
168194
@@ -184,25 +210,20 @@ class FibGenLM(nn.Module):
184210 """
185211
186212 def __init__ (self , vocab_size : int , d_model : int , n_blocks : int ,
187- seq_len : int , K : int = 16 ):
213+ seq_len : int , K : int = 16 , mode : str = "separable" ):
188214 super ().__init__ ()
189215 self .seq_len = seq_len
190216 self .K = K
191- # Embedding implemented as FibGen + index → FibGen produces a
192- # [vocab, d_model] table that we index into.
193- self .embed_gen = FibGenLinear (vocab_size , d_model , K = K , bias = False )
194- # Positional encoding stays CRT-Fibonacci (already substrate-aligned,
195- # and it's a buffer, not a learned weight).
217+ self .mode = mode
218+ self .embed_gen = FibGenLinear (vocab_size , d_model , K = K , mode = mode ,
219+ bias = False )
196220 pe = self ._crt_pe (seq_len , d_model )
197221 self .register_buffer ("pe" , pe )
198222 self .blocks = nn .ModuleList ([
199- FibGenBlock (d_model , K = K ) for _ in range (n_blocks )
223+ FibGenBlock (d_model , K = K , mode = mode ) for _ in range (n_blocks )
200224 ])
201225 self .ln_f = nn .LayerNorm (d_model )
202- # Head: FibGen too (or tied with embed — but tied with a generator
203- # means head and embed share the SAME generator seed which forces
204- # a constraint. Pick untied for now to test capacity.)
205- self .head = FibGenLinear (d_model , vocab_size , K = K , bias = False )
226+ self .head = FibGenLinear (d_model , vocab_size , K = K , mode = mode , bias = False )
206227 mask = torch .tril (torch .ones (seq_len , seq_len ))
207228 self .register_buffer ("mask" , mask )
208229
0 commit comments