Skip to content

Commit 993a021

Browse files
authored
fix: make eval schedule string safe (#126)
1 parent a8cd98c commit 993a021

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

discoart/helper.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import yaml
2121
from clip.simple_tokenizer import SimpleTokenizer, whitespace_clean, basic_clean
2222
from packaging.version import Version
23-
2423
from spellchecker import SpellChecker
2524
from tqdm.auto import tqdm
2625

@@ -690,9 +689,21 @@ def _version_check(package: str = None, github_repo: str = None):
690689
_MAX_DIFFUSION_STEPS = 1000
691690

692691

692+
def _is_valid_schedule_str(val) -> bool:
693+
r = re.match(r'(False\b|True\b|[\(\)\[\]0-9\, \.\*\+\-])+', val)
694+
if r and r.group(0) == val:
695+
return True
696+
return False
697+
698+
693699
def _eval_scheduling_str(val) -> List[float]:
694700
if isinstance(val, str):
695-
val = eval(val)
701+
if _is_valid_schedule_str(val):
702+
val = eval(val)
703+
else:
704+
raise ValueError(
705+
f'invalid scheduling string: {val}, it contains unsafe code'
706+
)
696707

697708
if isinstance(val, (int, float, bool)):
698709
val = [val] * _MAX_DIFFUSION_STEPS

tests/test_config.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import pytest
2+
13
from discoart.config import default_args, save_config, load_config, export_python
2-
from discoart.helper import _eval_scheduling_str
4+
from discoart.helper import _eval_scheduling_str, _is_valid_schedule_str
35

46

57
def test_export_load_config(tmpfile):
@@ -34,3 +36,28 @@ def test_eval_schedule_string():
3436
assert _eval_scheduling_str('True') == [True] * 1000
3537
assert _eval_scheduling_str('False') == [False] * 1000
3638
assert _eval_scheduling_str(True) == [True] * 1000
39+
40+
41+
@pytest.mark.parametrize(
42+
'val, expected',
43+
[
44+
('[100]*600+[200]*400', True),
45+
('[100]*600+[2.3]*400', True),
46+
('[100]*600+[2.3]*400', True),
47+
('1', True),
48+
('True', True),
49+
('Truetrue', False),
50+
('False', True),
51+
('true', False),
52+
('sdd ds', False),
53+
('[True, False]*1000', True),
54+
('[True]*500+[False]*400', True),
55+
('[0.5]*400+[0.2]*300+[True]*200', True),
56+
('[hello]*1000', False),
57+
('del a', False),
58+
('([1]+[2])*50', True),
59+
('[False,True,1,0.23,23,]*1000', True),
60+
],
61+
)
62+
def test_chec_schedule_str(val, expected):
63+
assert _is_valid_schedule_str(val) == expected

0 commit comments

Comments
 (0)