Skip to content

Commit 6066b68

Browse files
authored
Merge pull request #200 from sebastiondev/fix/cwe95-main-sympy-9383
fix(math): restrict sympy expression parsing
2 parents 06801f6 + 70f0978 commit 6066b68

8 files changed

Lines changed: 487 additions & 27 deletions

File tree

.coverage

-24 KB
Binary file not shown.

src/qwed_new/api/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ async def verify_math(
457457

458458
try:
459459
import sympy
460-
from sympy.parsing.sympy_parser import parse_expr
460+
from qwed_new.core.safe_parser import safe_parse_expr
461461
from sympy import simplify, symbols, Eq, solve
462462

463463
expression = request.get("expression")
@@ -472,8 +472,8 @@ async def verify_math(
472472
left_str, right_str = expression.split("=", 1)
473473

474474
# Parse both sides
475-
left = parse_expr(left_str)
476-
right = parse_expr(right_str)
475+
left = safe_parse_expr(left_str)
476+
right = safe_parse_expr(right_str)
477477

478478
# Simplify and check equivalence
479479
difference = simplify(left - right)
@@ -501,7 +501,7 @@ async def verify_math(
501501
if re.search(r'/\d+\(', expression.replace(" ", "")):
502502
is_ambiguous = True
503503

504-
parsed = parse_expr(expression_normalized)
504+
parsed = safe_parse_expr(expression_normalized)
505505

506506
# Check for division by zero before simplifying
507507
if "/0" in expression.replace(" ", "") or "/ 0" in expression:

src/qwed_new/core/batch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,14 @@ async def _verify_item(
220220
)
221221

222222
elif item.verification_type == VerificationType.MATH:
223-
from sympy.parsing.sympy_parser import parse_expr
223+
from qwed_new.core.safe_parser import safe_parse_expr
224224
from sympy import simplify
225225

226226
expression = item.query
227227
if "=" in expression:
228228
left, right = expression.split("=", 1)
229-
left_expr = parse_expr(left)
230-
right_expr = parse_expr(right)
229+
left_expr = safe_parse_expr(left)
230+
right_expr = safe_parse_expr(right)
231231
diff = simplify(left_expr - right_expr)
232232
is_valid = diff == 0
233233
return {
@@ -236,7 +236,7 @@ async def _verify_item(
236236
"message": "Identity verified" if is_valid else "Not equal"
237237
}
238238
else:
239-
parsed = parse_expr(expression)
239+
parsed = safe_parse_expr(expression)
240240
simplified = simplify(parsed)
241241
return {
242242
"is_valid": False,

src/qwed_new/core/safe_parser.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""
2+
Safe SymPy expression parser.
3+
4+
Wraps sympy.parsing.sympy_parser.parse_expr with input validation,
5+
a denylist for dangerous constructs, and a restricted evaluation
6+
namespace. This module is the ONLY approved entry point for parsing
7+
user-supplied math expressions in production code.
8+
9+
Security boundary:
10+
1. Reject known-dangerous Python/OS constructs (denylist).
11+
2. Remove __builtins__ from the eval global dict.
12+
3. Allow-list only expected math symbols, constants, and functions.
13+
4. Enforce basic input validation (type, length, empty check).
14+
15+
CWE-95 mitigation -- see PR #200 for full security analysis.
16+
"""
17+
18+
import ast
19+
import re
20+
from typing import Any, Dict, Optional, Tuple
21+
22+
import sympy
23+
from sympy import (
24+
E, I, Integer, Float, Rational, Symbol, oo, pi,
25+
)
26+
from sympy.parsing.sympy_parser import (
27+
parse_expr,
28+
standard_transformations,
29+
implicit_multiplication_application,
30+
)
31+
32+
__all__ = ["safe_parse_expr", "validate_variable_name", "get_safe_symbol", "SafeParserError"]
33+
34+
MAX_EXPRESSION_LENGTH = 5_000
35+
_AST_MAX_DEPTH = 30
36+
37+
_DENYLIST_PATTERN = re.compile(
38+
r"(?:"
39+
r"__import__|__builtins__|__subclasses__|__globals__|__locals__"
40+
r"|__getattr__|__setattr__|__delattr__|__class__|__bases__|__mro__"
41+
r"|\beval\b|\bexec\b|\bcompile\b|\bgetattr\b|\bsetattr\b|\bdelattr\b"
42+
r"|\bimport\b|\bimportlib\b"
43+
r"|\bos\b|\bsys\b|\bsubprocess\b|\bshutil\b|\bsocket\b"
44+
r"|\bpopen\b|\bsystem\b|\bspawn\b"
45+
r"|\bopen\b|\bfile\b|\bpath\b|\bglob\b"
46+
r"|\bchr\b|\bord\b|\bhex\b|\btype\b|\bvars\b|\bdir\b|\brepr\b"
47+
r"|\binput\b|\bprint\b|\bbreakpoint\b|\bexit\b|\bquit\b"
48+
r"|\bcodecs\b|\bcode\b|\bctypes\b"
49+
r")",
50+
re.IGNORECASE,
51+
)
52+
53+
_SAFE_GLOBAL_DICT_TEMPLATE: Dict[str, Any] = {"__builtins__": {}}
54+
55+
56+
def _check_ast_depth(expression: str) -> None:
57+
"""Reject Python-parseable expressions exceeding max AST depth (DoS defence).
58+
59+
Expressions using implicit multiplication (e.g. 2x, sin x) fail ast.parse
60+
and skip this check — they are caught by the post-parse sympy depth check.
61+
"""
62+
try:
63+
tree = ast.parse(expression, mode="eval")
64+
except SyntaxError:
65+
return
66+
depth = _ast_node_depth(tree)
67+
if depth > _AST_MAX_DEPTH:
68+
raise SafeParserError(
69+
f"Expression AST depth {depth} exceeds limit of {_AST_MAX_DEPTH}"
70+
)
71+
72+
73+
def _ast_node_depth(node: ast.AST, current: int = 0) -> int:
74+
max_depth = current
75+
for child in ast.iter_child_nodes(node):
76+
child_depth = _ast_node_depth(child, current + 1)
77+
if child_depth > max_depth:
78+
max_depth = child_depth
79+
return max_depth
80+
81+
82+
_SYMPY_MAX_DEPTH = 40
83+
84+
85+
def _sympy_tree_depth(expr: Any, current: int = 0) -> int:
86+
"""Compute nesting depth of a SymPy expression tree."""
87+
max_depth = current
88+
for arg in getattr(expr, "args", ()):
89+
child_depth = _sympy_tree_depth(arg, current + 1)
90+
if child_depth > max_depth:
91+
max_depth = child_depth
92+
return max_depth
93+
94+
95+
def _validate_sympy_result(result: Any) -> None:
96+
"""Ensure parse_expr returned a valid SymPy expression within depth limits."""
97+
import sympy
98+
if not isinstance(result, sympy.Expr):
99+
raise SafeParserError(
100+
f"Parsed result is not a supported arithmetic expression, got {type(result).__name__}"
101+
)
102+
depth = _sympy_tree_depth(result)
103+
if depth > _SYMPY_MAX_DEPTH:
104+
raise SafeParserError(
105+
f"Expression tree depth {depth} exceeds limit of {_SYMPY_MAX_DEPTH}"
106+
)
107+
108+
109+
def _build_safe_local_dict(
110+
extra_symbols: Optional[Dict[str, Any]] = None,
111+
) -> Dict[str, Any]:
112+
safe: Dict[str, Any] = {
113+
"x": Symbol("x"), "y": Symbol("y"), "z": Symbol("z"),
114+
"a": Symbol("a"), "b": Symbol("b"), "c": Symbol("c"),
115+
"d": Symbol("d"), "f": Symbol("f"), "g": Symbol("g"),
116+
"h": Symbol("h"), "k": Symbol("k"), "m": Symbol("m"),
117+
"n": Symbol("n", integer=True, positive=True),
118+
"p": Symbol("p"), "q": Symbol("q"), "r": Symbol("r"),
119+
"s": Symbol("s"), "t": Symbol("t"), "u": Symbol("u"),
120+
"v": Symbol("v"), "w": Symbol("w"),
121+
"alpha": Symbol("alpha"), "beta": Symbol("beta"),
122+
"gamma": Symbol("gamma"), "delta": Symbol("delta"),
123+
"epsilon": Symbol("epsilon"), "zeta": Symbol("zeta"),
124+
"eta": Symbol("eta"), "theta": Symbol("theta"),
125+
"iota": Symbol("iota"), "kappa": Symbol("kappa"),
126+
"mu": Symbol("mu"), "nu": Symbol("nu"),
127+
"xi": Symbol("xi"), "omicron": Symbol("omicron"),
128+
"rho": Symbol("rho"), "sigma": Symbol("sigma"),
129+
"tau": Symbol("tau"), "phi": Symbol("phi"),
130+
"chi": Symbol("chi"), "psi": Symbol("psi"),
131+
"omega": Symbol("omega"),
132+
"pi": pi, "E": E, "I": I, "oo": oo,
133+
"sin": sympy.sin, "cos": sympy.cos, "tan": sympy.tan,
134+
"cot": sympy.cot, "sec": sympy.sec, "csc": sympy.csc,
135+
"asin": sympy.asin, "acos": sympy.acos, "atan": sympy.atan,
136+
"atan2": sympy.atan2,
137+
"sinh": sympy.sinh, "cosh": sympy.cosh, "tanh": sympy.tanh,
138+
"log": sympy.log, "ln": sympy.log, "exp": sympy.exp,
139+
"sqrt": sympy.sqrt, "cbrt": sympy.cbrt,
140+
"abs": sympy.Abs, "Abs": sympy.Abs,
141+
"factorial": sympy.factorial, "binomial": sympy.binomial,
142+
"Integer": Integer, "Float": Float, "Rational": Rational,
143+
# Symbol is required because SymPy standard_transformations may emit
144+
# Symbol('name') during evaluation. This allows users to create symbols
145+
# with arbitrary names — the denylist and stripped builtins mitigate
146+
# downstream attribute-access risks on resulting objects.
147+
"Symbol": Symbol,
148+
}
149+
if extra_symbols:
150+
for key, value in extra_symbols.items():
151+
if not isinstance(key, str):
152+
raise SafeParserError(
153+
f"extra_symbols keys must be strings, got {type(key).__name__}"
154+
)
155+
if _DENYLIST_PATTERN.search(key):
156+
raise SafeParserError(
157+
f"extra_symbols key {key!r} contains disallowed construct"
158+
)
159+
if not isinstance(value, (Symbol, sympy.Basic)):
160+
raise SafeParserError(
161+
f"extra_symbols[{key!r}] must be a SymPy Symbol or Basic, "
162+
f"got {type(value).__name__}"
163+
)
164+
safe[key] = value
165+
return safe
166+
167+
168+
class SafeParserError(ValueError):
169+
pass
170+
171+
172+
def safe_parse_expr(
173+
expression: str,
174+
*,
175+
extra_symbols: Optional[Dict[str, Any]] = None,
176+
transformations: Optional[Tuple] = None,
177+
) -> Any:
178+
if not isinstance(expression, str):
179+
raise SafeParserError(
180+
f"Expression must be a string, got {type(expression).__name__}"
181+
)
182+
stripped = expression.strip()
183+
if not stripped:
184+
raise SafeParserError("Expression is empty")
185+
if len(stripped) > MAX_EXPRESSION_LENGTH:
186+
raise SafeParserError(
187+
f"Expression exceeds maximum length of {MAX_EXPRESSION_LENGTH} characters"
188+
)
189+
match = _DENYLIST_PATTERN.search(stripped)
190+
if match:
191+
raise SafeParserError(
192+
f"Expression contains disallowed construct: {match.group()!r}"
193+
)
194+
_check_ast_depth(stripped)
195+
local_dict = _build_safe_local_dict(extra_symbols)
196+
if transformations is None:
197+
transformations = standard_transformations + (
198+
implicit_multiplication_application,
199+
)
200+
global_dict = dict(_SAFE_GLOBAL_DICT_TEMPLATE)
201+
try:
202+
result = parse_expr(
203+
stripped,
204+
local_dict=local_dict,
205+
global_dict=global_dict,
206+
transformations=transformations,
207+
)
208+
_validate_sympy_result(result)
209+
return result
210+
except SafeParserError:
211+
raise
212+
except Exception as exc:
213+
raise SafeParserError(f"Failed to parse expression: {exc}") from exc
214+
215+
216+
def validate_variable_name(variable: str) -> str:
217+
if not isinstance(variable, str):
218+
raise SafeParserError(
219+
f"Variable name must be a string, got {type(variable).__name__}"
220+
)
221+
stripped = variable.strip()
222+
if not stripped:
223+
raise SafeParserError("Variable name is empty")
224+
if len(stripped) > 50:
225+
raise SafeParserError("Variable name is too long")
226+
if not re.match(r"^[A-Za-z][A-Za-z0-9_]*$", stripped):
227+
raise SafeParserError(
228+
f"Invalid variable name: {stripped!r}. "
229+
"Must start with a letter and contain only alphanumeric characters."
230+
)
231+
match = _DENYLIST_PATTERN.search(stripped)
232+
if match:
233+
raise SafeParserError(
234+
f"Variable name contains disallowed construct: {match.group()!r}"
235+
)
236+
return stripped
237+
238+
239+
def get_safe_symbol(name: str) -> Symbol:
240+
"""Return a Symbol consistent with safe_parse_expr's namespace.
241+
242+
Ensures calculus operation variables match any special assumptions
243+
(e.g. Symbol(\"n\", integer=True, positive=True)) applied during parsing,
244+
preventing symbol mismatch in diff/integrate/limit.
245+
"""
246+
name = validate_variable_name(name)
247+
safe = _build_safe_local_dict()
248+
if name in safe:
249+
sym = safe[name]
250+
if isinstance(sym, Symbol):
251+
return sym
252+
return Symbol(name)

src/qwed_new/core/validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
3. Evaluable: Can we calculate a numerical result?
1111
"""
1212

13-
from sympy.parsing.sympy_parser import parse_expr
13+
from qwed_new.core.safe_parser import safe_parse_expr
1414
from typing import Dict
1515

1616

@@ -61,7 +61,7 @@ def validate(self, expression: str) -> Dict[str, any]:
6161

6262
# Check 1: Syntax validation
6363
try:
64-
expr = parse_expr(expression)
64+
expr = safe_parse_expr(expression)
6565
checks_passed.append("syntax")
6666
except Exception as e:
6767
checks_failed.append("syntax")

0 commit comments

Comments
 (0)