Skip to content

Commit 57ec4e3

Browse files
Ecursoragent
authored andcommitted
Queue jobs until models available; fix progress tracking for UI
- Backend: jobs wait in pending_model until required DiT is installed; promote to queued when model available - API: get_status returns pendingReason; cancel supports pending_model; POST /api/generate/retry-pending to promote and start after model download - Frontend: SongList shows 'Waiting for model' + reason and Open Settings; poll sets generationStatus/generationPendingReason; Settings onDownloadComplete calls retryPending; cancel and pending_model handling in poll - Progress: broaden tqdm regex so all progress lines match (INFO); add progress updater from log parser so job progress/ETA updates for UI even when callback path fails Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fa8d7d9 commit 57ec4e3

12 files changed

Lines changed: 643 additions & 133 deletions

api/ace_step_models.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from pathlib import Path
8+
import shutil
89
import subprocess
910
import sys
1011
import threading
@@ -24,22 +25,22 @@ def _bundled_downloader_available() -> bool:
2425

2526
bp = Blueprint("api_ace_step_models", __name__)
2627

27-
# DiT variants from Tutorial (DiT Selection Summary)
28+
# DiT variants from Tutorial (DiT Selection Summary). size_gb: approximate for download confirmation.
2829
DIT_MODELS = [
29-
{"id": "turbo", "label": "Turbo (default)", "description": "Best balance, 8 steps", "steps": 8, "cfg": False},
30-
{"id": "turbo-shift1", "label": "Turbo shift=1", "description": "Richer details", "steps": 8, "cfg": False},
31-
{"id": "turbo-shift3", "label": "Turbo shift=3", "description": "Clearer timbre", "steps": 8, "cfg": False},
32-
{"id": "turbo-continuous", "label": "Turbo continuous", "description": "Flexible shift 1–5", "steps": 8, "cfg": False},
33-
{"id": "sft", "label": "SFT", "description": "50 steps, CFG", "steps": 50, "cfg": True},
34-
{"id": "base", "label": "Base", "description": "50 steps, CFG; lego/extract/complete", "steps": 50, "cfg": True, "exclusive_tasks": ["lego", "extract", "complete"]},
30+
{"id": "turbo", "label": "Turbo (default)", "description": "Best balance, 8 steps", "steps": 8, "cfg": False, "size_gb": 8},
31+
{"id": "turbo-shift1", "label": "Turbo shift=1", "description": "Richer details", "steps": 8, "cfg": False, "size_gb": 0.5},
32+
{"id": "turbo-shift3", "label": "Turbo shift=3", "description": "Clearer timbre", "steps": 8, "cfg": False, "size_gb": 0.5},
33+
{"id": "turbo-continuous", "label": "Turbo continuous", "description": "Flexible shift 1–5", "steps": 8, "cfg": False, "size_gb": 0.5},
34+
{"id": "sft", "label": "SFT", "description": "50 steps, CFG", "steps": 50, "cfg": True, "size_gb": 8},
35+
{"id": "base", "label": "Base", "description": "50 steps, CFG; lego/extract/complete", "steps": 50, "cfg": True, "exclusive_tasks": ["lego", "extract", "complete"], "size_gb": 8},
3536
]
3637

37-
# LM planner options from Tutorial
38+
# LM planner options from Tutorial. size_gb: approximate for download confirmation.
3839
LM_MODELS = [
39-
{"id": "none", "label": "No LM"},
40-
{"id": "0.6B", "label": "0.6B"},
41-
{"id": "1.7B", "label": "1.7B (default)"},
42-
{"id": "4B", "label": "4B"},
40+
{"id": "none", "label": "No LM", "size_gb": 0},
41+
{"id": "0.6B", "label": "0.6B", "size_gb": 2},
42+
{"id": "1.7B", "label": "1.7B (default)", "size_gb": 4},
43+
{"id": "4B", "label": "4B", "size_gb": 10},
4344
]
4445

4546
# ACE-Step 1.5 CLI model ids (for acestep-download --model)
@@ -333,6 +334,23 @@ def _do_download_worker(model: str, root: Path) -> None:
333334
_download_cancel_requested = False
334335

335336

