commit 3138e96ae04a692b8e4105a9e96e3724e493a97d Author: videogame hacker Date: Thu Jul 6 21:15:54 2023 +0000 Initial commit diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..73390a1 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,9 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = false +insert_final_newline = true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6c90650 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# venv +.venv + +src/data.db +src/uploads diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..d2c96c0 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.3 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..cee7b74 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none" +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..e0cacda --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +# demucs-server + +Web service which lets you demix stems using remote compute resources. + +## Setup + +```shell +$ rye sync +$ rye shell +[rye] $ cd src/ +src/ [rye] $ ./dev_run.py +``` + +## Architecture + +- flask application mostly lives in `demucs_server.routes` +- demucs / ffmpeg is run on a separate thread at `demucs_server.torch_worker` +- songs are queued with `torch_worker.work_queue.put(DemucsTask(..))` +- songs are processed by the thread worker and the database is written to report progress diff --git a/demucs-stems-zip.nginx.conf b/demucs-stems-zip.nginx.conf new file mode 100644 index 0000000..df12d49 --- /dev/null +++ b/demucs-stems-zip.nginx.conf @@ -0,0 +1,24 @@ +server { + listen 443 ssl http2; + listen [::]:443 ssl http2; + + server_name demucs.stems.zip; + + ssl_certificate /etc/letsencrypt/live/demucs.stems.zip/fullchain.pem; + ssl_certificate_key /etc/letsencrypt/live/demucs.stems.zip/privkey.pem; + + location / { + proxy_pass http://127.0.0.2:8000/; + } + + location ^~ /uploads/ { + gzip on; + alias /path/to/demucs-server/src/uploads/; + } + + location = /sse { + proxy_read_timeout 300; + proxy_send_timeout 300; + proxy_pass http://127.0.0.2:8000/; + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..56e5b63 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[project] +name = "demucs-server" +version = "0.1.0" +description = "Add a short description here" +dependencies = [ + "demucs>=4.0.0", + "flask>=2.3.2", + "gunicorn>=20.1.0", + "colorama>=0.4.6", + "ipython>=8.12.2", +] +readme = "README.md" +requires-python = ">= 3.8" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [ + "black>=23.3.0", +] + +[tool.hatch.metadata] +allow-direct-references = true diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..706dd53 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,82 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false + +-e file:. +antlr4-python3-runtime==4.9.3 +asttokens==2.2.1 +backcall==0.2.0 +black==23.3.0 +blinker==1.6.2 +click==8.1.3 +cloudpickle==2.2.1 +cmake==3.26.4 +colorama==0.4.6 +cython==0.29.36 +decorator==5.1.1 +demucs==4.0.0 +diffq==0.2.4 +dora-search==0.1.12 +einops==0.6.1 +executing==1.2.0 +filelock==3.12.2 +flask==2.3.2 +gunicorn==20.1.0 +ipython==8.14.0 +itsdangerous==2.1.2 +jedi==0.18.2 +jinja2==3.1.2 +julius==0.2.7 +lameenc==1.5.1 +lit==16.0.6 +markupsafe==2.1.3 +matplotlib-inline==0.1.6 +mpmath==1.3.0 +mypy-extensions==1.0.0 +networkx==3.1 +numpy==1.25.0 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +omegaconf==2.3.0 +openunmix==1.2.1 +packaging==23.1 +parso==0.8.3 +pathspec==0.11.1 +pexpect==4.8.0 +pickleshare==0.7.5 +platformdirs==3.8.0 +prompt-toolkit==3.0.39 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pygments==2.15.1 +pyyaml==6.0 +retrying==1.3.4 +six==1.16.0 +stack-data==0.6.2 +submitit==1.4.5 +sympy==1.12 +torch==2.0.1 +torchaudio==2.0.2 +tqdm==4.65.0 +traitlets==5.9.0 +treetable==0.2.5 +triton==2.0.0 +typing-extensions==4.7.1 +wcwidth==0.2.6 +werkzeug==2.3.6 +wheel==0.40.0 +# The following packages are considered to be unsafe in a requirements file: +setuptools==68.0.0 diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..992ce3b --- /dev/null +++ b/requirements.lock @@ -0,0 +1,77 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false + +-e file:. +antlr4-python3-runtime==4.9.3 +asttokens==2.2.1 +backcall==0.2.0 +blinker==1.6.2 +click==8.1.3 +cloudpickle==2.2.1 +cmake==3.26.4 +colorama==0.4.6 +cython==0.29.36 +decorator==5.1.1 +demucs==4.0.0 +diffq==0.2.4 +dora-search==0.1.12 +einops==0.6.1 +executing==1.2.0 +filelock==3.12.2 +flask==2.3.2 +gunicorn==20.1.0 +ipython==8.14.0 +itsdangerous==2.1.2 +jedi==0.18.2 +jinja2==3.1.2 +julius==0.2.7 +lameenc==1.5.1 +lit==16.0.6 +markupsafe==2.1.3 +matplotlib-inline==0.1.6 +mpmath==1.3.0 +networkx==3.1 +numpy==1.25.0 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +omegaconf==2.3.0 +openunmix==1.2.1 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +prompt-toolkit==3.0.39 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pygments==2.15.1 +pyyaml==6.0 +retrying==1.3.4 +six==1.16.0 +stack-data==0.6.2 +submitit==1.4.5 +sympy==1.12 +torch==2.0.1 +torchaudio==2.0.2 +tqdm==4.65.0 +traitlets==5.9.0 +treetable==0.2.5 +triton==2.0.0 +typing-extensions==4.7.1 +wcwidth==0.2.6 +werkzeug==2.3.6 +wheel==0.40.0 +# The following packages are considered to be unsafe in a requirements file: +setuptools==68.0.0 diff --git a/src/admin.py b/src/admin.py new file mode 100755 index 0000000..39c5fd8 --- /dev/null +++ b/src/admin.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +import demucs_server + +import uuid +import IPython + +with demucs_server.app.app_context(): + + def create_user(name: str) -> str: + db = demucs_server.db.get_db() + user_id = str(uuid.uuid4()) + db.execute( + "INSERT INTO users (user_id, nickname) VALUES (?, ?)", [user_id, name] + ) + db.commit() + + return user_id + + def list_users() -> list["demucs_server.auth.User"]: + from demucs_server.auth import User + + db = demucs_server.db.get_db() + users = [] + cur = db.execute("SELECT * FROM users") + for user_id, nickname in cur.fetchall(): + users.append(User(uuid.UUID(user_id), nickname)) + cur.close() + + return users + + IPython.embed() diff --git a/src/demucs_server/__init__.py b/src/demucs_server/__init__.py new file mode 100644 index 0000000..a032bef --- /dev/null +++ b/src/demucs_server/__init__.py @@ -0,0 +1,22 @@ +from flask import Flask +import colorama +from pathlib import Path + +app = Flask(__name__) + +colorama.init() +UPLOAD_DIR = Path(".") / "uploads" +UPLOAD_DIR.mkdir(exist_ok=True) + +from . import db + +with app.app_context(): + from .migrations import apply_migrations + + apply_migrations() + +from . import routes + +from .torch_worker import start_worker + +start_worker() diff --git a/src/demucs_server/auth.py b/src/demucs_server/auth.py new file mode 100644 index 0000000..35c31a7 --- /dev/null +++ b/src/demucs_server/auth.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from uuid import UUID + +from .db import get_db + + +@dataclass +class User: + user_id: UUID + nickname: str + + +def get_user(user_id: str) -> User | None: + try: + user_id = UUID(user_id) + except: + return None + + db = get_db() + cur = db.execute("SELECT * FROM users WHERE user_id = ?", [str(user_id)]) + user_row = cur.fetchone() + if user_row is None: + return None + + return User(UUID(user_row[0]), user_row[1]) diff --git a/src/demucs_server/db.py b/src/demucs_server/db.py new file mode 100644 index 0000000..1155afb --- /dev/null +++ b/src/demucs_server/db.py @@ -0,0 +1,22 @@ +import sqlite3 +from flask import g + +import os + +DATABASE = os.environ.get("DATABASE_PATH") or "data.db" + +from . import app + + +def get_db(): + db = getattr(g, "_database", None) + if db is None: + db = g._database = sqlite3.connect(DATABASE) + return db + + +@app.teardown_appcontext +def close_connection(exception): + db = getattr(g, "_database", None) + if db is not None: + db.close() diff --git a/src/demucs_server/migrations.py b/src/demucs_server/migrations.py new file mode 100644 index 0000000..9ff6350 --- /dev/null +++ b/src/demucs_server/migrations.py @@ -0,0 +1,29 @@ +from .db import get_db +from pathlib import Path + +import colorama + + +def apply_migrations(): + migrations_dir = Path("./migrations") + if not migrations_dir.exists(): + raise Exception("migrations directory was not found") + + db = get_db() + already_run_migrations = set([]) + try: + cur = db.execute("SELECT applied FROM migrations") + for (applied,) in cur.fetchall(): + already_run_migrations.add(applied) + except: + pass + + for migration in sorted(migrations_dir.iterdir()): + if migration.name not in already_run_migrations: + print( + f"{colorama.Fore.LIGHTBLUE_EX}[*]{colorama.Fore.RESET} " + f"Applied migration: {colorama.Fore.GREEN}{migration.name}{colorama.Fore.RESET}" + ) + db.executescript(migration.read_text()) + db.execute("INSERT INTO migrations VALUES (?)", [migration.name]) + db.commit() diff --git a/src/demucs_server/routes.py b/src/demucs_server/routes.py new file mode 100644 index 0000000..16ba183 --- /dev/null +++ b/src/demucs_server/routes.py @@ -0,0 +1,168 @@ +from flask import render_template, request, redirect, send_file, Response + +from werkzeug.utils import secure_filename +from werkzeug.security import safe_join + +from uuid import uuid4 +import json + +from . import app, UPLOAD_DIR +from .auth import get_user +from .db import get_db +from .torch_worker import get_device_name, get_queue_size, work_queue, DemucsTask +from .sse import bus + + +@app.route("/") +@app.route("/from-url") +def index_page(): + user_id = request.cookies.get("demucs_server_user_id") + user = user_id and get_user(user_id) + if user is None: + return render_template("login.html.j2") + + status = {"device": get_device_name(), "queue_size": get_queue_size()} + + if request.path == "/from-url": + return render_template("demix_ytdlp.html.j2", user=user, status=status) + return render_template("demix.html.j2", user=user, status=status) + + +@app.route("/login", methods=["GET", "POST"]) +def login(): + response = redirect("/") + user_id = request.form.get("user-id", type=str) + if user_id is not None: + response.set_cookie( + "demucs_server_user_id", + user_id, + samesite="Strict", + secure=True, + ) + + return response + + +@app.route("/logout", methods=["GET", "POST"]) +def logout(): + if request.method == "POST": + response = redirect("/") + response.delete_cookie("demucs_server_user_id") + return response + return render_template("logout.html.j2") + + +# TODO: demix routes (involves multipart form parsing!! maybe ytdlp!!) +@app.route("/demix", methods=["POST"]) +def demix(): + user_id = request.cookies.get("demucs_server_user_id") + user = user_id and get_user(user_id) + if user is None: + return "You are not logged in!", 403 + + audio_file = next(request.files.values(), None) + if audio_file is None: + return redirect("/") + + orig_name = secure_filename(audio_file.filename).replace("_", "-") + task_id = uuid4() + song_location = f"{task_id}-{orig_name}" + orig_name = orig_name.rsplit(".", 1)[0] + + db = get_db() + db.execute( + "INSERT INTO tasks (id, orig_name, status, song_location) VALUES (?, ?, ?, ?)", + [str(task_id), orig_name, "queued", song_location], + ) + audio_file.save(UPLOAD_DIR / song_location) + db.commit() + + demucs_task = DemucsTask(task_id, orig_name, song_location) + work_queue.put(demucs_task) + + return redirect(f"/task/{task_id}") + + +@app.route("/task/") +def task_view(task_id): + db = get_db() + + try: + cur = db.execute("SELECT status, orig_name FROM tasks WHERE id = ?", [task_id]) + (status, orig_name) = cur.fetchone() + except: + # TODO: real error page + return "Task with the given ID was not found", 404 + + outputs = [] + try: + cur = db.execute( + "SELECT track, output_path FROM task_outputs WHERE task_id = ?", [task_id] + ) + outputs = [ + {"track": track, "path": path} + for ( + track, + path, + ) in cur.fetchall() + ] + except: + pass + + outputs = sorted( + outputs, + key=lambda o: ["master", "vocals", "other", "bass", "drums"].index(o["track"]), + ) + + progress = {} + try: + cur = db.execute( + "SELECT stage, stages, step, stage_steps FROM task_progress WHERE task_id = ?", + [task_id], + ) + stage, stages, step, stage_steps = cur.fetchone() + progress = { + "stage": stage, + "stages": stages, + "step": step, + "stage-steps": stage_steps, + } + except: + pass + + return render_template( + "task.html.j2", + task_id=task_id, + status=status, + orig_name=orig_name, + outputs=outputs, + progress=progress, + ) + + +@app.route("/uploads/") +def send_upload(upload): + upload_file = safe_join(UPLOAD_DIR.absolute(), upload) + return send_file(upload_file, conditional=True) + + +@app.route("/sse") +def sse(): + task_id = request.args.get("task_id") + if task_id is None: + return "bad request", 400 + + def stream(): + queue = bus.listen() + while True: + message = queue.get() + if message.get("task") != task_id: + continue + message = json.dumps(message) + yield f"{message}\n\n" + + return Response( + stream(), + mimetype="text/event-stream", + headers={"cache-control": "no-cache", "x-accel-buffering": "no"}, + ) diff --git a/src/demucs_server/sse.py b/src/demucs_server/sse.py new file mode 100644 index 0000000..7836490 --- /dev/null +++ b/src/demucs_server/sse.py @@ -0,0 +1,21 @@ +import queue + + +class AnnounceBus: + def __init__(self): + self.listeners = [] + + def listen(self): + q = queue.Queue(maxsize=5) + self.listeners.append(q) + return q + + def announce(self, msg): + for i in reversed(range(len(self.listeners))): + try: + self.listeners[i].put_nowait(msg) + except queue.Full: + del self.listeners[i] + + +bus = AnnounceBus() diff --git a/src/demucs_server/static/styles.css b/src/demucs_server/static/styles.css new file mode 100644 index 0000000..a8c4a07 --- /dev/null +++ b/src/demucs_server/static/styles.css @@ -0,0 +1,265 @@ +:root { + --col-bg: #25181b; + --col-bg-dark: #1e1417; + --col-bg-light: #331f23; + --col-accent: #44cf6c; + --col-accent-dark: #32a287; + --col-accent-light: #a9fdac; + --col-fg: #f5fff2; + + --font-main: "Inter", "Arial", sans-serif; + --font-mono: ui-monospace, "Cascadia Mono", "Segoe UI Mono", "Courier New", + monospace; +} + +html { + color-scheme: dark; + font-family: var(--font-main); + font-size: 1.25rem; + color: var(--col-fg); + background-color: var(--col-bg); +} + +html, +body { + padding: 0; + margin: 0; +} + +body { + display: flex; + flex-direction: column; + min-height: 100vh; +} + +main { + flex: 1; + width: calc(100% - 1em); + max-width: 80ch; + margin: 0 auto; + padding: 0.5em; +} + +code { + font-family: var(--font-mono); + font-size: 1rem; +} + +a { + color: var(--col-accent); + text-decoration: none; +} + +a:hover { + border-bottom: 1px solid var(--col-accent); +} + +.logo { + display: inline; + height: 3.5em; +} + +header { + display: flex; + flex-direction: row; + justify-content: center; + align-items: center; + gap: 1em; +} + +.tagline { + text-align: center; + margin-top: 0.25em; +} + +header.vertical { + flex-direction: column; + gap: 0.5em; +} + +header h1 { + margin: 0; + margin-bottom: 0.5em; +} + +header + h2 { + margin-top: 0; +} + +header small { + font-size: 1em; +} + +form { + display: flex; + flex-direction: column; + + padding: 1em; + border: 2.5px solid var(--col-accent-dark); + border-radius: 8px; +} + +form :first-child { + margin-top: 0; +} + +form :last-child { + margin-bottom: 0; +} + +label { + margin-bottom: 0.5em; +} + +input { + background-color: transparent; + border: 1.5px solid var(--col-accent-light); + outline: none; + border-radius: 8px; + + font-family: var(--font-main); + font-size: 1em; + padding: 0.75rem 1rem; + margin-bottom: 0.75rem; + + box-shadow: none; + width: calc(100% - 2rem); + + color: var(--col-fg); +} + +.file-upload { + margin: 0.75rem 0; + overflow-x: clip; + white-space: nowrap; +} + +input[type="url"] { + margin-bottom: 0.25em; +} + +.file-upload + p, +input[type="url"] + p { + margin-top: 0.25em; + padding-left: 1.5em; +} + +input:active, +input:focus { + border-color: var(--col-accent); +} + +input[type="file"] { + display: none; +} + +button, +.file-upload-button { + font-family: var(--font-main); + font-size: 1em; + + background-color: var(--col-accent-light); + border: 1.5px solid var(--col-accent-light); + border-radius: 8px; + + outline: none; + color: var(--col-bg); + + padding: 0.5rem; +} + +.file-upload-button { + margin-right: 0.5em; +} + +button:hover { + cursor: pointer; + background-color: var(--col-accent); +} + +input::placeholder { + opacity: 0.33; + text-transform: lowercase; +} + +dl { + display: flex; + flex-direction: row; + gap: 0.5ch; + margin: 0; +} + +dt { + font-weight: bold; +} + +dt::after { + content: ":"; +} + +dd { + margin: 0; + display: inline; +} + +footer { + display: flex; + flex-direction: row; + justify-content: space-between; + + padding: 1em; + background-color: var(--col-bg-dark); +} + +nav ul { + list-style: none; + display: flex; + flex-direction: row; + gap: 0.5em; + margin: 0; + padding: 0; +} + +nav ul li { + display: table-cell; + padding: 0.1em 0; + vertical-align: middle; +} + +nav ul li + li { + border-inline-start: 1px solid var(--col-fg); + padding-inline-start: 0.5em; +} + +.track-list { + padding: 0; +} + +.track-list li { + display: flex; + flex-direction: row; + list-style: none; + padding: 0; + margin: 1em; + align-items: center; + justify-content: space-between; +} + +.track-list a { + display: inline-block; + padding: 0.5em 0.5em; + margin-right: 1em; + border: 1px solid var(--col-accent); + border-radius: 4px; + width: 8em; + text-align: right; +} + +.track-list a:hover { + background: var(--col-accent); + color: var(--col-bg); +} + +.track-list audio { + flex: 1; +} diff --git a/src/demucs_server/templates/_base.html.j2 b/src/demucs_server/templates/_base.html.j2 new file mode 100644 index 0000000..87c7b75 --- /dev/null +++ b/src/demucs_server/templates/_base.html.j2 @@ -0,0 +1,21 @@ + + + + + + + + demucs.stems.zip + + + +
+
+

