@@ -232,36 +232,176 @@ private double ComputeFrechetDistance(
232232 meanDiffSq = NumOps . Add ( meanDiffSq , NumOps . Multiply ( diff , diff ) ) ;
233233 }
234234
235- // 2. Compute trace of covariance matrices: Tr(Σ₁ + Σ₂)
236- var traceCov = NumOps . Zero ;
237- for ( int i = 0 ; i < cov1 . Rows ; i ++ )
238- {
239- traceCov = NumOps . Add ( traceCov , cov1 [ i , i ] ) ;
240- traceCov = NumOps . Add ( traceCov , cov2 [ i , i ] ) ;
241- }
242-
243- // 3. Compute sqrt(Σ₁ * Σ₂) using simplified approximation
244- // Full implementation would use proper matrix square root
245- // For now, use trace approximation: Tr(2√(Σ₁Σ₂)) ≈ 2√(Tr(Σ₁)Tr(Σ₂))
235+ // 2. Compute trace of covariance matrices: Tr(Σ₁) + Tr(Σ₂)
246236 var trace1 = NumOps . Zero ;
247237 var trace2 = NumOps . Zero ;
248238 for ( int i = 0 ; i < cov1 . Rows ; i ++ )
249239 {
250240 trace1 = NumOps . Add ( trace1 , cov1 [ i , i ] ) ;
251241 trace2 = NumOps . Add ( trace2 , cov2 [ i , i ] ) ;
252242 }
243+ var traceCov = NumOps . Add ( trace1 , trace2 ) ;
253244
254- var covProduct = NumOps . Multiply ( trace1 , trace2 ) ;
255- var sqrtCovProduct = NumOps . Sqrt ( covProduct ) ;
256- var traceSqrtCovProduct = NumOps . Multiply ( NumOps . FromDouble ( 2.0 ) , sqrtCovProduct ) ;
245+ // 3. Compute Tr(√(Σ₁Σ₂)) using proper matrix square root
246+ // For symmetric positive semi-definite matrices, we compute the product
247+ // and then find the trace of its square root
248+ var traceSqrtCovProduct = ComputeTraceSqrtCovProduct ( cov1 , cov2 ) ;
257249
258- // FID = ||μ₁ - μ₂||² + Tr(Σ₁ + Σ₂ - 2√(Σ₁Σ₂))
250+ // FID = ||μ₁ - μ₂||² + Tr(Σ₁) + Tr(Σ₂) - 2*Tr( √(Σ₁Σ₂))
259251 var fid = NumOps . Add ( meanDiffSq , traceCov ) ;
260- fid = NumOps . Subtract ( fid , traceSqrtCovProduct ) ;
252+ fid = NumOps . Subtract ( fid , NumOps . Multiply ( NumOps . FromDouble ( 2.0 ) , traceSqrtCovProduct ) ) ;
261253
262254 return Convert . ToDouble ( fid ) ;
263255 }
264256
257+ /// <summary>
258+ /// Computes Tr(√(Σ₁Σ₂)) using Newton-Schulz iteration for matrix square root.
259+ /// </summary>
260+ private T ComputeTraceSqrtCovProduct ( Matrix < T > cov1 , Matrix < T > cov2 )
261+ {
262+ int n = cov1 . Rows ;
263+
264+ // Compute the matrix product Σ₁ * Σ₂
265+ var product = new Matrix < T > ( n , n ) ;
266+ for ( int i = 0 ; i < n ; i ++ )
267+ {
268+ for ( int j = 0 ; j < n ; j ++ )
269+ {
270+ var sum = NumOps . Zero ;
271+ for ( int k = 0 ; k < n ; k ++ )
272+ {
273+ sum = NumOps . Add ( sum , NumOps . Multiply ( cov1 [ i , k ] , cov2 [ k , j ] ) ) ;
274+ }
275+ product [ i , j ] = sum ;
276+ }
277+ }
278+
279+ // For computing Tr(√A), we use the identity that for SPD matrices:
280+ // Tr(√A) = sum of square roots of eigenvalues
281+ // Use power iteration to approximate the trace of the square root
282+ // via Newton-Schulz iteration: Y_{k+1} = 0.5 * Y_k * (3I - Y_k^2 * A)
283+ // with Y_0 = A / ||A||_F, converges to √(A^{-1}), so we need to adapt
284+
285+ // Simpler approach: Use the property that for SPD matrices,
286+ // Tr(√A) ≈ √Tr(A) when eigenvalues are close together,
287+ // but better to use Denman-Beavers iteration which converges to √A
288+
289+ // Denman-Beavers iteration: Y_0 = A, Z_0 = I
290+ // Y_{k+1} = 0.5 * (Y_k + Z_k^{-1})
291+ // Z_{k+1} = 0.5 * (Z_k + Y_k^{-1})
292+ // Converges to: Y → √A, Z → √(A^{-1})
293+
294+ // For numerical stability, use a simpler approximation with eigenvalue sum
295+ // First, symmetrize the product to handle numerical issues: (A + A^T) / 2
296+ var symProduct = new Matrix < T > ( n , n ) ;
297+ for ( int i = 0 ; i < n ; i ++ )
298+ {
299+ for ( int j = 0 ; j < n ; j ++ )
300+ {
301+ symProduct [ i , j ] = NumOps . Divide (
302+ NumOps . Add ( product [ i , j ] , product [ j , i ] ) ,
303+ NumOps . FromDouble ( 2.0 ) ) ;
304+ }
305+ }
306+
307+ // Use Newton-Schulz iteration for matrix square root
308+ // Start with Y = A / ||A||_F for numerical stability
309+ var frobNormSq = NumOps . Zero ;
310+ for ( int i = 0 ; i < n ; i ++ )
311+ {
312+ for ( int j = 0 ; j < n ; j ++ )
313+ {
314+ frobNormSq = NumOps . Add ( frobNormSq , NumOps . Multiply ( symProduct [ i , j ] , symProduct [ i , j ] ) ) ;
315+ }
316+ }
317+ var frobNorm = NumOps . Sqrt ( frobNormSq ) ;
318+
319+ // If the product is essentially zero, return zero
320+ if ( NumOps . LessThan ( frobNorm , NumOps . FromDouble ( 1e-10 ) ) )
321+ {
322+ return NumOps . Zero ;
323+ }
324+
325+ // Scale for numerical stability
326+ var scale = NumOps . Sqrt ( frobNorm ) ;
327+ var Y = new Matrix < T > ( n , n ) ;
328+ for ( int i = 0 ; i < n ; i ++ )
329+ {
330+ for ( int j = 0 ; j < n ; j ++ )
331+ {
332+ Y [ i , j ] = NumOps . Divide ( symProduct [ i , j ] , scale ) ;
333+ }
334+ }
335+
336+ // Newton-Schulz iteration: Y_{k+1} = 0.5 * Y_k * (3I - Y_k * Y_k)
337+ // Run for a fixed number of iterations
338+ const int maxIterations = 15 ;
339+ var identity = Matrix < T > . CreateIdentity ( n ) ;
340+
341+ for ( int iter = 0 ; iter < maxIterations ; iter ++ )
342+ {
343+ // Compute Y * Y
344+ var YY = MatrixMultiply ( Y , Y ) ;
345+
346+ // Compute 3I - Y*Y
347+ var threeIMinusYY = new Matrix < T > ( n , n ) ;
348+ for ( int i = 0 ; i < n ; i ++ )
349+ {
350+ for ( int j = 0 ; j < n ; j ++ )
351+ {
352+ threeIMinusYY [ i , j ] = NumOps . Subtract (
353+ NumOps . Multiply ( NumOps . FromDouble ( 3.0 ) , identity [ i , j ] ) ,
354+ YY [ i , j ] ) ;
355+ }
356+ }
357+
358+ // Y = 0.5 * Y * (3I - Y*Y)
359+ var newY = MatrixMultiply ( Y , threeIMinusYY ) ;
360+ for ( int i = 0 ; i < n ; i ++ )
361+ {
362+ for ( int j = 0 ; j < n ; j ++ )
363+ {
364+ Y [ i , j ] = NumOps . Multiply ( NumOps . FromDouble ( 0.5 ) , newY [ i , j ] ) ;
365+ }
366+ }
367+ }
368+
369+ // Y now approximates √(A/scale), so √A ≈ Y * √scale
370+ // Tr(√A) = √scale * Tr(Y)
371+ var sqrtScale = NumOps . Sqrt ( scale ) ;
372+ var traceY = NumOps . Zero ;
373+ for ( int i = 0 ; i < n ; i ++ )
374+ {
375+ traceY = NumOps . Add ( traceY , Y [ i , i ] ) ;
376+ }
377+
378+ return NumOps . Multiply ( sqrtScale , traceY ) ;
379+ }
380+
381+ /// <summary>
382+ /// Multiplies two matrices.
383+ /// </summary>
384+ private Matrix < T > MatrixMultiply ( Matrix < T > a , Matrix < T > b )
385+ {
386+ int n = a . Rows ;
387+ var result = new Matrix < T > ( n , n ) ;
388+
389+ for ( int i = 0 ; i < n ; i ++ )
390+ {
391+ for ( int j = 0 ; j < n ; j ++ )
392+ {
393+ var sum = NumOps . Zero ;
394+ for ( int k = 0 ; k < n ; k ++ )
395+ {
396+ sum = NumOps . Add ( sum , NumOps . Multiply ( a [ i , k ] , b [ k , j ] ) ) ;
397+ }
398+ result [ i , j ] = sum ;
399+ }
400+ }
401+
402+ return result ;
403+ }
404+
265405 /// <summary>
266406 /// Computes FID using pre-computed statistics.
267407 /// Useful when you want to compare against a fixed set of real images.
0 commit comments