@@ -21,7 +21,7 @@ def _uppercase_track_in_instruction(instruction):
2121 return instruction [: m .start (2 )] + m .group (2 ).upper () + instruction [m .end (2 ) :]
2222 return instruction
2323
24- from cdmf_paths import get_output_dir , get_user_data_dir , load_config
24+ from cdmf_paths import get_output_dir , get_user_data_dir , get_models_folder , load_config , save_config
2525from cdmf_tracks import get_audio_duration , list_lora_adapters , load_track_meta , save_track_meta
2626from cdmf_generation_job import GenerationCancelled
2727import cdmf_state
@@ -59,6 +59,38 @@ def _is_cancel_requested(job_id: str) -> bool:
5959 return job_id in _cancel_requested
6060
6161
62+ def _is_model_available (dit_tag : str ) -> bool :
63+ """Return True if the given DiT model is installed and ready (no download needed). Used to promote pending_model jobs."""
64+ if not dit_tag or not isinstance (dit_tag , str ):
65+ return False
66+ dit = dit_tag .strip ().lower ()
67+ DIT_15_FOLDERS = {
68+ "turbo" : "acestep-v15-turbo" ,
69+ "base" : "acestep-v15-base" ,
70+ "sft" : "acestep-v15-sft" ,
71+ "turbo-shift1" : "acestep-v15-turbo-shift1" ,
72+ "turbo-shift3" : "acestep-v15-turbo-shift3" ,
73+ "turbo-continuous" : "acestep-v15-turbo-continuous" ,
74+ }
75+ REQUIRED_SUBDIRS = ("music_dcae_f8c8" , "music_vocoder" , "ace_step_transformer" , "umt5-base" )
76+ folder = DIT_15_FOLDERS .get (dit )
77+ models_root = Path (get_models_folder ()) / "checkpoints"
78+ if folder :
79+ candidate = models_root / folder
80+ if not candidate .exists ():
81+ return False
82+ for sub in REQUIRED_SUBDIRS :
83+ if not (candidate / sub ).exists ():
84+ return False
85+ return True
86+ # Legacy v1
87+ try :
88+ from ace_model_setup import ace_models_present
89+ return ace_models_present ()
90+ except Exception :
91+ return False
92+
93+
6294def _refs_dir () -> Path :
6395 d = get_user_data_dir () / "references"
6496 d .mkdir (parents = True , exist_ok = True )
@@ -115,10 +147,45 @@ def _on_job_progress(
115147register_job_progress_callback (_on_job_progress )
116148
117149
150+ def _update_job_progress_from_log (
151+ percent : int , current : int , total : int , eta_seconds : float | None
152+ ) -> None :
153+ """Update current job progress from parsed tqdm log line (log handler runs in same thread as worker)."""
154+ with _jobs_lock :
155+ jid = cdmf_state .get_current_generation_job_id ()
156+ if not jid :
157+ return
158+ job = _jobs .get (jid )
159+ if not job :
160+ return
161+ job ["progressPercent" ] = round (float (percent ), 1 )
162+ job ["progressSteps" ] = f"{ current } /{ total } "
163+ if eta_seconds is not None :
164+ job ["progressEta" ] = round (float (eta_seconds ), 1 )
165+
166+
118167def _run_generation (job_id : str ) -> None :
119168 """Background: run generate_track_ace and update job."""
120169 global _generation_busy , _current_job_id
170+ prev_config = None
171+ config_switched = False
121172 try :
173+ with _jobs_lock :
174+ job = _jobs .get (job_id )
175+ if not job or job .get ("status" ) != "queued" :
176+ return
177+ job_dit = job .get ("dit_model" ) or "turbo"
178+ # If required model is not installed, leave job as pending_model so it runs after user installs it.
179+ if not _is_model_available (job_dit ):
180+ logging .info ("[API generate] Job %s waiting for model %s (not installed)" , job_id , job_dit )
181+ with _jobs_lock :
182+ job = _jobs .get (job_id )
183+ if job and job .get ("status" ) == "queued" :
184+ job ["status" ] = "pending_model"
185+ job ["pendingReason" ] = (
186+ f"Model '{ job_dit } ' is not installed. Install it from Settings → Models to run this job."
187+ )
188+ return
122189 with _jobs_lock :
123190 job = _jobs .get (job_id )
124191 if not job or job .get ("status" ) != "queued" :
@@ -131,8 +198,20 @@ def _run_generation(job_id: str) -> None:
131198 _current_job_id = job_id
132199
133200 cdmf_state .set_current_generation_job_id (job_id )
201+ cdmf_state .set_progress_updater (_update_job_progress_from_log )
134202 cancel_check = lambda : _is_cancel_requested (job_id )
135- from generate_ace import generate_track_ace
203+ from generate_ace import generate_track_ace , clear_ace_pipeline
204+
205+ # Use job's dit_model (e.g. base for cover). Temporarily switch config so pipeline loads the right model.
206+ with _jobs_lock :
207+ j = _jobs .get (job_id )
208+ job_dit = (j .get ("dit_model" ) or "turbo" ) if j else "turbo"
209+ prev_config = load_config () or {}
210+ prev_config = dict (prev_config )
211+ config_switched = job_dit != (prev_config .get ("ace_step_dit_model" ) or "turbo" )
212+ if config_switched :
213+ save_config ({** prev_config , "ace_step_dit_model" : job_dit })
214+ clear_ace_pipeline ()
136215
137216 params = job .get ("params" ) or {}
138217 if not isinstance (params , dict ):
@@ -451,27 +530,64 @@ def _run_generation(job_id: str) -> None:
451530 job ["status" ] = "cancelled"
452531 job ["error" ] = "Cancelled by user"
453532 except Exception as e :
454- logging .exception ("Generation job %s failed" , job_id )
455- with _jobs_lock :
456- job = _jobs .get (job_id )
457- if job :
458- job ["status" ] = "failed"
459- job ["error" ] = str (e )
533+ err_msg = str (e )
534+ # Keep job queued as pending_model when model is missing so it can run after user installs it.
535+ if "not installed" in err_msg .lower () or "Settings → Models" in err_msg :
536+ logging .info ("[API generate] Job %s waiting for model: %s" , job_id , err_msg [:120 ])
537+ with _jobs_lock :
538+ job = _jobs .get (job_id )
539+ if job :
540+ job ["status" ] = "pending_model"
541+ job ["pendingReason" ] = err_msg
542+ job ["error" ] = None
543+ else :
544+ logging .exception ("Generation job %s failed" , job_id )
545+ with _jobs_lock :
546+ job = _jobs .get (job_id )
547+ if job :
548+ job ["status" ] = "failed"
549+ job ["error" ] = err_msg
460550 finally :
551+ cdmf_state .set_progress_updater (None )
552+ if config_switched and prev_config :
553+ save_config (prev_config )
554+ clear_ace_pipeline ()
461555 cdmf_state .set_current_generation_job_id (None )
462556 _generation_busy = False
463557 with _jobs_lock :
464558 _current_job_id = None
465559 _cancel_requested .discard (job_id )
466- # Start next queued job (skips cancelled: they are no longer " queued")
560+ # Promote pending_model jobs to queued when their model is now available; then start first queued job.
467561 with _jobs_lock :
562+ for jid in _job_order :
563+ j = _jobs .get (jid )
564+ if j and j .get ("status" ) == "pending_model" :
565+ dit = j .get ("dit_model" ) or "turbo"
566+ if _is_model_available (dit ):
567+ j ["status" ] = "queued"
568+ j ["pendingReason" ] = None
569+ logging .info ("[API generate] Job %s promoted to queued (model %s now available)" , jid , dit )
468570 for jid in _job_order :
469571 j = _jobs .get (jid )
470572 if j and j .get ("status" ) == "queued" :
471573 threading .Thread (target = _run_generation , args = (jid ,), daemon = True ).start ()
472574 break
473575
474576
577+ @bp .route ("/model-download-status" , methods = ["GET" ])
578+ def get_model_download_status ():
579+ """GET /api/generate/model-download-status — whether pipeline is loading (may be downloading model files)."""
580+ try :
581+ st = getattr (cdmf_state , "GENERATION_MODEL_LOADING" , {})
582+ return jsonify ({
583+ "in_progress" : bool (st .get ("in_progress" )),
584+ "message" : st .get ("message" ) or "Preparing model (downloading if needed)..." ,
585+ })
586+ except Exception as e :
587+ logging .warning ("[API generate] model-download-status failed: %s" , e )
588+ return jsonify ({"in_progress" : False , "message" : "" })
589+
590+
475591@bp .route ("/lora_adapters" , methods = ["GET" ])
476592def get_lora_adapters ():
477593 """GET /api/generate/lora_adapters — list LoRA adapters (e.g. from Training or custom_lora)."""
@@ -554,7 +670,15 @@ def _str(v):
554670 except (TypeError , ValueError ):
555671 params_copy = {}
556672 config = load_config ()
557- dit_tag = config .get ("ace_step_dit_model" ) or params_copy .get ("aceStepDitModel" ) or "turbo"
673+ # User override from Generation tab model selector takes precedence; else auto base for cover (per docs).
674+ task_for_dit = (params_copy .get ("task_type" ) or params_copy .get ("taskType" ) or "text2music" ).strip ().lower ()
675+ explicit_dit = (params_copy .get ("aceStepDitModel" ) or params_copy .get ("ace_step_dit_model" ) or "" ).strip ()
676+ if explicit_dit :
677+ dit_tag = explicit_dit
678+ elif task_for_dit == "cover" :
679+ dit_tag = "base"
680+ else :
681+ dit_tag = config .get ("ace_step_dit_model" ) or "turbo"
558682 lm_tag = config .get ("ace_step_lm" ) or params_copy .get ("aceStepLm" ) or "1.7B"
559683 with _jobs_lock :
560684 _jobs [job_id ] = {
@@ -608,10 +732,42 @@ def get_status(job_id: str):
608732 "progressStage" : job .get ("progressStage" ),
609733 "result" : job .get ("result" ),
610734 "error" : job .get ("error" ),
735+ "pendingReason" : job .get ("pendingReason" ) if status == "pending_model" else None ,
611736 }
612737 return jsonify (out )
613738
614739
740+ @bp .route ("/retry-pending" , methods = ["POST" ])
741+ def retry_pending ():
742+ """POST /api/generate/retry-pending — promote pending_model jobs to queued if model is now available, start first queued job. Call after model download completes."""
743+ global _generation_busy
744+ promoted = 0
745+ started = None
746+ with _jobs_lock :
747+ for jid in _job_order :
748+ j = _jobs .get (jid )
749+ if j and j .get ("status" ) == "pending_model" :
750+ dit = j .get ("dit_model" ) or "turbo"
751+ if _is_model_available (dit ):
752+ j ["status" ] = "queued"
753+ j ["pendingReason" ] = None
754+ promoted += 1
755+ logging .info ("[API generate] Job %s promoted to queued (model %s now available)" , jid , dit )
756+ if not _generation_busy :
757+ for jid in _job_order :
758+ j = _jobs .get (jid )
759+ if j and j .get ("status" ) == "queued" :
760+ _generation_busy = True
761+ threading .Thread (target = _run_generation , args = (jid ,), daemon = True ).start ()
762+ started = jid
763+ break
764+ return jsonify ({
765+ "ok" : True ,
766+ "promoted" : promoted ,
767+ "startedJobId" : started ,
768+ })
769+
770+
615771@bp .route ("/unstick" , methods = ["POST" ])
616772def unstick_queue ():
617773 """POST /api/generate/unstick — clear stuck worker state and start the next queued job (if any)."""
@@ -641,9 +797,10 @@ def cancel_job(job_id: str):
641797 if not job :
642798 return jsonify ({"error" : "Job not found" }), 404
643799 status = job .get ("status" , "unknown" )
644- if status == "queued" :
800+ if status in ( "queued" , "pending_model" ) :
645801 job ["status" ] = "cancelled"
646802 job ["error" ] = "Cancelled by user"
803+ job ["pendingReason" ] = None
647804 return jsonify ({"cancelled" : True , "jobId" : job_id , "message" : "Job removed from queue." })
648805 if status == "running" :
649806 _cancel_requested .add (job_id )
0 commit comments