demucs.stems.zip

+
+ + {% block content %}{% endblock %} +
+ + + diff --git a/src/demucs_server/templates/demix.html.j2 b/src/demucs_server/templates/demix.html.j2 new file mode 100644 index 0000000..7f29f89 --- /dev/null +++ b/src/demucs_server/templates/demix.html.j2 @@ -0,0 +1,41 @@ +{% extends "_base.html.j2" %} + +{% block content %} +

demix song

+

+ use demucs + to generate 'vocal', 'bass', 'drums', and 'other' tracks: +

+ +{% block form %} +
+
+ + + + + + + +
    +
  • songs in queue: {{ status['queue_size'] }}
  • +
  • this node is using: {{ status['device'] }}
  • +
+
+ + +{% endblock %} +{% endblock %} diff --git a/src/demucs_server/templates/demix_ytdlp.html.j2 b/src/demucs_server/templates/demix_ytdlp.html.j2 new file mode 100644 index 0000000..30f0a54 --- /dev/null +++ b/src/demucs_server/templates/demix_ytdlp.html.j2 @@ -0,0 +1,19 @@ +{% extends "demix.html.j2" %} + +{% block form %} +
+
+ + + + +

or upload an audio file.

+ + + +
    +
  • songs in queue: {{ status['queue_size'] }}
  • +
  • this node is using: {{ status['device'] }}
  • +
