Skip to content

Commit 6ba3bca

Browse files
committed
feat: enhance ask() functionality to normalize inequality patterns and support wildcard symbols in bound queries
1 parent c66fb5b commit 6ba3bca

3 files changed

Lines changed: 174 additions & 7 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
- **Improved `ask()` Queries**: `ce.ask()` now matches patterns with wildcards
66
correctly, can answer common “bound” queries such as
7-
`ask(["Greater", "x", "_k"])`, and falls back to `verify()` for closed
8-
predicates when the fact is known but not stored as an explicit assumption.
7+
`ask(["Greater", "x", "_k"])` and `ask(["Greater", "_x", "_k"])`, normalizes
8+
inequality patterns for matching (e.g. `ask(["Greater", "_x", 0])`), and falls
9+
back to `verify()` for closed predicates when the fact is known but not stored
10+
as an explicit assumption.
911

1012
- **Tri-state `verify()`**: Implemented `ce.verify()` as a truth query that
1113
returns `true`, `false` or `undefined` when a predicate cannot be determined
@@ -335,7 +337,8 @@
335337
want them reduced to `\frac{0}{0}``\operatorname{NaN}`.
336338
- **Implicit multiplication powers**: `xx` now simplifies to `x^2`.
337339
- **Exponential/log separation**: `\exp(\log(x)+y)` and `\exp(\log(x)-y)` now
338-
simplify without leaving a remaining `\log(...)` term in the exponent.
340+
simplify without leaving a remaining `\log(...)` term in the exponent
341+
(preferred by the default cost function).
339342

340343
### New Features
341344

src/compute-engine/index.ts

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ import { SIMPLIFY_RULES } from './symbolic/simplify-rules';
133133
import { bigint } from './numerics/bigint';
134134
import { canonicalFunctionLiteral, lookup } from './function-utils';
135135

