Skip to content

Commit bbf9ade

Browse files
committed
fix: address review — add Greek symbols, copy global_dict per call, validate variable names
1 parent dc9d4db commit bbf9ade

3 files changed

Lines changed: 155 additions & 5 deletions

File tree

src/qwed_new/core/safe_parser.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,52 @@
6565
re.IGNORECASE,
6666
)
6767

68+
# Regex for validating variable names passed to Symbol().
69+
# Allows single-letter, Greek names, and conventional multi-letter math
70+
# variable names. Must start with a letter and contain only alphanumerics
71+
# and underscores, with a reasonable length cap.
72+
_SAFE_VARIABLE_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]{0,49}$")
73+
74+
75+
def validate_variable_name(name: str) -> None:
76+
"""
77+
Validate a user-supplied variable name before it reaches Symbol().
78+
79+
Applies the same length cap, denylist, and character-set checks that
80+
safe_parse_expr applies to full expressions, keeping the hardened
81+
boundary consistent across all user-controlled string inputs.
82+
83+
Raises:
84+
ValueError: If the name is invalid or contains dangerous patterns.
85+
"""
86+
if not isinstance(name, str):
87+
raise ValueError("Variable name must be a string")
88+
89+
stripped = name.strip()
90+
if not stripped:
91+
raise ValueError("Variable name must not be empty")
92+
93+
if not _SAFE_VARIABLE_RE.match(stripped):
94+
raise ValueError(
95+
"Variable name must start with a letter, contain only "
96+
"alphanumerics/underscores, and be at most 50 characters"
97+
)
98+
99+
if _DANGEROUS_PATTERNS.search(stripped):
100+
raise ValueError("Variable name contains disallowed constructs")
101+
68102

