Skip to content

Commit b88ca6c

Browse files
committed
feat: implement norm computation for vectors and matrices with corresponding tests
1 parent 53c457b commit b88ca6c

2 files changed

Lines changed: 307 additions & 0 deletions

File tree

src/compute-engine/library/linear-algebra.ts

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,166 @@ export const LINEAR_ALGEBRA_LIBRARY: SymbolDefinitions[] = [
586586
return ce.box(['List', ...rows]);
587587
},
588588
},
589+
590+
// Computes vector and matrix norms
591+
// For vectors:
592+
// - L2 (Euclidean, default): √(Σ|xi|²)
593+
// - L1: Σ|xi|
594+
// - L∞ (max): max(|xi|)
595+
// - Lp: (Σ|xi|^p)^(1/p)
596+
// For matrices:
597+
// - Frobenius (default): √(ΣΣ|aij|²)
598+
Norm: {
599+
complexity: 8200,
600+
signature: '(value, number|string?) -> number',
601+
evaluate: (ops, { engine: ce }): BoxedExpression | undefined => {
602+
const x = ops[0].evaluate();
603+
const normTypeExpr = ops.length > 1 ? ops[1].evaluate() : undefined;
604+
605+
// Scalar: |x| (absolute value)
606+
if (x.isNumber) {
607+
return ce.box(['Abs', x]).evaluate();
608+
}
609+
610+
if (!isBoxedTensor(x)) return undefined;
611+
612+
const shape = x.shape;
613+
614+
// Determine norm type
615+
let normType: number | string = 2; // Default to L2/Frobenius
616+
if (normTypeExpr) {
617+
if (
618+
normTypeExpr.string === 'Infinity' ||
619+
normTypeExpr.symbol === 'Infinity' ||
620+
normTypeExpr.re === Infinity
621+
) {
622+
normType = 'infinity';
623+
} else if (normTypeExpr.string === 'Frobenius') {
624+
normType = 'frobenius';
625+
} else if (normTypeExpr.re !== undefined) {
626+
normType = normTypeExpr.re;
627+
}
628+
}
629+
630+
// Vector norm (rank 1)
631+
if (shape.length === 1) {
632+
const elements: BoxedExpression[] = [];
633+
const n = shape[0];
634+
for (let i = 0; i < n; i++) {
635+
const val = x.tensor.at(i + 1);
636+
elements.push(val !== undefined ? ce.box(val) : ce.Zero);
637+
}
638+
639+
if (normType === 1) {
640+
// L1 norm: sum of absolute values
641+
let sum: BoxedExpression = ce.Zero;
642+
for (const el of elements) {
643+
sum = sum.add(ce.box(['Abs', el]).evaluate());
644+
}
645+
return sum.evaluate();
646+
}
647+
648+
if (normType === 2) {
649+
// L2 norm: sqrt of sum of squares
650+
let sumSq: BoxedExpression = ce.Zero;
651+
for (const el of elements) {
652+
const absEl = ce.box(['Abs', el]).evaluate();
653+
sumSq = sumSq.add(absEl.mul(absEl));
654+
}
655+
return ce.box(['Sqrt', sumSq]).evaluate();
656+
}
657+
658+
if (normType === 'infinity') {
659+
// L∞ norm: max absolute value
660+
let maxVal: BoxedExpression = ce.Zero;
661+
for (const el of elements) {
662+
const absEl = ce.box(['Abs', el]).evaluate();
663+
// Compare: use numeric comparison
664+
const absNum = absEl.re ?? 0;
665+
const maxNum = maxVal.re ?? 0;
666+
if (absNum > maxNum) {
667+
maxVal = absEl;
668+
}
669+
}
670+
return maxVal;
671+
}
672+
673+
// General Lp norm: (Σ|xi|^p)^(1/p)
674+
if (typeof normType === 'number' && normType > 0) {
675+
const p = normType;
676+
let sumPow: BoxedExpression = ce.Zero;
677+
for (const el of elements) {
678+
const absEl = ce.box(['Abs', el]).evaluate();
679+
sumPow = sumPow.add(ce.box(['Power', absEl, p]).evaluate());
680+
}
681+
// Use Root for integer p values, Power for non-integer
682+
// Use .N() to get numeric result for non-perfect roots
683+
if (Number.isInteger(p)) {
684+
return ce.box(['Root', sumPow, p]).N();
685+
}
686+
return ce.box(['Power', sumPow, ce.box(['Divide', 1, p])]).N();
687+
}
688+
689+
return undefined;
690+
}
691+
692+
// Matrix norm (rank 2)
693+
if (shape.length === 2) {
694+
const [m, n] = shape;
695+
696+
// Frobenius norm (default for matrices): √(ΣΣ|aij|²)
697+
if (normType === 2 || normType === 'frobenius') {
698+
let sumSq: BoxedExpression = ce.Zero;
699+
for (let i = 0; i < m; i++) {
700+
for (let j = 0; j < n; j++) {
701+
const val = x.tensor.at(i + 1, j + 1);
702+
const el = val !== undefined ? ce.box(val) : ce.Zero;
703+
const absEl = ce.box(['Abs', el]).evaluate();
704+
sumSq = sumSq.add(absEl.mul(absEl));
705+
}
706+
}
707+
return ce.box(['Sqrt', sumSq]).evaluate();
708+
}
709+
710+
// L1 (max column sum of absolute values)
711+
if (normType === 1) {
712+
let maxColSum = 0;
713+
for (let j = 0; j < n; j++) {
714+
let colSum = 0;
715+
for (let i = 0; i < m; i++) {
716+
const val = x.tensor.at(i + 1, j + 1);
717+
const el = val !== undefined ? ce.box(val) : ce.Zero;
718+
const absEl = ce.box(['Abs', el]).evaluate();
719+
colSum += absEl.re ?? 0;
720+
}
721+
if (colSum > maxColSum) maxColSum = colSum;
722+
}
723+
return ce.number(maxColSum);
724+
}
725+
726+
// L∞ (max row sum of absolute values)
727+
if (normType === 'infinity') {
728+
let maxRowSum = 0;
729+
for (let i = 0; i < m; i++) {
730+
let rowSum = 0;
731+
for (let j = 0; j < n; j++) {
732+
const val = x.tensor.at(i + 1, j + 1);
733+
const el = val !== undefined ? ce.box(val) : ce.Zero;
734+
const absEl = ce.box(['Abs', el]).evaluate();
735+
rowSum += absEl.re ?? 0;
736+
}
737+
if (rowSum > maxRowSum) maxRowSum = rowSum;
738+
}
739+
return ce.number(maxRowSum);
740+
}
741+
742+
return undefined;
743+
}
744+
745+
// Higher-rank tensors: not supported yet
746+
return undefined;
747+
},
748+
},
589749
},
590750
];
591751