+
+{% endblock %} diff --git a/src/demucs_server/templates/login.html.j2 b/src/demucs_server/templates/login.html.j2 new file mode 100644 index 0000000..68ccf1f --- /dev/null +++ b/src/demucs_server/templates/login.html.j2 @@ -0,0 +1,12 @@ +{% extends "_base.html.j2" %} + +{% block content %} +

log in

+ +
+ + + +

demucs.stems.zip is invite-only.

+
+{% endblock %} diff --git a/src/demucs_server/templates/logout.html.j2 b/src/demucs_server/templates/logout.html.j2 new file mode 100644 index 0000000..520c312 --- /dev/null +++ b/src/demucs_server/templates/logout.html.j2 @@ -0,0 +1,9 @@ +{% extends "_base.html.j2" %} + +{% block content %} +
+

are you sure you want to log out?

+ +

you will need your user ID to log back in.

+
+{% endblock %} diff --git a/src/demucs_server/templates/task.html.j2 b/src/demucs_server/templates/task.html.j2 new file mode 100644 index 0000000..773ec7f --- /dev/null +++ b/src/demucs_server/templates/task.html.j2 @@ -0,0 +1,61 @@ +{% extends "_base.html.j2" %} + +{% block content %} +

task

+ +{% if status in ["queued"] %} +

