Skip to content

Commit 761b423

Browse files
Merge branch 'master' into dependency-upgrade
2 parents ec2ab46 + e706e57 commit 761b423

1 file changed

Lines changed: 146 additions & 37 deletions

File tree

  • sagemaker-core/src/sagemaker/core/jumpstart

sagemaker-core/src/sagemaker/core/jumpstart/search.py

Lines changed: 146 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,78 @@
77
logger = logging.getLogger(__name__)
88

99

10+
class _ExpressionNode:
11+
"""Base class for expression AST nodes."""
12+
13+
def evaluate(self, keywords: List[str]) -> bool:
14+
"""Evaluate this node against the given keywords."""
15+
raise NotImplementedError
16+
17+
18+
class _AndNode(_ExpressionNode):
19+
"""AND logical operator node."""
20+
21+
def __init__(self, left: _ExpressionNode, right: _ExpressionNode):
22+
self.left = left
23+
self.right = right
24+
25+
def evaluate(self, keywords: List[str]) -> bool:
26+
return self.left.evaluate(keywords) and self.right.evaluate(keywords)
27+
28+
29+
class _OrNode(_ExpressionNode):
30+
"""OR logical operator node."""
31+
32+
def __init__(self, left: _ExpressionNode, right: _ExpressionNode):
33+
self.left = left
34+
self.right = right
35+
36+
def evaluate(self, keywords: List[str]) -> bool:
37+
return self.left.evaluate(keywords) or self.right.evaluate(keywords)
38+
39+
40+
class _NotNode(_ExpressionNode):
41+
"""NOT logical operator node."""
42+
43+
def __init__(self, operand: _ExpressionNode):
44+
self.operand = operand
45+
46+
def evaluate(self, keywords: List[str]) -> bool:
47+
return not self.operand.evaluate(keywords)
48+
49+
50+
class _PatternNode(_ExpressionNode):
51+
"""Pattern matching node for keywords with wildcard support."""
52+
53+
def __init__(self, pattern: str):
54+
self.pattern = pattern.strip('"').strip("'")
55+
56+
def evaluate(self, keywords: List[str]) -> bool:
57+
"""Check if any keyword matches this pattern."""
58+
for keyword in keywords:
59+
if self._matches_pattern(keyword, self.pattern):
60+
return True
61+
return False
62+
63+
def _matches_pattern(self, keyword: str, pattern: str) -> bool:
64+
"""Check if a keyword matches a pattern with wildcard support."""
65+
if pattern.startswith("*") and pattern.endswith("*"):
66+
# Contains pattern: *text*
67+
stripped = pattern.strip("*")
68+
return stripped in keyword
69+
elif pattern.startswith("*"):
70+
# Ends with pattern: *text
71+
stripped = pattern[1:]
72+
return keyword.endswith(stripped)
73+
elif pattern.endswith("*"):
74+
# Starts with pattern: text*
75+
stripped = pattern[:-1]
76+
return keyword.startswith(stripped)
77+
else:
78+
# Exact match
79+
return keyword == pattern
80+
81+
1082
class _Filter:
1183
"""
1284
A filter that evaluates logical expressions against a list of keyword strings.
@@ -28,6 +100,7 @@ def __init__(self, expression: str) -> None:
28100
Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
29101
"""
30102
self.expression: str = expression
103+
self._ast: Optional[_ExpressionNode] = None
31104

32105
def match(self, keywords: List[str]) -> bool:
33106
"""
@@ -39,54 +112,90 @@ def match(self, keywords: List[str]) -> bool:
39112
Returns:
40113
bool: True if the expression evaluates to True for the given keywords, else False.
41114
"""
42-
expr: str = self._convert_expression(self.expression)
43115
try:
44-
return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any})
116+
if self._ast is None:
117+
self._ast = self._parse_expression(self.expression)
118+
return self._ast.evaluate(keywords)
45119
except Exception:
46120
return False
47121

