204 lines
5.5 KiB
Python
204 lines
5.5 KiB
Python
import torch
|
|
import queue
|
|
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from uuid import UUID
|
|
|
|
from . import UPLOAD_DIR
|
|
from .db import DATABASE
|
|
from .sse import bus
|
|
|
|
import sqlite3
|
|
|
|
|
|
@dataclass
|
|
class DemucsTask:
|
|
database_id: UUID
|
|
orig_name: str
|
|
song_location: str
|
|
|
|
|
|
work_queue: queue.Queue[DemucsTask] = queue.Queue()
|
|
|
|
|
|
def get_device_name() -> str:
|
|
if torch.cuda.is_available():
|
|
return f"{torch.cuda.get_device_name()} (CUDA)"
|
|
|
|
try:
|
|
cpu_info = Path("/proc/cpuinfo").read_text().splitlines()
|
|
cpu_info = [l.split(":") for l in cpu_info]
|
|
cpu_info = [l for l in cpu_info if l[0].strip() == "model name"]
|
|
cpu_info = [l[1].strip() for l in cpu_info][0]
|
|
return f"{cpu_info} (CPU)"
|
|
except:
|
|
return "Unknown CPU"
|
|
|
|
|
|
def get_queue_size() -> int:
|
|
return work_queue.unfinished_tasks
|
|
|
|
|
|
def report_task_status(task_id: UUID, status: str):
|
|
bus.announce({"task": str(task_id), "status": status})
|
|
# short-lived sqlite connections since we spend most of the time
|
|
# on the worker thread just waiting for stuff
|
|
with sqlite3.connect(DATABASE) as conn:
|
|
conn.execute(
|
|
"UPDATE tasks SET status = ? WHERE id = ?",
|
|
[status, str(task_id)],
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def report_task_output(task_id: UUID, track: str, output_path: Path):
|
|
bus.announce({"task": str(task_id), "output": track})
|
|
# short-lived sqlite connections since we spend most of the time
|
|
# on the worker thread just waiting for stuff
|
|
with sqlite3.connect(DATABASE) as conn:
|
|
conn.execute(
|
|
"INSERT INTO task_outputs (task_id, track, output_path) VALUES (?, ?, ?)",
|
|
[str(task_id), track, str(output_path)],
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def report_task_progress(
|
|
task_id: UUID, stage: int, stages: int, step: int, stage_steps: int
|
|
):
|
|
bus.announce(
|
|
{
|
|
"task": str(task_id),
|
|
"progress": {
|
|
"stage": stage,
|
|
"stages": stages,
|
|
"step": step,
|
|
"stage-steps": stage_steps,
|
|
},
|
|
}
|
|
)
|
|
|
|
with sqlite3.connect(DATABASE) as conn:
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO task_progress (task_id, stage, stages, step, stage_steps) VALUES (?, ?, ?, ?, ?)",
|
|
[str(task_id), stage, stages, step, stage_steps],
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
from demucs.pretrained import get_model
|
|
from demucs.apply import apply_model, tqdm, BagOfModels
|
|
from demucs.separate import load_track
|
|
from demucs.audio import save_audio
|
|
|
|
import subprocess
|
|
|
|
|
|
def convert_audio(source: Path, output: Path):
|
|
command = ["ffmpeg", "-y", "-loglevel", "panic"]
|
|
command += ["-i", str(source)]
|
|
command += [str(output)]
|
|
subprocess.run(command, check=True)
|
|
|
|
|
|
def process_task(task: DemucsTask, model):
|
|
task_id = task.database_id
|
|
|
|
song_path = UPLOAD_DIR / task.song_location
|
|
|
|
task_dir = UPLOAD_DIR / str(task.database_id)
|
|
task_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
master_path = task_dir / f"{task.orig_name}-master.flac"
|
|
convert_audio(song_path, master_path)
|
|
report_task_output(task.database_id, "master", master_path)
|
|
|
|
wav = load_track(song_path, model.audio_channels, model.samplerate)
|
|
|
|
ref = wav.mean(0)
|
|
wav -= ref.mean()
|
|
wav /= ref.std()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# hook the demucs progress function
|
|
stages = 1
|
|
if isinstance(model, BagOfModels):
|
|
stages = len(model.models)
|
|
stage = 0
|
|
|
|
def progress(futures, **kwargs):
|
|
nonlocal stage, task_id
|
|
|
|
stage += 1
|
|
stage_steps = len(futures)
|
|
|
|
for step, future in enumerate(futures, start=1):
|
|
report_task_progress(task_id, stage, stages, step, stage_steps)
|
|
yield future
|
|
|
|
tqdm.tqdm = progress
|
|
sources = apply_model(
|
|
model,
|
|
wav[None],
|
|
device=device,
|
|
shifts=1,
|
|
split=True,
|
|
overlap=0.25,
|
|
progress=True,
|
|
num_workers=0,
|
|
)[0]
|
|
|
|
# denormalize amplitude / DC
|
|
sources *= ref.std()
|
|
sources += ref.mean()
|
|
|
|
report_task_status(task.database_id, "saving")
|
|
|
|
outputs = {}
|
|
for source, source_name in zip(sources, model.sources):
|
|
path = song_path.parent / f"{song_path.name}.{source_name}.wav"
|
|
save_audio(source, str(path), samplerate=model.samplerate)
|
|
outputs[source_name] = path
|
|
|
|
song_path.unlink()
|
|
|
|
for source, path in sorted(outputs.items()):
|
|
flac_path = task_dir / f"{task.orig_name}-{source}.flac"
|
|
convert_audio(path, flac_path)
|
|
report_task_output(task.database_id, source, flac_path)
|
|
path.unlink() # delete wav
|
|
|
|
|
|
def worker_main():
|
|
model = get_model("htdemucs_ft")
|
|
model.eval()
|
|
|
|
with sqlite3.connect(DATABASE) as conn:
|
|
cur = conn.execute(
|
|
"SELECT id, orig_name, song_location FROM tasks WHERE status NOT IN ('done', 'errored')"
|
|
)
|
|
for id, orig_name, song_location in cur.fetchall():
|
|
work_queue.put(DemucsTask(UUID(id), orig_name, song_location))
|
|
|
|
while True:
|
|
task = work_queue.get()
|
|
try:
|
|
report_task_status(task.database_id, "in progress")
|
|
process_task(task, model)
|
|
report_task_status(task.database_id, "done")
|
|
except BaseException as e:
|
|
report_task_status(task.database_id, "errored")
|
|
print(e)
|
|
work_queue.task_done()
|
|
|
|
|
|
def start_worker():
|
|
import threading
|
|
|
|
thread = threading.Thread(target=worker_main, daemon=True)
|
|
thread.start()
|
|
|
|
return thread
|