69103
def _build_safe_local_dict(extra_symbols: Optional[Dict[str, Any]] = None) -> dict:
70104
"""
71105
Build the allow-listed local namespace for parse_expr.
72106
73107
Only mathematical symbols, constants, functions, and the internal
74108
sympy types that parse_expr's transformations emit are included.
109+
Includes common Greek-letter and multi-letter symbolic variable names
110+
used in standard mathematical and scientific notation.
75111
"""
76112
safe = {
77-
# Common symbolic variables
113+
# Common single-letter symbolic variables
78114
"x": Symbol("x"),
79115
"y": Symbol("y"),
80116
"z": Symbol("z"),
@@ -91,6 +127,40 @@ def _build_safe_local_dict(extra_symbols: Optional[Dict[str, Any]] = None) -> di
91127
"u": Symbol("u"),
92128
"v": Symbol("v"),
93129
"w": Symbol("w"),
130+
# Greek-letter symbolic variables (common in verification workloads)
131+
"alpha": Symbol("alpha"),
132+
"beta": Symbol("beta"),
133+
"gamma": Symbol("gamma"),
134+
"delta": Symbol("delta"),
135+
"epsilon": Symbol("epsilon"),
136+
"zeta": Symbol("zeta"),
137+
"eta": Symbol("eta"),
138+
"theta": Symbol("theta"),
139+
"iota": Symbol("iota"),
140+
"kappa": Symbol("kappa"),
141+
"mu": Symbol("mu"),
142+
"nu": Symbol("nu"),
143+
"xi": Symbol("xi"),
144+
"omicron": Symbol("omicron"),
145+
"rho": Symbol("rho"),
146+
"sigma": Symbol("sigma"),
147+
"tau": Symbol("tau"),
148+
"upsilon": Symbol("upsilon"),
149+
"phi": Symbol("phi"),
150+
"chi": Symbol("chi"),
151+
"psi": Symbol("psi"),
152+
"omega": Symbol("omega"),
153+
# Capital Greek letters commonly used as symbols
154+
"Alpha": Symbol("Alpha"),
155+
"Beta": Symbol("Beta"),
156+
"Gamma": Symbol("Gamma"),
157+
"Delta": Symbol("Delta"),
158+
"Theta": Symbol("Theta"),
159+
"Lambda": Symbol("Lambda"),
160+
"Sigma": Symbol("Sigma"),
161+
"Phi": Symbol("Phi"),
162+
"Psi": Symbol("Psi"),
163+
"Omega": Symbol("Omega"),
94164
# Mathematical constants
95165
"pi": pi,
96166
"e": E,
@@ -141,7 +211,9 @@ def _build_safe_local_dict(extra_symbols: Optional[Dict[str, Any]] = None) -> di
141211
return safe
142212

143213

144-
# Pre-built global dict that strips builtins
214+
# Pre-built global dict that strips builtins.
215+
# IMPORTANT: A shallow copy is made per invocation (see safe_parse_expr)
216+
# to prevent cross-call mutation by SymPy transformations.
145217
_SAFE_GLOBAL_DICT: dict = {"__builtins__": {}}
146218

147219

@@ -192,7 +264,7 @@ def safe_parse_expr(
192264
return parse_expr(
193265
stripped,
194266
local_dict=local_dict,
195-
global_dict=_SAFE_GLOBAL_DICT,
267+
global_dict=dict(_SAFE_GLOBAL_DICT),
196268
transformations=transformations,
197269
)
198270
except Exception as exc:

src/qwed_new/core/verifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
diff, integrate, limit, oo,
1919
simplify, expand
2020
)
21-
from qwed_new.core.safe_parser import safe_parse_expr, SAFE_TRANSFORMATIONS
21+
from qwed_new.core.safe_parser import safe_parse_expr, validate_variable_name, SAFE_TRANSFORMATIONS
2222
from typing import Any, Dict, List, Optional
2323
from decimal import Decimal, ROUND_HALF_UP
2424
from dataclasses import dataclass
@@ -282,6 +282,7 @@ def verify_derivative(
282282
"""
283283
try:
284284
expr = safe_parse_expr(expression)
285+
validate_variable_name(variable)
285286
var = Symbol(variable)
286287
expected_expr = safe_parse_expr(expected)
287288

@@ -329,6 +330,7 @@ def verify_integral(
329330
"""
330331
try:
331332
expr = safe_parse_expr(expression)
333+
validate_variable_name(variable)
332334
var = Symbol(variable)
333335
expected_expr = safe_parse_expr(expected)
334336

@@ -386,6 +388,7 @@ def verify_limit(
386388
"""
387389
try:
388390
expr = safe_parse_expr(expression)
391+
validate_variable_name(variable)
389392
var = Symbol(variable)
390393
expected_expr = safe_parse_expr(expected)
391394

tests/security/test_safe_parser.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import pytest
9-
from qwed_new.core.safe_parser import safe_parse_expr
9+
from qwed_new.core.safe_parser import safe_parse_expr, validate_variable_name
1010

1111

1212
class TestSafeParseExprBlocksCodeExecution:
@@ -81,6 +81,32 @@ def test_valid_expression_parses(
8181
assert str(result) == expected_str
8282

8383

84+
class TestSafeParseExprMultiLetterSymbols:
85+
"""Verify that multi-letter symbolic variable names parse correctly."""
86+
87+
@pytest.mark.parametrize(
88+
"expression,expected_str",
89+
[
90+
("alpha + beta", "alpha + beta"),
91+
("theta**2", "theta**2"),
92+
("sin(phi)", "sin(phi)"),
93+
("gamma * delta", "delta*gamma"),
94+
("epsilon + tau", "epsilon + tau"),
95+
("sigma * omega", "omega*sigma"),
96+
("Lambda + Omega", "Lambda + Omega"),
97+
],
98+
)
99+
def test_greek_letter_variables_parse(
100+
self, expression: str, expected_str: str
101+
) -> None:
102+
result = safe_parse_expr(expression)
103+
assert str(result) == expected_str
104+
105+
def test_mixed_single_and_greek_variables(self) -> None:
106+
result = safe_parse_expr("x + alpha")
107+
assert str(result) == "alpha + x"
108+
109+
84110
class TestSafeParseExprInputValidation:
85111
"""Verify input validation guards."""
86112

@@ -95,3 +121,52 @@ def test_rejects_non_string(self) -> None:
95121
def test_rejects_oversized_input(self) -> None:
96122
with pytest.raises(ValueError, match="too long"):
97123
safe_parse_expr("x + " * 2000)
124+
125+
126+
class TestSafeParseExprGlobalDictIsolation:
127+
"""Verify that _SAFE_GLOBAL_DICT is not mutated across calls."""
128+
129+
def test_global_dict_not_shared_between_calls(self) -> None:
130+
"""Parse two different expressions and confirm no cross-contamination."""
131+
safe_parse_expr("x + 1")
132+
safe_parse_expr("alpha + beta")
133+
# If global_dict were shared mutably, SymPy transformations could
134+
# leak symbols from one call into another. The shallow-copy fix
135+
# prevents this. We just verify no exception is raised and both
136+
# parse independently.
137+
result = safe_parse_expr("y + 2")
138+
assert str(result) == "y + 2"
139+
140+
141+
class TestValidateVariableName:
142+
"""Verify variable name validation for calculus methods."""
143+
144+
@pytest.mark.parametrize(
145+
"name",
146+
["x", "y", "theta", "alpha", "x1", "var_name"],
147+
)
148+
def test_valid_variable_names_accepted(self, name: str) -> None:
149+
# Should not raise
150+
validate_variable_name(name)
151+
152+
@pytest.mark.parametrize(
153+
"name",
154+
[
155+
"", # empty
156+
" ", # whitespace only
157+
"123", # starts with digit
158+
"__import__", # dunder pattern
159+
"a" * 51, # too long
160+
"os", # denylist hit
161+
"sys", # denylist hit
162+
"eval", # denylist hit
163+
"exec", # denylist hit
164+
],
165+
)
166+
def test_invalid_variable_names_rejected(self, name: str) -> None:
167+
with pytest.raises(ValueError):
168+
validate_variable_name(name)
169+
170+
def test_rejects_non_string_variable(self) -> None:
171+
with pytest.raises(ValueError, match="must be a string"):
172+
validate_variable_name(123) # type: ignore[arg-type]

0 commit comments

Comments
 (0)