337+
@bp.route("/models/disk-space", methods=["GET"])
338+
def disk_space():
339+
"""
340+
GET /api/ace-step/models/disk-space
341+
Returns free and total disk space for the models/checkpoints path (for download confirmation).
342+
"""
343+
try:
344+
root = _checkpoint_root()
345+
root.mkdir(parents=True, exist_ok=True)
346+
usage = shutil.disk_usage(str(root))
347+
free_gb = round(usage.free / (1024 ** 3), 2)
348+
total_gb = round(usage.total / (1024 ** 3), 2)
349+
return jsonify({"free_gb": free_gb, "total_gb": total_gb, "path": str(root)})
350+
except Exception as e:
351+
return jsonify({"error": str(e), "free_gb": None, "total_gb": None, "path": ""}), 500
352+
353+
336354
@bp.route("/models/download", methods=["POST"])
337355
def download_model():
338356
"""

api/generate.py

Lines changed: 168 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _uppercase_track_in_instruction(instruction):
2121
return instruction[: m.start(2)] + m.group(2).upper() + instruction[m.end(2) :]
2222
return instruction
2323

24-
from cdmf_paths import get_output_dir, get_user_data_dir, load_config
24+
from cdmf_paths import get_output_dir, get_user_data_dir, get_models_folder, load_config, save_config
2525
from cdmf_tracks import get_audio_duration, list_lora_adapters, load_track_meta, save_track_meta
2626
from cdmf_generation_job import GenerationCancelled
2727
import cdmf_state
@@ -59,6 +59,38 @@ def _is_cancel_requested(job_id: str) -> bool:
5959
return job_id in _cancel_requested
6060

6161

62+
def _is_model_available(dit_tag: str) -> bool:
63+
"""Return True if the given DiT model is installed and ready (no download needed). Used to promote pending_model jobs."""
64+
if not dit_tag or not isinstance(dit_tag, str):
65+
return False
66+
dit = dit_tag.strip().lower()
67+
DIT_15_FOLDERS = {
68+
"turbo": "acestep-v15-turbo",
69+
"base": "acestep-v15-base",
70+
"sft": "acestep-v15-sft",
71+
"turbo-shift1": "acestep-v15-turbo-shift1",
72+
"turbo-shift3": "acestep-v15-turbo-shift3",
73+
"turbo-continuous": "acestep-v15-turbo-continuous",
74+
}
75+
REQUIRED_SUBDIRS = ("music_dcae_f8c8", "music_vocoder", "ace_step_transformer", "umt5-base")
76+
folder = DIT_15_FOLDERS.get(dit)
77+
models_root = Path(get_models_folder()) / "checkpoints"
78+
if folder:
79+
candidate = models_root / folder
80+
if not candidate.exists():
81+
return False
82+
for sub in REQUIRED_SUBDIRS:
83+
if not (candidate / sub).exists():
84+
return False
85+
return True
86+
# Legacy v1
87+
try:
88+
from ace_model_setup import ace_models_present
89+
return ace_models_present()
90+
except Exception:
91+
return False
92+
93+
6294
def _refs_dir() -> Path:
6395
d = get_user_data_dir() / "references"
6496
d.mkdir(parents=True, exist_ok=True)
@@ -115,10 +147,45 @@ def _on_job_progress(
115147
register_job_progress_callback(_on_job_progress)
116148

117149

150+
def _update_job_progress_from_log(
151+
percent: int, current: int, total: int, eta_seconds: float | None
152+
) -> None:
153+
"""Update current job progress from parsed tqdm log line (log handler runs in same thread as worker)."""
154+
with _jobs_lock:
155+
jid = cdmf_state.get_current_generation_job_id()
156+
if not jid:
157+
return
158+
job = _jobs.get(jid)
159+
if not job:
160+
return
161+
job["progressPercent"] = round(float(percent), 1)
162+
job["progressSteps"] = f"{current}/{total}"
163+
if eta_seconds is not None:
164+
job["progressEta"] = round(float(eta_seconds), 1)
165+
166+
118167
def _run_generation(job_id: str) -> None:
119168
"""Background: run generate_track_ace and update job."""
120169
global _generation_busy, _current_job_id
170+
prev_config = None
171+
config_switched = False
121172
try:
173+
with _jobs_lock:
174+
job = _jobs.get(job_id)
175+
if not job or job.get("status") != "queued":
176+
return
177+
job_dit = job.get("dit_model") or "turbo"
178+
# If required model is not installed, leave job as pending_model so it runs after user installs it.
179+
if not _is_model_available(job_dit):
180+
logging.info("[API generate] Job %s waiting for model %s (not installed)", job_id, job_dit)
181+
with _jobs_lock:
182+
job = _jobs.get(job_id)
183+
if job and job.get("status") == "queued":
184+
job["status"] = "pending_model"
185+
job["pendingReason"] = (
186+
f"Model '{job_dit}' is not installed. Install it from Settings → Models to run this job."
187+
)
188+
return
122189
with _jobs_lock:
123190
job = _jobs.get(job_id)
124191
if not job or job.get("status") != "queued":
@@ -131,8 +198,20 @@ def _run_generation(job_id: str) -> None:
131198
_current_job_id = job_id
132199

133200
cdmf_state.set_current_generation_job_id(job_id)
201+
cdmf_state.set_progress_updater(_update_job_progress_from_log)
134202
cancel_check = lambda: _is_cancel_requested(job_id)
135-
from generate_ace import generate_track_ace
203+
from generate_ace import generate_track_ace, clear_ace_pipeline
204+
205+
# Use job's dit_model (e.g. base for cover). Temporarily switch config so pipeline loads the right model.
206+
with _jobs_lock:
207+
j = _jobs.get(job_id)
208+
job_dit = (j.get("dit_model") or "turbo") if j else "turbo"
209+
prev_config = load_config() or {}
210+
prev_config = dict(prev_config)
211+
config_switched = job_dit != (prev_config.get("ace_step_dit_model") or "turbo")
212+
if config_switched:
213+
save_config({**prev_config, "ace_step_dit_model": job_dit})
214+
clear_ace_pipeline()
136215

137216
params = job.get("params") or {}
138217
if not isinstance(params, dict):
@@ -451,27 +530,64 @@ def _run_generation(job_id: str) -> None:
451530
job["status"] = "cancelled"
452531
job["error"] = "Cancelled by user"
453532
except Exception as e:
454-
logging.exception("Generation job %s failed", job_id)
455-
with _jobs_lock:
456-
job = _jobs.get(job_id)
457-
if job:
458-
job["status"] = "failed"
459-
job["error"] = str(e)
533+
err_msg = str(e)
534+
# Keep job queued as pending_model when model is missing so it can run after user installs it.
535+
if "not installed" in err_msg.lower() or "Settings → Models" in err_msg:
536+
logging.info("[API generate] Job %s waiting for model: %s", job_id, err_msg[:120])
537+
with _jobs_lock:
538+
job = _jobs.get(job_id)
539+
if job:
540+
job["status"] = "pending_model"
541+
job["pendingReason"] = err_msg
542+
job["error"] = None
543+
else:
544+
logging.exception("Generation job %s failed", job_id)
545+
with _jobs_lock:
546+
job = _jobs.get(job_id)
547+
if job:
548+
job["status"] = "failed"
549+
job["error"] = err_msg
460550
finally:
551+
cdmf_state.set_progress_updater(None)
552+
if config_switched and prev_config:
553+
save_config(prev_config)
554+
clear_ace_pipeline()
461555
cdmf_state.set_current_generation_job_id(None)
462556
_generation_busy = False
463557
with _jobs_lock:
464558
_current_job_id = None
465559
_cancel_requested.discard(job_id)
466-
# Start next queued job (skips cancelled: they are no longer "queued")
560+
# Promote pending_model jobs to queued when their model is now available; then start first queued job.
467561
with _jobs_lock:
562+
for jid in _job_order:
563+
j = _jobs.get(jid)
564+
if j and j.get("status") == "pending_model":
565+
dit = j.get("dit_model") or "turbo"
566+
if _is_model_available(dit):
567+
j["status"] = "queued"
568+
j["pendingReason"] = None
569+
logging.info("[API generate] Job %s promoted to queued (model %s now available)", jid, dit)
468570
for jid in _job_order:
469571
j = _jobs.get(jid)
470572
if j and j.get("status") == "queued":
471573
threading.Thread(target=_run_generation, args=(jid,), daemon=True).start()
472574
break
473575

474576

577+
@bp.route("/model-download-status", methods=["GET"])
578+
def get_model_download_status():
579+
"""GET /api/generate/model-download-status — whether pipeline is loading (may be downloading model files)."""
580+
try:
581+
st = getattr(cdmf_state, "GENERATION_MODEL_LOADING", {})
582+
return jsonify({
583+
"in_progress": bool(st.get("in_progress")),
584+
"message": st.get("message") or "Preparing model (downloading if needed)...",
585+
})
586+
except Exception as e:
587+
logging.warning("[API generate] model-download-status failed: %s", e)
588+
return jsonify({"in_progress": False, "message": ""})
589+
590+
475591
@bp.route("/lora_adapters", methods=["GET"])
476592
def get_lora_adapters():
477593
"""GET /api/generate/lora_adapters — list LoRA adapters (e.g. from Training or custom_lora)."""
@@ -554,7 +670,15 @@ def _str(v):
554670
except (TypeError, ValueError):
555671
params_copy = {}
556672
config = load_config()
557-
dit_tag = config.get("ace_step_dit_model") or params_copy.get("aceStepDitModel") or "turbo"
673+
# User override from Generation tab model selector takes precedence; else auto base for cover (per docs).
674+
task_for_dit = (params_copy.get("task_type") or params_copy.get("taskType") or "text2music").strip().lower()
675+
explicit_dit = (params_copy.get("aceStepDitModel") or params_copy.get("ace_step_dit_model") or "").strip()
676+
if explicit_dit:
677+
dit_tag = explicit_dit
678+
elif task_for_dit == "cover":
679+
dit_tag = "base"
680+
else:
681+
dit_tag = config.get("ace_step_dit_model") or "turbo"
558682
lm_tag = config.get("ace_step_lm") or params_copy.get("aceStepLm") or "1.7B"
559683
with _jobs_lock:
560684
_jobs[job_id] = {
@@ -608,10 +732,42 @@ def get_status(job_id: str):
608732
"progressStage": job.get("progressStage"),
609733
"result": job.get("result"),
610734
"error": job.get("error"),
735+
"pendingReason": job.get("pendingReason") if status == "pending_model" else None,
611736
}
612737
return jsonify(out)
613738

614739

740+
@bp.route("/retry-pending", methods=["POST"])
741+
def retry_pending():
742+
"""POST /api/generate/retry-pending — promote pending_model jobs to queued if model is now available, start first queued job. Call after model download completes."""
743+
global _generation_busy
744+
promoted = 0
745+
started = None
746+
with _jobs_lock:
747+
for jid in _job_order:
748+
j = _jobs.get(jid)
749+
if j and j.get("status") == "pending_model":
750+
dit = j.get("dit_model") or "turbo"
751+
if _is_model_available(dit):
752+
j["status"] = "queued"
753+
j["pendingReason"] = None
754+
promoted += 1
755+
logging.info("[API generate] Job %s promoted to queued (model %s now available)", jid, dit)
756+
if not _generation_busy:
757+
for jid in _job_order:
758+
j = _jobs.get(jid)
759+
if j and j.get("status") == "queued":
760+
_generation_busy = True
761+
threading.Thread(target=_run_generation, args=(jid,), daemon=True).start()
762+
started = jid
763+
break
764+
return jsonify({
765+
"ok": True,
766+
"promoted": promoted,
767+
"startedJobId": started,
768+
})
769+
770+
615771
@bp.route("/unstick", methods=["POST"])
616772
def unstick_queue():
617773
"""POST /api/generate/unstick — clear stuck worker state and start the next queued job (if any)."""
@@ -641,9 +797,10 @@ def cancel_job(job_id: str):
641797
if not job:
642798
return jsonify({"error": "Job not found"}), 404
643799
status = job.get("status", "unknown")
644-
if status == "queued":
800+
if status in ("queued", "pending_model"):
645801
job["status"] = "cancelled"
646802
job["error"] = "Cancelled by user"
803+
job["pendingReason"] = None
647804
return jsonify({"cancelled": True, "jobId": job_id, "message": "Job removed from queue."})
648805
if status == "running":
649806
_cancel_requested.add(job_id)

cdmf_pipeline_ace_step.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,8 @@ def add_latents_noise(
11951195
sigma_max=sigma_max
11961196
)
11971197

1198-
infer_steps = int(sigma_max * infer_steps)
1198+
# Ensure enough steps for cover/audio2audio so reference is audible (INFERENCE.md: base 32-64 recommended).
1199+
infer_steps = max(16, int(sigma_max * infer_steps))
11991200
timesteps, num_inference_steps = retrieve_timesteps(
12001201
scheduler,
12011202
num_inference_steps=infer_steps,
@@ -1295,6 +1296,17 @@ def text2music_diffusion_process(
12951296

12961297
if ref_latents is not None:
12971298
frame_length = ref_latents.shape[-1]
1299+
# Cap ref length for cover/audio2audio so each diffusion step stays fast (avoids 80s+ per step on long refs)
1300+
max_cover_sec = float(os.environ.get("ACE_COVER_MAX_REF_SECONDS", "90"))
1301+
max_cover_frames = int(max_cover_sec * 44100 / 512 / 8)
1302+
if frame_length > max_cover_frames:
1303+
ref_latents = ref_latents[:, :, :, :max_cover_frames].contiguous()
1304+
frame_length = max_cover_frames
1305+
logger.info(
1306+
"Capped ref_latents to %d frames (~%.0fs) for faster cover/audio2audio generation (set ACE_COVER_MAX_REF_SECONDS to override).",
1307+
max_cover_frames,
1308+
max_cover_sec,
1309+
)
12981310

12991311
if len(oss_steps) > 0:
13001312
infer_steps = max(oss_steps)
@@ -2087,6 +2099,7 @@ def __call__(
20872099

20882100
ref_latents = None
20892101
if ref_audio_input is not None and audio2audio_enable:
2102+
# For cover mode: ref_audio_input = source song (song to cover), per docs/ACE-Step-INFERENCE.md.
20902103
assert ref_audio_input is not None, "ref_audio_input is required for audio2audio task"
20912104
assert os.path.exists(
20922105
ref_audio_input

0 commit comments

Comments
 (0)