48-
def _convert_expression(self, expr: str) -> str:
122+
def _parse_expression(self, expr: str) -> _ExpressionNode:
49123
"""
50-
Convert the logical filter expression into a Python-evaluable string.
124+
Parse the logical filter expression into an AST.
51125
52126
Args:
53-
expr (str): The raw expression to convert.
127+
expr (str): The raw expression to parse.
54128
55129
Returns:
56-
str: A Python expression string using 'any' and logical operators.
130+
_ExpressionNode: Root node of the parsed expression AST.
57131
"""
58-
tokens: List[str] = re.findall(
59-
r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE
60-
)
61-
62-
def wildcard_condition(pattern: str) -> str:
63-
pattern = pattern.strip('"').strip("'")
64-
stripped = pattern.strip("*")
132+
tokens = self._tokenize(expr)
133+
result, _ = self._parse_or_expression(tokens, 0)
134+
return result
135+
136+
def _tokenize(self, expr: str) -> List[str]:
137+
"""Tokenize the expression into logical operators, keywords, and parentheses."""
138+
return re.findall(r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE)
139+
140+
def _parse_or_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
141+
"""Parse OR expression (lowest precedence)."""
142+
left, pos = self._parse_and_expression(tokens, pos)
143+
144+
while pos < len(tokens) and tokens[pos].upper() == "OR":
145+
pos += 1 # Skip OR token
146+
right, pos = self._parse_and_expression(tokens, pos)
147+
left = _OrNode(left, right)
148+
149+
return left, pos
150+
151+
def _parse_and_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
152+
"""Parse AND expression (medium precedence)."""
153+
left, pos = self._parse_not_expression(tokens, pos)
154+
155+
while pos < len(tokens) and tokens[pos].upper() == "AND":
156+
pos += 1 # Skip AND token
157+
right, pos = self._parse_not_expression(tokens, pos)
158+
left = _AndNode(left, right)
159+
160+
return left, pos
161+
162+
def _parse_not_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
163+
"""Parse NOT expression (highest precedence)."""
164+
if pos < len(tokens) and tokens[pos].upper() == "NOT":
165+
pos += 1 # Skip NOT token
166+
operand, pos = self._parse_primary_expression(tokens, pos)
167+
return _NotNode(operand), pos
168+
else:
169+
return self._parse_primary_expression(tokens, pos)
170+
171+
def _parse_primary_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
172+
"""Parse primary expression (parentheses or pattern)."""
173+
if pos >= len(tokens):
174+
raise ValueError("Unexpected end of expression")
175+
176+
token = tokens[pos]
177+
178+
if token == "(":
179+
pos += 1 # Skip opening parenthesis
180+
expr, pos = self._parse_or_expression(tokens, pos)
181+
if pos >= len(tokens) or tokens[pos] != ")":
182+
raise ValueError("Missing closing parenthesis")
183+
pos += 1 # Skip closing parenthesis
184+
return expr, pos
185+
elif token == ")":
186+
raise ValueError("Unexpected closing parenthesis")
187+
else:
188+
# Pattern token
189+
return _PatternNode(token), pos + 1
65190

66-
if pattern.startswith("*") and pattern.endswith("*"):
67-
return f"{repr(stripped)} in k"
68-
elif pattern.startswith("*"):
69-
return f"k.endswith({repr(stripped)})"
70-
elif pattern.endswith("*"):
71-
return f"k.startswith({repr(stripped)})"
72-
else:
73-
return f"k == {repr(pattern)}"
74-
75-
def convert_token(token: str) -> str:
76-
upper = token.upper()
77-
if upper == "AND":
78-
return "and"
79-
elif upper == "OR":
80-
return "or"
81-
elif upper == "NOT":
82-
return "not"
83-
elif token in ("(", ")"):
84-
return token
85-
else:
86-
return f"any({wildcard_condition(token)} for k in keywords)"
87-
88-
converted_tokens = [convert_token(tok) for tok in tokens]
89-
return " ".join(converted_tokens)
191+
def _convert_expression(self, expr: str) -> str:
192+
"""
193+
Legacy method for backward compatibility.
194+
This method is no longer used but kept to avoid breaking changes.
195+
"""
196+
# This method is deprecated and should not be used
197+
# It's kept only for backward compatibility
198+
return expr
90199

91200

92201
def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]:

0 commit comments

Comments
 (0)