your upload is currently in the queue, and hasn't started processing yet.

+{% elif status in ["in progress", "saving"] %} +

your upload is currently processing.

+{% endif %} + +

status: {{ status }}

+

originally: {{ orig_name }}

+ +
+

progress

+

+ stage: + {{ progress.get('stage') or '' }} / + {{ progress.get('stages') or '' }} +

+

+ step: + {{ progress.get('step') or '' }} / + {{ progress.get('stage-steps') or '' }} +

+
+ + + +{% if outputs %} + +{% endif %} + +{% endblock %} diff --git a/src/demucs_server/torch_worker.py b/src/demucs_server/torch_worker.py new file mode 100644 index 0000000..e24c325 --- /dev/null +++ b/src/demucs_server/torch_worker.py @@ -0,0 +1,203 @@ +import torch +import queue + +from pathlib import Path +from dataclasses import dataclass +from uuid import UUID + +from . import UPLOAD_DIR +from .db import DATABASE +from .sse import bus + +import sqlite3 + + +@dataclass +class DemucsTask: + database_id: UUID + orig_name: str + song_location: str + + +work_queue: queue.Queue[DemucsTask] = queue.Queue() + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return f"{torch.cuda.get_device_name()} (CUDA)" + + try: + cpu_info = Path("/proc/cpuinfo").read_text().splitlines() + cpu_info = [l.split(":") for l in cpu_info] + cpu_info = [l for l in cpu_info if l[0].strip() == "model name"] + cpu_info = [l[1].strip() for l in cpu_info][0] + return f"{cpu_info} (CPU)" + except: + return "Unknown CPU" + + +def get_queue_size() -> int: + return work_queue.unfinished_tasks + + +def report_task_status(task_id: UUID, status: str): + bus.announce({"task": str(task_id), "status": status}) + # short-lived sqlite connections since we spend most of the time + # on the worker thread just waiting for stuff + with sqlite3.connect(DATABASE) as conn: + conn.execute( + "UPDATE tasks SET status = ? WHERE id = ?", + [status, str(task_id)], + ) + conn.commit() + + +def report_task_output(task_id: UUID, track: str, output_path: Path): + bus.announce({"task": str(task_id), "output": track}) + # short-lived sqlite connections since we spend most of the time + # on the worker thread just waiting for stuff + with sqlite3.connect(DATABASE) as conn: + conn.execute( + "INSERT INTO task_outputs (task_id, track, output_path) VALUES (?, ?, ?)", + [str(task_id), track, str(output_path)], + ) + conn.commit() + + +def report_task_progress( + task_id: UUID, stage: int, stages: int, step: int, stage_steps: int +): + bus.announce( + { + "task": str(task_id), + "progress": { + "stage": stage, + "stages": stages, + "step": step, + "stage-steps": stage_steps, + }, + } + ) + + with sqlite3.connect(DATABASE) as conn: + conn.execute( + "INSERT OR REPLACE INTO task_progress (task_id, stage, stages, step, stage_steps) VALUES (?, ?, ?, ?, ?)", + [str(task_id), stage, stages, step, stage_steps], + ) + conn.commit() + + +from demucs.pretrained import get_model +from demucs.apply import apply_model, tqdm, BagOfModels +from demucs.separate import load_track +from demucs.audio import save_audio + +import subprocess + + +def convert_audio(source: Path, output: Path): + command = ["ffmpeg", "-y", "-loglevel", "panic"] + command += ["-i", str(source)] + command += [str(output)] + subprocess.run(command, check=True) + + +def process_task(task: DemucsTask, model): + task_id = task.database_id + + song_path = UPLOAD_DIR / task.song_location + + task_dir = UPLOAD_DIR / str(task.database_id) + task_dir.mkdir(parents=True, exist_ok=True) + + master_path = task_dir / f"{task.orig_name}-master.flac" + convert_audio(song_path, master_path) + report_task_output(task.database_id, "master", master_path) + + wav = load_track(song_path, model.audio_channels, model.samplerate) + + ref = wav.mean(0) + wav -= ref.mean() + wav /= ref.std() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # hook the demucs progress function + stages = 1 + if isinstance(model, BagOfModels): + stages = len(model.models) + stage = 0 + + def progress(futures, **kwargs): + nonlocal stage, task_id + + stage += 1 + stage_steps = len(futures) + + for step, future in enumerate(futures, start=1): + report_task_progress(task_id, stage, stages, step, stage_steps) + yield future + + tqdm.tqdm = progress + sources = apply_model( + model, + wav[None], + device=device, + shifts=1, + split=True, + overlap=0.25, + progress=True, + num_workers=0, + )[0] + + # denormalize amplitude / DC + sources *= ref.std() + sources += ref.mean() + + report_task_status(task.database_id, "saving") + + outputs = {} + for source, source_name in zip(sources, model.sources): + path = song_path.parent / f"{song_path.name}.{source_name}.wav" + save_audio(source, str(path), samplerate=model.samplerate) + outputs[source_name] = path + + song_path.unlink() + + for source, path in sorted(outputs.items()): + flac_path = task_dir / f"{task.orig_name}-{source}.flac" + convert_audio(path, flac_path) + report_task_output(task.database_id, source, flac_path) + path.unlink() # delete wav + + +def worker_main(): + model = get_model("htdemucs_ft") + model.eval() + + with sqlite3.connect(DATABASE) as conn: + cur = conn.execute( + "SELECT id, orig_name, song_location FROM tasks WHERE status NOT IN ('done', 'errored')" + ) + for id, orig_name, song_location in cur.fetchall(): + work_queue.put(DemucsTask(UUID(id), orig_name, song_location)) + + while True: + task = work_queue.get() + try: + report_task_status(task.database_id, "in progress") + process_task(task, model) + report_task_status(task.database_id, "done") + except BaseException as e: + report_task_status(task.database_id, "errored") + print(e) + work_queue.task_done() + + +def start_worker(): + import threading + + thread = threading.Thread(target=worker_main, daemon=True) + thread.start() + + return thread diff --git a/src/dev_run.py b/src/dev_run.py new file mode 100755 index 0000000..a0a5199 --- /dev/null +++ b/src/dev_run.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +from demucs_server import app + +if __name__ == "__main__": + app.run(port=8000, debug=True) diff --git a/src/migrations/0000_init.sql b/src/migrations/0000_init.sql new file mode 100644 index 0000000..76cf317 --- /dev/null +++ b/src/migrations/0000_init.sql @@ -0,0 +1,5 @@ +CREATE TABLE migrations ( + applied TEXT PRIMARY KEY +) STRICT; + +INSERT INTO migrations VALUES ('0000_init'); diff --git a/src/migrations/0001_auth.sql b/src/migrations/0001_auth.sql new file mode 100644 index 0000000..cbac6c6 --- /dev/null +++ b/src/migrations/0001_auth.sql @@ -0,0 +1,4 @@ +CREATE TABLE users ( + user_id TEXT NOT NULL PRIMARY KEY, + nickname TEXT NOT NULL +) STRICT; diff --git a/src/migrations/0002_tasks.sql b/src/migrations/0002_tasks.sql new file mode 100644 index 0000000..9b00869 --- /dev/null +++ b/src/migrations/0002_tasks.sql @@ -0,0 +1,6 @@ +CREATE TABLE tasks ( + id TEXT NOT NULL PRIMARY KEY, + orig_name TEXT NOT NULL, + status TEXT NOT NULL, -- "queued / in progress / done" + song_location TEXT NOT NULL +) STRICT; diff --git a/src/migrations/0003_task_outputs.sql b/src/migrations/0003_task_outputs.sql new file mode 100644 index 0000000..70279a5 --- /dev/null +++ b/src/migrations/0003_task_outputs.sql @@ -0,0 +1,5 @@ +CREATE TABLE task_outputs ( + task_id TEXT NOT NULL, + track TEXT NOT NULL, + output_path TEXT NOT NULL +) STRICT; diff --git a/src/migrations/0004_task_progress.sql b/src/migrations/0004_task_progress.sql new file mode 100644 index 0000000..203d01f --- /dev/null +++ b/src/migrations/0004_task_progress.sql @@ -0,0 +1,7 @@ +CREATE TABLE task_progress ( + task_id TEXT NOT NULL UNIQUE, + stage INTEGER NOT NULL, + stages INTEGER NOT NULL, + step INTEGER NOT NULL, + stage_steps INTEGER NOT NULL +) STRICT; diff --git a/src/run.sh b/src/run.sh new file mode 100755 index 0000000..aff82a0 --- /dev/null +++ b/src/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +gunicorn -b '127.0.0.2:8000' demucs_server:app