test/compute-engine/linear-algebra.test.ts

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,150 @@ describe('OnesMatrix', () => {
832832
);
833833
});
834834
});
835+
836+
describe('Norm', () => {
837+
// Scalar norm (absolute value)
838+
it('should compute the norm of a scalar', () => {
839+
const result = ce.box(['Norm', 5]).evaluate();
840+
expect(result.toString()).toMatchInlineSnapshot(`5`);
841+
});
842+
843+
it('should compute the norm of a negative scalar', () => {
844+
const result = ce.box(['Norm', -7]).evaluate();
845+
expect(result.toString()).toMatchInlineSnapshot(`7`);
846+
});
847+
848+
// Vector L2 norm (default)
849+
it('should compute the L2 norm of a vector (3-4-5 triangle)', () => {
850+
// √(3² + 4²) = √(9 + 16) = √25 = 5
851+
const result = ce.box(['Norm', ['List', 3, 4]]).evaluate();
852+
expect(result.toString()).toMatchInlineSnapshot(`5`);
853+
});
854+
855+
it('should compute the L2 norm of a vector with negatives', () => {
856+
// √(3² + (-4)²) = 5
857+
const result = ce.box(['Norm', ['List', 3, -4]]).evaluate();
858+
expect(result.toString()).toMatchInlineSnapshot(`5`);
859+
});
860+
861+
it('should compute the L2 norm of a 3D vector', () => {
862+
// √(1² + 2² + 2²) = √(1 + 4 + 4) = √9 = 3
863+
const result = ce.box(['Norm', ['List', 1, 2, 2]]).evaluate();
864+
expect(result.toString()).toMatchInlineSnapshot(`3`);
865+
});
866+
867+
// Vector L1 norm
868+
it('should compute the L1 norm of a vector', () => {
869+
// |3| + |-4| = 3 + 4 = 7
870+
const result = ce.box(['Norm', ['List', 3, -4], 1]).evaluate();
871+
expect(result.toString()).toMatchInlineSnapshot(`7`);
872+
});
873+
874+
it('should compute the L1 norm of a longer vector', () => {
875+
// |1| + |-2| + |3| + |-4| = 1 + 2 + 3 + 4 = 10
876+
const result = ce.box(['Norm', ['List', 1, -2, 3, -4], 1]).evaluate();
877+
expect(result.toString()).toMatchInlineSnapshot(`10`);
878+
});
879+
880+
// Vector L∞ norm (max absolute value)
881+
it('should compute the L-infinity norm of a vector', () => {
882+
// max(|3|, |-4|) = 4
883+
const result = ce
884+
.box(['Norm', ['List', 3, -4], 'Infinity'])
885+
.evaluate();
886+
expect(result.toString()).toMatchInlineSnapshot(`4`);
887+
});
888+
889+
it('should compute the L-infinity norm with string', () => {
890+
// max(|1|, |-5|, |3|) = 5
891+
const result = ce
892+
.box(['Norm', ['List', 1, -5, 3], { str: 'Infinity' }])
893+
.evaluate();
894+
expect(result.toString()).toMatchInlineSnapshot(`5`);
895+
});
896+
897+
// Vector Lp norm (general)
898+
it('should compute the L3 norm of a vector', () => {
899+
// (|3|³ + |4|³)^(1/3) = (27 + 64)^(1/3) = 91^(1/3) ≈ 4.498
900+
const result = ce.box(['Norm', ['List', 3, 4], 3]).evaluate();
901+
expect(result.re).toBeCloseTo(4.4979, 3);
902+
});
903+
904+
it('should compute the L4 norm of a vector', () => {
905+
// (|2|⁴ + |2|⁴)^(1/4) = (16 + 16)^(1/4) = 32^(1/4) ≈ 2.378
906+
const result = ce.box(['Norm', ['List', 2, 2], 4]).evaluate();
907+
expect(result.re).toBeCloseTo(2.3784, 3);
908+
});
909+
910+
// Matrix Frobenius norm (default)
911+
it('should compute the Frobenius norm of a matrix', () => {
912+
// √(1² + 2² + 3² + 4²) = √(1 + 4 + 9 + 16) = √30 ≈ 5.477
913+
const result = ce.box(['Norm', sq2_n]).evaluate();
914+
expect(result.re).toBeCloseTo(5.4772, 3);
915+
});
916+
917+
it('should compute the Frobenius norm of a non-square matrix', () => {
918+
// √(1² + 2² + 3² + 4² + 5² + 6²) = √(1+4+9+16+25+36) = √91 ≈ 9.539
919+
const result = ce.box(['Norm', m23_n]).evaluate();
920+
expect(result.re).toBeCloseTo(9.5394, 3);
921+
});
922+
923+
it('should compute the Frobenius norm with explicit type', () => {
924+
const result = ce
925+
.box(['Norm', sq2_n, { str: 'Frobenius' }])
926+
.evaluate();
927+
expect(result.re).toBeCloseTo(5.4772, 3);
928+
});
929+
930+
// Matrix L1 norm (max column sum)
931+
it('should compute the L1 norm of a matrix', () => {
932+
// [[1, 2], [3, 4]]
933+
// Column sums: |1| + |3| = 4, |2| + |4| = 6
934+
// max = 6
935+
const result = ce.box(['Norm', sq2_n, 1]).evaluate();
936+
expect(result.toString()).toMatchInlineSnapshot(`6`);
937+
});
938+
939+
it('should compute the L1 norm of a matrix with negatives', () => {
940+
// [[1, -2], [-3, 4]]
941+
// Column sums: |1| + |-3| = 4, |-2| + |4| = 6
942+
// max = 6
943+
const result = ce
944+
.box(['Norm', ['List', ['List', 1, -2], ['List', -3, 4]], 1])
945+
.evaluate();
946+
expect(result.toString()).toMatchInlineSnapshot(`6`);
947+
});
948+
949+
// Matrix L∞ norm (max row sum)
950+
it('should compute the L-infinity norm of a matrix', () => {
951+
// [[1, 2], [3, 4]]
952+
// Row sums: |1| + |2| = 3, |3| + |4| = 7
953+
// max = 7
954+
const result = ce
955+
.box(['Norm', sq2_n, { str: 'Infinity' }])
956+
.evaluate();
957+
expect(result.toString()).toMatchInlineSnapshot(`7`);
958+
});
959+
960+
it('should compute the L-infinity norm of a non-square matrix', () => {
961+
// [[1, 2, 3], [4, 5, 6]]
962+
// Row sums: 1 + 2 + 3 = 6, 4 + 5 + 6 = 15
963+
// max = 15
964+
const result = ce
965+
.box(['Norm', m23_n, { str: 'Infinity' }])
966+
.evaluate();
967+
expect(result.toString()).toMatchInlineSnapshot(`15`);
968+
});
969+
970+
// Zero vector
971+
it('should compute the norm of a zero vector', () => {
972+
const result = ce.box(['Norm', ['List', 0, 0, 0]]).evaluate();
973+
expect(result.toString()).toMatchInlineSnapshot(`0`);
974+
});
975+
976+
// Single element vector
977+
it('should compute the norm of a single element vector', () => {
978+
const result = ce.box(['Norm', ['List', -5]]).evaluate();
979+
expect(result.toString()).toMatchInlineSnapshot(`5`);
980+
});
981+
});

0 commit comments

Comments
 (0)