demucs-server/src/demucs_server/torch_worker.py

204 lines
5.5 KiB
Python

import torch
import queue
from pathlib import Path
from dataclasses import dataclass
from uuid import UUID
from . import UPLOAD_DIR
from .db import DATABASE
from .sse import bus
import sqlite3
@dataclass
class DemucsTask:
database_id: UUID
orig_name: str
song_location: str
work_queue: queue.Queue[DemucsTask] = queue.Queue()
def get_device_name() -> str:
if torch.cuda.is_available():
return f"{torch.cuda.get_device_name()} (CUDA)"
try:
cpu_info = Path("/proc/cpuinfo").read_text().splitlines()
cpu_info = [l.split(":") for l in cpu_info]
cpu_info = [l for l in cpu_info if l[0].strip() == "model name"]
cpu_info = [l[1].strip() for l in cpu_info][0]
return f"{cpu_info} (CPU)"
except:
return "Unknown CPU"
def get_queue_size() -> int:
return work_queue.unfinished_tasks
def report_task_status(task_id: UUID, status: str):
bus.announce({"task": str(task_id), "status": status})
# short-lived sqlite connections since we spend most of the time
# on the worker thread just waiting for stuff
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"UPDATE tasks SET status = ? WHERE id = ?",
[status, str(task_id)],
)
conn.commit()
def report_task_output(task_id: UUID, track: str, output_path: Path):
bus.announce({"task": str(task_id), "output": track})
# short-lived sqlite connections since we spend most of the time
# on the worker thread just waiting for stuff
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"INSERT INTO task_outputs (task_id, track, output_path) VALUES (?, ?, ?)",
[str(task_id), track, str(output_path)],
)
conn.commit()
def report_task_progress(
task_id: UUID, stage: int, stages: int, step: int, stage_steps: int
):
bus.announce(
{
"task": str(task_id),
"progress": {
"stage": stage,
"stages": stages,
"step": step,
"stage-steps": stage_steps,
},
}
)
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"INSERT OR REPLACE INTO task_progress (task_id, stage, stages, step, stage_steps) VALUES (?, ?, ?, ?, ?)",
[str(task_id), stage, stages, step, stage_steps],
)
conn.commit()
from demucs.pretrained import get_model
from demucs.apply import apply_model, tqdm, BagOfModels
from demucs.separate import load_track
from demucs.audio import save_audio
import subprocess
def convert_audio(source: Path, output: Path):
command = ["ffmpeg", "-y", "-loglevel", "panic"]
command += ["-i", str(source)]
command += [str(output)]
subprocess.run(command, check=True)
def process_task(task: DemucsTask, model):
task_id = task.database_id
song_path = UPLOAD_DIR / task.song_location
task_dir = UPLOAD_DIR / str(task.database_id)
task_dir.mkdir(parents=True, exist_ok=True)
master_path = task_dir / f"{task.orig_name}-master.flac"
convert_audio(song_path, master_path)
report_task_output(task.database_id, "master", master_path)
wav = load_track(song_path, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav -= ref.mean()
wav /= ref.std()
device = "cuda" if torch.cuda.is_available() else "cpu"
# hook the demucs progress function
stages = 1
if isinstance(model, BagOfModels):
stages = len(model.models)
stage = 0
def progress(futures, **kwargs):
nonlocal stage, task_id
stage += 1
stage_steps = len(futures)
for step, future in enumerate(futures, start=1):
report_task_progress(task_id, stage, stages, step, stage_steps)
yield future
tqdm.tqdm = progress
sources = apply_model(
model,
wav[None],
device=device,
shifts=1,
split=True,
overlap=0.25,
progress=True,
num_workers=0,
)[0]
# denormalize amplitude / DC
sources *= ref.std()
sources += ref.mean()
report_task_status(task.database_id, "saving")
outputs = {}
for source, source_name in zip(sources, model.sources):
path = song_path.parent / f"{song_path.name}.{source_name}.wav"
save_audio(source, str(path), samplerate=model.samplerate)
outputs[source_name] = path
song_path.unlink()
for source, path in sorted(outputs.items()):
flac_path = task_dir / f"{task.orig_name}-{source}.flac"
convert_audio(path, flac_path)
report_task_output(task.database_id, source, flac_path)
path.unlink() # delete wav
def worker_main():
model = get_model("htdemucs_ft")
model.eval()
with sqlite3.connect(DATABASE) as conn:
cur = conn.execute(
"SELECT id, orig_name, song_location FROM tasks WHERE status NOT IN ('done', 'errored')"
)
for id, orig_name, song_location in cur.fetchall():
work_queue.put(DemucsTask(UUID(id), orig_name, song_location))
while True:
task = work_queue.get()
try:
report_task_status(task.database_id, "in progress")
process_task(task, model)
report_task_status(task.database_id, "done")
except BaseException as e:
report_task_status(task.database_id, "errored")
print(e)
work_queue.task_done()
def start_worker():
import threading
thread = threading.Thread(target=worker_main, daemon=True)
thread.start()
return thread