Skip to content

Commit bccae5c

Browse files
committed
feat: implement nested sqrt equation solving and add corresponding tests
1 parent 2113483 commit bccae5c

3 files changed

Lines changed: 237 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
ce.parse('\\sqrt{x} = x').solve('x'); // → [0, 1]
1414
```
1515

16+
- **Nested Sqrt Equation Solving**: The equation solver now handles nested
17+
sqrt equations of the form `√(x + √x) = a` using substitution. These patterns
18+
have √x inside the argument of an outer sqrt. The solver uses u = √x
19+
substitution, solves the resulting quadratic, and filters negative u values.
20+
21+
```javascript
22+
ce.parse('\\sqrt{x + 2\\sqrt{x}} = 3').solve('x'); // → [11 - 2√10] ≈ 4.675
23+
ce.parse('\\sqrt{x + \\sqrt{x}} = 2').solve('x'); // → [9/2 - √17/2] ≈ 2.438
24+
ce.parse('\\sqrt{x - \\sqrt{x}} = 1').solve('x'); // → [φ²] ≈ 2.618
25+
```
26+
1627
- **Quadratic Equations Without Constant Term**: Added support for solving
1728
quadratic equations of the form `ax² + bx = 0` (missing constant term).
1829
These are solved by factoring: `x(ax + b) = 0``x = 0` or `x = -b/a`.

src/compute-engine/boxed-expression/solve.ts

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,184 @@ function transformSqrtLinearEquation(
729729
return transformed;
730730
}
731731

732+
/**
733+
* Detect and solve nested sqrt equations of the form √(f(x, √x)) = a.
734+
*
735+
* Pattern 4: √(x + √x) = a (or similar with √x inside outer sqrt)
736+
* - Use substitution u = √x, so x = u²
737+
* - √(u² + u) = a becomes u² + u = a² (after squaring)
738+
* - Solve quadratic for u, then x = u² for valid u ≥ 0
739+
*
740+
* Returns the solutions for x, or null if pattern not detected.
741+
*/
742+
function solveNestedSqrtEquation(
743+
expr: BoxedExpression,
744+
variable: string
745+
): BoxedExpression[] | null {
746+
if (expr.operator !== 'Add') return null;
747+
748+
const ce = expr.engine;
749+
const ops = expr.ops;
750+
if (!ops || ops.length === 0) return null;
751+
752+
// Find the outer sqrt term
753+
let outerSqrt: BoxedExpression | null = null;
754+
let sqrtIndex = -1;
755+
756+
for (let i = 0; i < ops.length; i++) {
757+
if (ops[i].operator === 'Sqrt') {
758+
outerSqrt = ops[i];
759+
sqrtIndex = i;
760+
break;
761+
}
762+
}
763+
764+
if (!outerSqrt || sqrtIndex < 0) return null;
765+
766+
// Get the argument of the outer sqrt
767+
const outerArg = outerSqrt.op1;
768+
if (!outerArg) return null;
769+
770+
// Check if the outer sqrt argument contains an inner √x (Sqrt of just the variable)
771+
// Pattern: √(... + √x + ...) or √(... + a*√x + ...)
772+
let hasInnerSqrtX = false;
773+
let innerSqrtCoeff: BoxedExpression | null = null;
774+
775+
if (outerArg.operator === 'Add' && outerArg.ops) {
776+
for (const term of outerArg.ops) {
777+
// Check for √x directly
778+
if (
779+
term.operator === 'Sqrt' &&
780+
term.op1?.symbol === variable
781+
) {
782+
hasInnerSqrtX = true;
783+
innerSqrtCoeff = ce.One;
784+
break;
785+
}
786+
// Check for Negate(Sqrt(x))
787+
if (
788+
term.operator === 'Negate' &&
789+
term.op1?.operator === 'Sqrt' &&
790+
term.op1?.op1?.symbol === variable
791+
) {
792+
hasInnerSqrtX = true;
793+
innerSqrtCoeff = ce.NegativeOne;
794+
break;
795+
}
796+
// Check for coefficient * √x
797+
if (term.operator === 'Multiply' && term.ops) {
798+
for (const factor of term.ops) {
799+
if (
800+
factor.operator === 'Sqrt' &&
801+
factor.op1?.symbol === variable
802+
) {
803+
hasInnerSqrtX = true;
804+
// Get coefficient (product of other factors)
805+
const otherFactors = term.ops.filter((f) => f !== factor);
806+
innerSqrtCoeff =
807+
otherFactors.length === 1
808+
? otherFactors[0]
809+
: ce.function('Multiply', otherFactors);
810+
break;
811+
}
812+
}
813+
if (hasInnerSqrtX) break;
814+
}
815+
}
816+
}
817+
818+
if (!hasInnerSqrtX) return null;
819+
820+
// We have √(f(x, √x)) = a pattern
821+
// Collect the constant terms (non-sqrt parts of the Add expression)
822+
const nonSqrtTerms = ops.filter((_, i) => i !== sqrtIndex);
823+
if (nonSqrtTerms.length === 0) return null;
824+
825+
// a = -(sum of non-sqrt terms)
826+
let aExpr: BoxedExpression;
827+
if (nonSqrtTerms.length === 1) {
828+
aExpr = nonSqrtTerms[0].neg();
829+
} else {
830+
aExpr = ce.function('Add', nonSqrtTerms).neg();
831+
}
832+
833+
// The constant should not contain the variable
834+
if (aExpr.has(variable)) return null;
835+
836+
// Now we have: √(f(x, √x)) = a
837+
// Substitute u = √x, so x = u², √x = u
838+
// The outer arg f(x, √x) becomes f(u², u)
839+
840+
// Create a unique internal symbol for u (avoiding wildcard prefix _)
841+
// Use __internalU to avoid collision with user symbols
842+
const uSymbolName = '__internalU';
843+
const uSymbol = ce.symbol(uSymbolName);
844+
845+
// Substitute √x → u and x → u² in the outer sqrt argument
846+
// IMPORTANT: Must replace √x first, THEN x, otherwise √x becomes √(u²)
847+
const step1 = outerArg.replace(
848+
{ match: ['Sqrt', variable], replace: uSymbol },
849+
{ recursive: true }
850+
);
851+
const substitutedArg = step1?.subs({ [variable]: ce.box(['Power', uSymbolName, 2]) });
852+
853+
if (!substitutedArg) return null;
854+
855+
// Now we have √(g(u)) = a where g(u) = substitutedArg
856+
// Square both sides: g(u) = a²
857+
// So g(u) - a² = 0
858+
859+
const aSquared = aExpr.mul(aExpr);
860+
const uEquation = substitutedArg.sub(aSquared).simplify();
861+
862+
// Solve for u
863+
ce.pushScope();
864+
ce.declare(uSymbolName, { type: 'real' });
865+
866+
const uSolutions = findUnivariateRoots(uEquation, uSymbolName);
867+
868+
ce.popScope();
869+
870+
if (uSolutions.length === 0) return null;
871+
872+
// Convert u solutions back to x = u²
873+
// Only keep solutions where u ≥ 0 (since u = √x ≥ 0)
874+
const xSolutions: BoxedExpression[] = [];
875+
876+
for (const uVal of uSolutions) {
877+
// Check if u is real and non-negative (since u = √x ≥ 0)
878+
const uNumeric = uVal.N();
879+
880+
// Use the expression's isNegative property for reliable checking
881+
if (uNumeric.isNegative) continue; // Skip negative u values
882+
883+
// Also check numericValue for cases where isNegative might not be set
884+
const uNum = uNumeric.numericValue;
885+
if (uNum !== null) {
886+
let uReal: number | null = null;
887+
if (typeof uNum === 'number') {
888+
uReal = uNum;
889+
} else if (typeof uNum === 'object' && 'decimal' in uNum) {
890+
// BigNumericValue object - extract numeric value from decimal
891+
const decimal = (uNum as any).decimal;
892+
if (decimal && typeof decimal.toNumber === 'function') {
893+
uReal = decimal.toNumber();
894+
}
895+
} else if (typeof uNum === 'object' && 're' in uNum) {
896+
// Complex number object
897+
uReal = (uNum as any).re;
898+
}
899+
if (uReal !== null && uReal < -1e-10) continue; // Skip negative u values
900+
}
901+
902+
// x = u²
903+
const xVal = uVal.mul(uVal).simplify();
904+
xSolutions.push(xVal);
905+
}
906+
907+
return xSolutions.length > 0 ? xSolutions : null;
908+
}
909+
732910
/**
733911
* Expression is a function of a single variable (`x`) or an Equality
734912
*
@@ -755,6 +933,14 @@ export function findUnivariateRoots(
755933
// Clear denominators to enable matching of expressions like F - 3x/h = 0
756934
expr = clearDenominators(expr);
757935

936+
// Try to solve nested sqrt equations: √(f(x, √x)) = a
937+
// This uses substitution u = √x, solves for u, then converts back to x = u²
938+
const nestedSqrtSolutions = solveNestedSqrtEquation(expr, x);
939+
if (nestedSqrtSolutions !== null) {
940+
// Validate and return the solutions
941+
return validateRoots(originalExpr, x, nestedSqrtSolutions);
942+
}
943+
758944
// Transform sqrt-linear equations: √(f(x)) = g(x) → f(x) - g(x)² = 0
759945
// This handles Pattern 2: √(ax+b) = cx+d by squaring both sides.
760946
// Must be done before pattern matching so quadratic formula can match.

test/compute-engine/solve.test.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,46 @@ describe('SQRT-LINEAR EQUATIONS (Pattern 2)', () => {
479479
});
480480
});
481481

482+
// Tests for nested sqrt equations: √(f(x, √x)) = a (Pattern 4 from TODO #15)
483+
// Uses substitution u = √x, solves for u, then x = u² with u ≥ 0 filtering
484+
describe('NESTED SQRT EQUATIONS (Pattern 4)', () => {
485+
// √(x + 2√x) = 3 → u = √x, √(u² + 2u) = 3 → u² + 2u = 9 → u² + 2u - 9 = 0
486+
// u = (-2 ± √40)/2 = -1 ± √10
487+
// u₁ = -1 + √10 ≈ 2.16 ≥ 0 ✓, u₂ = -1 - √10 ≈ -4.16 < 0 ❌
488+
// x = u² = (-1 + √10)² = 1 - 2√10 + 10 = 11 - 2√10 ≈ 4.675
489+
test('should solve sqrt(x + 2sqrt(x)) = 3 with negative u filtered', () => {
490+
const e = expr('\\sqrt{x + 2\\sqrt{x}} = 3');
491+
const result = e.solve('x');
492+
expect(result?.length).toBe(1);
493+
// x = 11 - 2√10 ≈ 4.675
494+
expect(result?.[0]?.N().toString()).toMatch(/^4\.67/);
495+
});
496+
497+
// √(x + √x) = 2 → u² + u = 4 → u² + u - 4 = 0
498+
// u = (-1 ± √17)/2
499+
// u₁ = (-1 + √17)/2 ≈ 1.56 ≥ 0 ✓, u₂ = (-1 - √17)/2 ≈ -2.56 < 0 ❌
500+
// x = u² ≈ 2.44
501+
test('should solve sqrt(x + sqrt(x)) = 2 with negative u filtered', () => {
502+
const e = expr('\\sqrt{x + \\sqrt{x}} = 2');
503+
const result = e.solve('x');
504+
expect(result?.length).toBe(1);
505+
expect(result?.[0]?.N().toString()).toMatch(/^2\.43/);
506+
});
507+
508+
// √(x - √x) = 1 → u² - u = 1 → u² - u - 1 = 0
509+
// u = (1 ± √5)/2
510+
// u₁ = (1 + √5)/2 ≈ 1.618 (golden ratio) ≥ 0 ✓
511+
// u₂ = (1 - √5)/2 ≈ -0.618 < 0 ❌
512+
// x = u² = φ² ≈ 2.618
513+
test('should solve sqrt(x - sqrt(x)) = 1 with negative u filtered', () => {
514+
const e = expr('\\sqrt{x - \\sqrt{x}} = 1');
515+
const result = e.solve('x');
516+
expect(result?.length).toBe(1);
517+
// x = φ² = ((1+√5)/2)² ≈ 2.618
518+
expect(result?.[0]?.N().toString()).toMatch(/^2\.61/);
519+
});
520+
});
521+
482522
// Tests for trigonometric equations
483523
describe('SOLVING TRIGONOMETRIC EQUATIONS', () => {
484524
test('should solve sin(x) = 0', () => {

0 commit comments

Comments
 (0)