Skip to content

Commit 4bbb3b0

Browse files
authored
fix: eval schedule string (#125)
1 parent cf1f581 commit 4bbb3b0

2 files changed

Lines changed: 21 additions & 9 deletions

File tree

discoart/helper.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -692,17 +692,18 @@ def _version_check(package: str = None, github_repo: str = None):
692692

693693
def _eval_scheduling_str(val) -> List[float]:
694694
if isinstance(val, str):
695-
r = eval(val)
696-
elif isinstance(val, (int, float, bool)):
697-
r = [val] * _MAX_DIFFUSION_STEPS
695+
val = eval(val)
696+
697+
if isinstance(val, (int, float, bool)):
698+
val = [val] * _MAX_DIFFUSION_STEPS
699+
elif isinstance(val, (list, tuple)):
700+
if len(val) != _MAX_DIFFUSION_STEPS:
701+
raise ValueError(
702+
f'invalid scheduling string: {val} the schedule steps should be exactly {_MAX_DIFFUSION_STEPS}'
703+
)
698704
else:
699705
raise ValueError(f'unsupported scheduling type: {val}: {type(val)}')
700-
701-
if len(r) != _MAX_DIFFUSION_STEPS:
702-
raise ValueError(
703-
f'invalid scheduling string: {val} the schedule steps should be exactly {_MAX_DIFFUSION_STEPS}'
704-
)
705-
return r
706+
return val
706707

707708

708709
def _get_current_schedule(schedule_table: Dict, t: int) -> 'SimpleNamespace':

tests/test_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from discoart.config import default_args, save_config, load_config, export_python
2+
from discoart.helper import _eval_scheduling_str
23

34

45
def test_export_load_config(tmpfile):
@@ -23,3 +24,13 @@ def test_format_config():
2324

2425
def test_export_python():
2526
assert export_python(default_args)
27+
28+
29+
def test_eval_schedule_string():
30+
assert _eval_scheduling_str('1') == [1] * 1000
31+
assert _eval_scheduling_str('[1] * 1000') == [1] * 1000
32+
assert _eval_scheduling_str(1) == [1] * 1000
33+
assert _eval_scheduling_str('1.') == [1] * 1000
34+
assert _eval_scheduling_str('True') == [True] * 1000
35+
assert _eval_scheduling_str('False') == [False] * 1000
36+
assert _eval_scheduling_str(True) == [True] * 1000

0 commit comments

Comments
 (0)