136-
import { assume } from './assume';
136+
import { assume, getInequalityBoundsFromAssumptions } from './assume';
137137
import {
138138
createSequenceHandler,
139139
validateSequenceDefinition,
@@ -151,6 +151,8 @@ import {
151151
checkSequence as checkSequenceImpl,
152152
} from './oeis';
153153

154+
import { isWildcard, wildcardName } from './boxed-expression/boxed-patterns';
155+
154156
export * from './global-types';
155157

156158
export { validatePattern };
@@ -2188,11 +2190,155 @@ export class ComputeEngine implements IComputeEngine {
21882190
ask(pattern: BoxedExpression): BoxedSubstitution[] {
21892191
const pat = this.box(pattern, { canonical: false });
21902192
const result: BoxedSubstitution[] = [];
2193+
2194+
const patternHasWildcards = (expr: BoxedExpression): boolean => {
2195+
if (expr.operator?.startsWith('_')) return true;
2196+
if (isWildcard(expr)) return true;
2197+
if (expr.ops) return expr.ops.some(patternHasWildcards);
2198+
return false;
2199+
};
2200+
2201+
const pushResult = (m: BoxedSubstitution) => {
2202+
const keys = Object.keys(m).sort();
2203+
for (const prev of result) {
2204+
const prevKeys = Object.keys(prev).sort();
2205+
if (prevKeys.length !== keys.length) continue;
2206+
let same = true;
2207+
for (let i = 0; i < keys.length; i++) {
2208+
if (prevKeys[i] !== keys[i]) {
2209+
same = false;
2210+
break;
2211+
}
2212+
const k = keys[i]!;
2213+
if (!m[k]!.isSame(prev[k]!)) {
2214+
same = false;
2215+
break;
2216+
}
2217+
}
2218+
if (same) return;
2219+
}
2220+
result.push(m);
2221+
};
2222+
21912223
const assumptions = this.context.assumptions;
2224+
2225+
const candidatesFromAssumptions = (): string[] => {
2226+
const candidates = new Set<string>();
2227+
for (const [assumption, val] of assumptions) {
2228+
if (val !== true) continue;
2229+
for (const s of assumption.symbols) candidates.add(s);
2230+
}
2231+
return [...candidates];
2232+
};
2233+
2234+
const normalizedInequalityPatterns = (
2235+
expr: BoxedExpression
2236+
): Array<{ pattern: BoxedExpression; matchPermutations?: boolean }> => {
2237+
const op = expr.operator;
2238+
if (
2239+
op !== 'Less' &&
2240+
op !== 'LessEqual' &&
2241+
op !== 'Greater' &&
2242+
op !== 'GreaterEqual'
2243+
)
2244+
return [{ pattern: expr }];
2245+
2246+
const lhs = op === 'Greater' || op === 'GreaterEqual' ? expr.op2 : expr.op1;
2247+
const rhs = op === 'Greater' || op === 'GreaterEqual' ? expr.op1 : expr.op2;
2248+
const normalizedOp = op === 'Less' || op === 'Greater' ? 'Less' : 'LessEqual';
2249+
2250+
// Normalize to Less/LessEqual with RHS = 0, matching how assumptions are stored:
2251+
// Greater(a, b) -> Less(b - a, 0)
2252+
// Less(a, b) -> Less(a - b, 0)
2253+
const diff = this.box(['Add', lhs, ['Negate', rhs]], { canonical: false });
2254+
return [
2255+
{ pattern: expr },
2256+
// For the normalized form, disable permutations: for commutative
2257+
// subexpressions (notably Add), allowing permutations can lead to
2258+
// ambiguous wildcard bindings and duplicate, surprising matches.
2259+
{
2260+
pattern: this.box([normalizedOp, diff, 0], { canonical: false }),
2261+
matchPermutations: false,
2262+
},
2263+
];
2264+
};
2265+
2266+
// B1: Element(x, _T) can be answered from the declared/inferred type of x
2267+
if (pat.operator === 'Element' && pat.op1?.symbol && isWildcard(pat.op2)) {
2268+
const typeWildcard = wildcardName(pat.op2);
2269+
if (typeWildcard && !typeWildcard.startsWith('__')) {
2270+
const symbolType = this.box(pat.op1.symbol).type;
2271+
if (!symbolType.isUnknown) {
2272+
pushResult({
2273+
[typeWildcard]: this.box(symbolType.toString(), { canonical: false }),
2274+
});
2275+
}
2276+
}
2277+
}
2278+
2279+
// B2: Inequality bound queries, e.g. Greater(x, _k) -> {_k: lowerBound}
2280+
if (
2281+
(pat.operator === 'Greater' ||
2282+
pat.operator === 'GreaterEqual' ||
2283+
pat.operator === 'Less' ||
2284+
pat.operator === 'LessEqual') &&
2285+
isWildcard(pat.op2)
2286+
) {
2287+
const boundWildcard = wildcardName(pat.op2);
2288+
if (boundWildcard && !boundWildcard.startsWith('__')) {
2289+
const isLower = pat.operator === 'Greater' || pat.operator === 'GreaterEqual';
2290+
const isStrict = pat.operator === 'Greater' || pat.operator === 'Less';
2291+
2292+
// Symbol on LHS: Greater(x, _k)
2293+
if (pat.op1?.symbol) {
2294+
const bounds = getInequalityBoundsFromAssumptions(this, pat.op1.symbol);
2295+
const bound = isLower ? bounds.lowerBound : bounds.upperBound;
2296+
const strictOk = isLower ? bounds.lowerStrict : bounds.upperStrict;
2297+
if (bound !== undefined && (!isStrict || strictOk === true))
2298+
pushResult({ [boundWildcard]: bound });
2299+
}
2300+
2301+
// Wildcard on LHS: Greater(_x, _k)
2302+
if (isWildcard(pat.op1)) {
2303+
const symbolWildcard = wildcardName(pat.op1);
2304+
if (symbolWildcard && !symbolWildcard.startsWith('__')) {
2305+
for (const s of candidatesFromAssumptions()) {
2306+
const bounds = getInequalityBoundsFromAssumptions(this, s);
2307+
const bound = isLower ? bounds.lowerBound : bounds.upperBound;
2308+
const strictOk = isLower ? bounds.lowerStrict : bounds.upperStrict;
2309+
if (bound === undefined || (isStrict && strictOk !== true)) continue;
2310+
pushResult({
2311+
[symbolWildcard]: this.box(s, { canonical: true }),
2312+
[boundWildcard]: bound,
2313+
});
2314+
}
2315+
}
2316+
}
2317+
}
2318+
}
2319+
2320+
const patternsToTry = normalizedInequalityPatterns(pat);
21922321
for (const [assumption, val] of assumptions) {
2193-
const m = pat.match(assumption);
2194-
if (m !== null && val === true) result.push(m);
2322+
if (val !== true) continue;
2323+
for (const { pattern: p, matchPermutations } of patternsToTry) {
2324+
const m = assumption.match(p, {
2325+
useVariations: true,
2326+
matchPermutations,
2327+
});
2328+
if (m !== null) pushResult(m);
2329+
}
2330+
}
2331+
2332+
// B3: For closed predicates (no wildcards), fall back to verify().
2333+
// This makes `ask()` useful for "is this known?" queries even when the
2334+
// fact is not explicitly stored in the assumptions DB (e.g. declarations).
2335+
if (result.length === 0 && !patternHasWildcards(pat)) {
2336+
// Use the canonical form so symbol declarations/definitions are visible
2337+
// to the evaluator.
2338+
const verified = this.verify(this.box(pattern, { canonical: true }));
2339+
if (verified === true) pushResult({});
21952340
}
2341+
21962342
return result;
21972343
}
21982344

test/compute-engine/ask.test.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@ describe('ASK', () => {
1919
expect(r[0]!._k.json).toBe(0);
2020
});
2121

22+
test('normalizes inequality patterns for matching', () => {
23+
const ce = new ComputeEngine();
24+
ce.assume(ce.parse('x > 0'));
25+
26+
const r = ce.ask(['Greater', '_x', 0]);
27+
expect(r.length).toBe(1);
28+
expect(r[0]!._x.symbol).toBe('x');
29+
});
30+
31+
test('supports wildcard symbols in bound queries', () => {
32+
const ce = new ComputeEngine();
33+
ce.assume(ce.parse('x > 0'));
34+
35+
const r = ce.ask(['Greater', '_x', '_k']);
36+
expect(r.length).toBe(1);
37+
expect(r[0]!._x.symbol).toBe('x');
38+
expect(r[0]!._k.json).toBe(0);
39+
});
40+
2241
test('is conservative about strictness of bounds', () => {
2342
const ce = new ComputeEngine();
2443
ce.assume(ce.parse('x \\ge 0'));
@@ -43,4 +62,3 @@ describe('ASK', () => {
4362
expect(r[0]!._T.json).toBe('finite_real');
4463
});
4564
});
46-

0 commit comments

Comments
 (0)