Initial commit

main
Charlotte Som 2023-07-06 21:15:54 +00:00
commit 3138e96ae0
31 changed files with 1243 additions and 0 deletions

9
.editorconfig Normal file
View File

@ -0,0 +1,9 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = false
insert_final_newline = true

13
.gitignore vendored Normal file
View File

@ -0,0 +1,13 @@
# python generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# venv
.venv
src/data.db
src/uploads

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.11.3

6
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}

19
README.md Normal file
View File

@ -0,0 +1,19 @@
# demucs-server
Web service which lets you demix stems using remote compute resources.
## Setup
```shell
$ rye sync
$ rye shell
[rye] $ cd src/
src/ [rye] $ ./dev_run.py
```
## Architecture
- flask application mostly lives in `demucs_server.routes`
- demucs / ffmpeg is run on a separate thread at `demucs_server.torch_worker`
- songs are queued with `torch_worker.work_queue.put(DemucsTask(..))`
- songs are processed by the thread worker and the database is written to report progress

View File

@ -0,0 +1,24 @@
server {
listen 443 ssl http2;
listen [::]:443 ssl http2;
server_name demucs.stems.zip;
ssl_certificate /etc/letsencrypt/live/demucs.stems.zip/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/demucs.stems.zip/privkey.pem;
location / {
proxy_pass http://127.0.0.2:8000/;
}
location ^~ /uploads/ {
gzip on;
alias /path/to/demucs-server/src/uploads/;
}
location = /sse {
proxy_read_timeout 300;
proxy_send_timeout 300;
proxy_pass http://127.0.0.2:8000/;
}
}

26
pyproject.toml Normal file
View File

@ -0,0 +1,26 @@
[project]
name = "demucs-server"
version = "0.1.0"
description = "Add a short description here"
dependencies = [
"demucs>=4.0.0",
"flask>=2.3.2",
"gunicorn>=20.1.0",
"colorama>=0.4.6",
"ipython>=8.12.2",
]
readme = "README.md"
requires-python = ">= 3.8"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"black>=23.3.0",
]
[tool.hatch.metadata]
allow-direct-references = true

82
requirements-dev.lock Normal file
View File

@ -0,0 +1,82 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
-e file:.
antlr4-python3-runtime==4.9.3
asttokens==2.2.1
backcall==0.2.0
black==23.3.0
blinker==1.6.2
click==8.1.3
cloudpickle==2.2.1
cmake==3.26.4
colorama==0.4.6
cython==0.29.36
decorator==5.1.1
demucs==4.0.0
diffq==0.2.4
dora-search==0.1.12
einops==0.6.1
executing==1.2.0
filelock==3.12.2
flask==2.3.2
gunicorn==20.1.0
ipython==8.14.0
itsdangerous==2.1.2
jedi==0.18.2
jinja2==3.1.2
julius==0.2.7
lameenc==1.5.1
lit==16.0.6
markupsafe==2.1.3
matplotlib-inline==0.1.6
mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.1
numpy==1.25.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
omegaconf==2.3.0
openunmix==1.2.1
packaging==23.1
parso==0.8.3
pathspec==0.11.1
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.8.0
prompt-toolkit==3.0.39
ptyprocess==0.7.0
pure-eval==0.2.2
pygments==2.15.1
pyyaml==6.0
retrying==1.3.4
six==1.16.0
stack-data==0.6.2
submitit==1.4.5
sympy==1.12
torch==2.0.1
torchaudio==2.0.2
tqdm==4.65.0
traitlets==5.9.0
treetable==0.2.5
triton==2.0.0
typing-extensions==4.7.1
wcwidth==0.2.6
werkzeug==2.3.6
wheel==0.40.0
# The following packages are considered to be unsafe in a requirements file:
setuptools==68.0.0

77
requirements.lock Normal file
View File

@ -0,0 +1,77 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
-e file:.
antlr4-python3-runtime==4.9.3
asttokens==2.2.1
backcall==0.2.0
blinker==1.6.2
click==8.1.3
cloudpickle==2.2.1
cmake==3.26.4
colorama==0.4.6
cython==0.29.36
decorator==5.1.1
demucs==4.0.0
diffq==0.2.4
dora-search==0.1.12
einops==0.6.1
executing==1.2.0
filelock==3.12.2
flask==2.3.2
gunicorn==20.1.0
ipython==8.14.0
itsdangerous==2.1.2
jedi==0.18.2
jinja2==3.1.2
julius==0.2.7
lameenc==1.5.1
lit==16.0.6
markupsafe==2.1.3
matplotlib-inline==0.1.6
mpmath==1.3.0
networkx==3.1
numpy==1.25.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
omegaconf==2.3.0
openunmix==1.2.1
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
prompt-toolkit==3.0.39
ptyprocess==0.7.0
pure-eval==0.2.2
pygments==2.15.1
pyyaml==6.0
retrying==1.3.4
six==1.16.0
stack-data==0.6.2
submitit==1.4.5
sympy==1.12
torch==2.0.1
torchaudio==2.0.2
tqdm==4.65.0
traitlets==5.9.0
treetable==0.2.5
triton==2.0.0
typing-extensions==4.7.1
wcwidth==0.2.6
werkzeug==2.3.6
wheel==0.40.0
# The following packages are considered to be unsafe in a requirements file:
setuptools==68.0.0

32
src/admin.py Executable file
View File

@ -0,0 +1,32 @@
#!/usr/bin/env python
import demucs_server
import uuid
import IPython
with demucs_server.app.app_context():
def create_user(name: str) -> str:
db = demucs_server.db.get_db()
user_id = str(uuid.uuid4())
db.execute(
"INSERT INTO users (user_id, nickname) VALUES (?, ?)", [user_id, name]
)
db.commit()
return user_id
def list_users() -> list["demucs_server.auth.User"]:
from demucs_server.auth import User
db = demucs_server.db.get_db()
users = []
cur = db.execute("SELECT * FROM users")
for user_id, nickname in cur.fetchall():
users.append(User(uuid.UUID(user_id), nickname))
cur.close()
return users
IPython.embed()

View File

@ -0,0 +1,22 @@
from flask import Flask
import colorama
from pathlib import Path
app = Flask(__name__)
colorama.init()
UPLOAD_DIR = Path(".") / "uploads"
UPLOAD_DIR.mkdir(exist_ok=True)
from . import db
with app.app_context():
from .migrations import apply_migrations
apply_migrations()
from . import routes
from .torch_worker import start_worker
start_worker()

25
src/demucs_server/auth.py Normal file
View File

@ -0,0 +1,25 @@
from dataclasses import dataclass
from uuid import UUID
from .db import get_db
@dataclass
class User:
user_id: UUID
nickname: str
def get_user(user_id: str) -> User | None:
try:
user_id = UUID(user_id)
except:
return None
db = get_db()
cur = db.execute("SELECT * FROM users WHERE user_id = ?", [str(user_id)])
user_row = cur.fetchone()
if user_row is None:
return None
return User(UUID(user_row[0]), user_row[1])

22
src/demucs_server/db.py Normal file
View File

@ -0,0 +1,22 @@
import sqlite3
from flask import g
import os
DATABASE = os.environ.get("DATABASE_PATH") or "data.db"
from . import app
def get_db():
db = getattr(g, "_database", None)
if db is None:
db = g._database = sqlite3.connect(DATABASE)
return db
@app.teardown_appcontext
def close_connection(exception):
db = getattr(g, "_database", None)
if db is not None:
db.close()

View File

@ -0,0 +1,29 @@
from .db import get_db
from pathlib import Path
import colorama
def apply_migrations():
migrations_dir = Path("./migrations")
if not migrations_dir.exists():
raise Exception("migrations directory was not found")
db = get_db()
already_run_migrations = set([])
try:
cur = db.execute("SELECT applied FROM migrations")
for (applied,) in cur.fetchall():
already_run_migrations.add(applied)
except:
pass
for migration in sorted(migrations_dir.iterdir()):
if migration.name not in already_run_migrations:
print(
f"{colorama.Fore.LIGHTBLUE_EX}[*]{colorama.Fore.RESET} "
f"Applied migration: {colorama.Fore.GREEN}{migration.name}{colorama.Fore.RESET}"
)
db.executescript(migration.read_text())
db.execute("INSERT INTO migrations VALUES (?)", [migration.name])
db.commit()

168
src/demucs_server/routes.py Normal file
View File

@ -0,0 +1,168 @@
from flask import render_template, request, redirect, send_file, Response
from werkzeug.utils import secure_filename
from werkzeug.security import safe_join
from uuid import uuid4
import json
from . import app, UPLOAD_DIR
from .auth import get_user
from .db import get_db
from .torch_worker import get_device_name, get_queue_size, work_queue, DemucsTask
from .sse import bus
@app.route("/")
@app.route("/from-url")
def index_page():
user_id = request.cookies.get("demucs_server_user_id")
user = user_id and get_user(user_id)
if user is None:
return render_template("login.html.j2")
status = {"device": get_device_name(), "queue_size": get_queue_size()}
if request.path == "/from-url":
return render_template("demix_ytdlp.html.j2", user=user, status=status)
return render_template("demix.html.j2", user=user, status=status)
@app.route("/login", methods=["GET", "POST"])
def login():
response = redirect("/")
user_id = request.form.get("user-id", type=str)
if user_id is not None:
response.set_cookie(
"demucs_server_user_id",
user_id,
samesite="Strict",
secure=True,
)
return response
@app.route("/logout", methods=["GET", "POST"])
def logout():
if request.method == "POST":
response = redirect("/")
response.delete_cookie("demucs_server_user_id")
return response
return render_template("logout.html.j2")
# TODO: demix routes (involves multipart form parsing!! maybe ytdlp!!)
@app.route("/demix", methods=["POST"])
def demix():
user_id = request.cookies.get("demucs_server_user_id")
user = user_id and get_user(user_id)
if user is None:
return "You are not logged in!", 403
audio_file = next(request.files.values(), None)
if audio_file is None:
return redirect("/")
orig_name = secure_filename(audio_file.filename).replace("_", "-")
task_id = uuid4()
song_location = f"{task_id}-{orig_name}"
orig_name = orig_name.rsplit(".", 1)[0]
db = get_db()
db.execute(
"INSERT INTO tasks (id, orig_name, status, song_location) VALUES (?, ?, ?, ?)",
[str(task_id), orig_name, "queued", song_location],
)
audio_file.save(UPLOAD_DIR / song_location)
db.commit()
demucs_task = DemucsTask(task_id, orig_name, song_location)
work_queue.put(demucs_task)
return redirect(f"/task/{task_id}")
@app.route("/task/<task_id>")
def task_view(task_id):
db = get_db()
try:
cur = db.execute("SELECT status, orig_name FROM tasks WHERE id = ?", [task_id])
(status, orig_name) = cur.fetchone()
except:
# TODO: real error page
return "Task with the given ID was not found", 404
outputs = []
try:
cur = db.execute(
"SELECT track, output_path FROM task_outputs WHERE task_id = ?", [task_id]
)
outputs = [
{"track": track, "path": path}
for (
track,
path,
) in cur.fetchall()
]
except:
pass
outputs = sorted(
outputs,
key=lambda o: ["master", "vocals", "other", "bass", "drums"].index(o["track"]),
)
progress = {}
try:
cur = db.execute(
"SELECT stage, stages, step, stage_steps FROM task_progress WHERE task_id = ?",
[task_id],
)
stage, stages, step, stage_steps = cur.fetchone()
progress = {
"stage": stage,
"stages": stages,
"step": step,
"stage-steps": stage_steps,
}
except:
pass
return render_template(
"task.html.j2",
task_id=task_id,
status=status,
orig_name=orig_name,
outputs=outputs,
progress=progress,
)
@app.route("/uploads/<path:upload>")
def send_upload(upload):
upload_file = safe_join(UPLOAD_DIR.absolute(), upload)
return send_file(upload_file, conditional=True)
@app.route("/sse")
def sse():
task_id = request.args.get("task_id")
if task_id is None:
return "bad request", 400
def stream():
queue = bus.listen()
while True:
message = queue.get()
if message.get("task") != task_id:
continue
message = json.dumps(message)
yield f"{message}\n\n"
return Response(
stream(),
mimetype="text/event-stream",
headers={"cache-control": "no-cache", "x-accel-buffering": "no"},
)

21
src/demucs_server/sse.py Normal file
View File

@ -0,0 +1,21 @@
import queue
class AnnounceBus:
def __init__(self):
self.listeners = []
def listen(self):
q = queue.Queue(maxsize=5)
self.listeners.append(q)
return q
def announce(self, msg):
for i in reversed(range(len(self.listeners))):
try:
self.listeners[i].put_nowait(msg)
except queue.Full:
del self.listeners[i]
bus = AnnounceBus()

View File

@ -0,0 +1,265 @@
:root {
--col-bg: #25181b;
--col-bg-dark: #1e1417;
--col-bg-light: #331f23;
--col-accent: #44cf6c;
--col-accent-dark: #32a287;
--col-accent-light: #a9fdac;
--col-fg: #f5fff2;
--font-main: "Inter", "Arial", sans-serif;
--font-mono: ui-monospace, "Cascadia Mono", "Segoe UI Mono", "Courier New",
monospace;
}
html {
color-scheme: dark;
font-family: var(--font-main);
font-size: 1.25rem;
color: var(--col-fg);
background-color: var(--col-bg);
}
html,
body {
padding: 0;
margin: 0;
}
body {
display: flex;
flex-direction: column;
min-height: 100vh;
}
main {
flex: 1;
width: calc(100% - 1em);
max-width: 80ch;
margin: 0 auto;
padding: 0.5em;
}
code {
font-family: var(--font-mono);
font-size: 1rem;
}
a {
color: var(--col-accent);
text-decoration: none;
}
a:hover {
border-bottom: 1px solid var(--col-accent);
}
.logo {
display: inline;
height: 3.5em;
}
header {
display: flex;
flex-direction: row;
justify-content: center;
align-items: center;
gap: 1em;
}
.tagline {
text-align: center;
margin-top: 0.25em;
}
header.vertical {
flex-direction: column;
gap: 0.5em;
}
header h1 {
margin: 0;
margin-bottom: 0.5em;
}
header + h2 {
margin-top: 0;
}
header small {
font-size: 1em;
}
form {
display: flex;
flex-direction: column;
padding: 1em;
border: 2.5px solid var(--col-accent-dark);
border-radius: 8px;
}
form :first-child {
margin-top: 0;
}
form :last-child {
margin-bottom: 0;
}
label {
margin-bottom: 0.5em;
}
input {
background-color: transparent;
border: 1.5px solid var(--col-accent-light);
outline: none;
border-radius: 8px;
font-family: var(--font-main);
font-size: 1em;
padding: 0.75rem 1rem;
margin-bottom: 0.75rem;
box-shadow: none;
width: calc(100% - 2rem);
color: var(--col-fg);
}
.file-upload {
margin: 0.75rem 0;
overflow-x: clip;
white-space: nowrap;
}
input[type="url"] {
margin-bottom: 0.25em;
}
.file-upload + p,
input[type="url"] + p {
margin-top: 0.25em;
padding-left: 1.5em;
}
input:active,
input:focus {
border-color: var(--col-accent);
}
input[type="file"] {
display: none;
}
button,
.file-upload-button {
font-family: var(--font-main);
font-size: 1em;
background-color: var(--col-accent-light);
border: 1.5px solid var(--col-accent-light);
border-radius: 8px;
outline: none;
color: var(--col-bg);
padding: 0.5rem;
}
.file-upload-button {
margin-right: 0.5em;
}
button:hover {
cursor: pointer;
background-color: var(--col-accent);
}
input::placeholder {
opacity: 0.33;
text-transform: lowercase;
}
dl {
display: flex;
flex-direction: row;
gap: 0.5ch;
margin: 0;
}
dt {
font-weight: bold;
}
dt::after {
content: ":";
}
dd {
margin: 0;
display: inline;
}
footer {
display: flex;
flex-direction: row;
justify-content: space-between;
padding: 1em;
background-color: var(--col-bg-dark);
}
nav ul {
list-style: none;
display: flex;
flex-direction: row;
gap: 0.5em;
margin: 0;
padding: 0;
}
nav ul li {
display: table-cell;
padding: 0.1em 0;
vertical-align: middle;
}
nav ul li + li {
border-inline-start: 1px solid var(--col-fg);
padding-inline-start: 0.5em;
}
.track-list {
padding: 0;
}
.track-list li {
display: flex;
flex-direction: row;
list-style: none;
padding: 0;
margin: 1em;
align-items: center;
justify-content: space-between;
}
.track-list a {
display: inline-block;
padding: 0.5em 0.5em;
margin-right: 1em;
border: 1px solid var(--col-accent);
border-radius: 4px;
width: 8em;
text-align: right;
}
.track-list a:hover {
background: var(--col-accent);
color: var(--col-bg);
}
.track-list audio {
flex: 1;
}

View File

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="stylesheet" href="/static/styles.css" />
<title>demucs.stems.zip</title>
</head>
<body>
<main>
<header>
<h1>demucs.stems.zip</h1>
</header>
{% block content %}{% endblock %}
</main>
</body>
</html>

View File

@ -0,0 +1,41 @@
{% extends "_base.html.j2" %}
{% block content %}
<h2>demix song</h2>
<p>
use <a href="https://github.com/facebookresearch/demucs">demucs</a>
to generate 'vocal', 'bass', 'drums', and 'other' tracks:
</p>
{% block form %}
<form action="/demix" method="POST" enctype="multipart/form-data">
<br>
<label class="file-upload">
<span class="file-upload-button">browse...</span>
<span class="file-upload-status">upload audio file (.aic, .aif, .flac, .m4a, .mp3, .ogg, .opus, .wav)</span>
<input type="file" name="song" accept="audio/*" required />
</label>
<!-- NYI TODO: <p>or <a href="/from-url">download from online</a> (e.g. youtube).</p> -->
<button>enqueue</button>
<ul class="status">
<li>songs in queue: <strong>{{ status['queue_size'] }}</strong></li>
<li>this node is using: <strong>{{ status['device'] }}</strong></li>
</ul>
</form>
<script>
const fileInput = document.querySelector(".file-upload input[type=file]")
fileInput.addEventListener("change", e => {
if (fileInput.files.length > 0) {
document.querySelector(".file-upload-status").textContent = `selected: ${fileInput.files[0].name}`
} else {
document.querySelector(".file-upload-status").textContent = `no file selected.`
}
});
</script>
{% endblock %}
{% endblock %}

View File

@ -0,0 +1,19 @@
{% extends "demix.html.j2" %}
{% block form %}
<form action="/demix-ytlp" method="POST" enctype="multipart/form-data">
<br>
<label for="url">song url:</label>
<input type="url" id="url" name="url" placeholder="https://youtu.be/…" />
<p>or <a href="/">upload an audio file</a>.</p>
<button>enqueue</button>
<ul class="status">
<li>songs in queue: <strong>{{ status['queue_size'] }}</strong></li>
<li>this node is using: <strong>{{ status['device'] }}</strong></li>
</ul>
</form>
{% endblock %}

View File

@ -0,0 +1,12 @@
{% extends "_base.html.j2" %}
{% block content %}
<h2>log in</h2>
<form action="/login" method="POST">
<label for="user-id">user id:</label>
<input type="password" name="user-id" id="user-id" placeholder="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" required />
<button>submit</button>
<p><code>demucs.stems.zip</code> is <em>invite-only</em>.</p>
</form>
{% endblock %}

View File

@ -0,0 +1,9 @@
{% extends "_base.html.j2" %}
{% block content %}
<form action="/logout" method="POST">
<h2>are you sure you want to log out?</h2>
<button>log out</button>
<p>you will need your user ID to log back in.</p>
</form>
{% endblock %}

View File

@ -0,0 +1,61 @@
{% extends "_base.html.j2" %}
{% block content %}
<h2>task</h2>
{% if status in ["queued"] %}
<p>your upload is currently in the queue, and hasn't started processing yet.</p>
{% elif status in ["in progress", "saving"] %}
<p>your upload is currently processing.</p>
{% endif %}
<p>status: <strong class="status">{{ status }}</strong></p>
<p>originally: <strong>{{ orig_name }}</strong></p>
<section class="progress-section" {% if status !="in progress" %}style="display: none" {% endif %}>
<h3>progress</h3>
<p>
stage:
<data id="stage">{{ progress.get('stage') or '' }}</data> /
<data id="stages">{{ progress.get('stages') or '' }}</data>
</p>
<p>
step:
<data id="step">{{ progress.get('step') or '' }}</data> /
<data id="stage-steps">{{ progress.get('stage-steps') or '' }}</data>
</p>
</section>
<script>
const source = new EventSource("/sse?task_id={{ task_id }}")
source.onmessage = (event) => {
console.log(event);
const message = JSON.parse(event.data);
if (message.status != null) {
document.querySelector("#status").textContent = message.status;
}
if (message.progress != null) {
document.querySelector("progress-section").style.display = "unset";
for (const item in message.progress) {
document.getElementById(item).textContent = message.progress[item];
}
}
};
</script>
{% if outputs %}
<ul class="track-list">
{% for output in outputs %}
<li>
<a href="/{{ output['path'] }}">{{ output['track'] }}</a>
<audio controls>
<source src="/{{ output['path'] }}" />
</audio>
</li>
{% endfor %}
</ul>
{% endif %}
{% endblock %}

View File

@ -0,0 +1,203 @@
import torch
import queue
from pathlib import Path
from dataclasses import dataclass
from uuid import UUID
from . import UPLOAD_DIR
from .db import DATABASE
from .sse import bus
import sqlite3
@dataclass
class DemucsTask:
database_id: UUID
orig_name: str
song_location: str
work_queue: queue.Queue[DemucsTask] = queue.Queue()
def get_device_name() -> str:
if torch.cuda.is_available():
return f"{torch.cuda.get_device_name()} (CUDA)"
try:
cpu_info = Path("/proc/cpuinfo").read_text().splitlines()
cpu_info = [l.split(":") for l in cpu_info]
cpu_info = [l for l in cpu_info if l[0].strip() == "model name"]
cpu_info = [l[1].strip() for l in cpu_info][0]
return f"{cpu_info} (CPU)"
except:
return "Unknown CPU"
def get_queue_size() -> int:
return work_queue.unfinished_tasks
def report_task_status(task_id: UUID, status: str):
bus.announce({"task": str(task_id), "status": status})
# short-lived sqlite connections since we spend most of the time
# on the worker thread just waiting for stuff
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"UPDATE tasks SET status = ? WHERE id = ?",
[status, str(task_id)],
)
conn.commit()
def report_task_output(task_id: UUID, track: str, output_path: Path):
bus.announce({"task": str(task_id), "output": track})
# short-lived sqlite connections since we spend most of the time
# on the worker thread just waiting for stuff
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"INSERT INTO task_outputs (task_id, track, output_path) VALUES (?, ?, ?)",
[str(task_id), track, str(output_path)],
)
conn.commit()
def report_task_progress(
task_id: UUID, stage: int, stages: int, step: int, stage_steps: int
):
bus.announce(
{
"task": str(task_id),
"progress": {
"stage": stage,
"stages": stages,
"step": step,
"stage-steps": stage_steps,
},
}
)
with sqlite3.connect(DATABASE) as conn:
conn.execute(
"INSERT OR REPLACE INTO task_progress (task_id, stage, stages, step, stage_steps) VALUES (?, ?, ?, ?, ?)",
[str(task_id), stage, stages, step, stage_steps],
)
conn.commit()
from demucs.pretrained import get_model
from demucs.apply import apply_model, tqdm, BagOfModels
from demucs.separate import load_track
from demucs.audio import save_audio
import subprocess
def convert_audio(source: Path, output: Path):
command = ["ffmpeg", "-y", "-loglevel", "panic"]
command += ["-i", str(source)]
command += [str(output)]
subprocess.run(command, check=True)
def process_task(task: DemucsTask, model):
task_id = task.database_id
song_path = UPLOAD_DIR / task.song_location
task_dir = UPLOAD_DIR / str(task.database_id)
task_dir.mkdir(parents=True, exist_ok=True)
master_path = task_dir / f"{task.orig_name}-master.flac"
convert_audio(song_path, master_path)
report_task_output(task.database_id, "master", master_path)
wav = load_track(song_path, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav -= ref.mean()
wav /= ref.std()
device = "cuda" if torch.cuda.is_available() else "cpu"
# hook the demucs progress function
stages = 1
if isinstance(model, BagOfModels):
stages = len(model.models)
stage = 0
def progress(futures, **kwargs):
nonlocal stage, task_id
stage += 1
stage_steps = len(futures)
for step, future in enumerate(futures, start=1):
report_task_progress(task_id, stage, stages, step, stage_steps)
yield future
tqdm.tqdm = progress
sources = apply_model(
model,
wav[None],
device=device,
shifts=1,
split=True,
overlap=0.25,
progress=True,
num_workers=0,
)[0]
# denormalize amplitude / DC
sources *= ref.std()
sources += ref.mean()
report_task_status(task.database_id, "saving")
outputs = {}
for source, source_name in zip(sources, model.sources):
path = song_path.parent / f"{song_path.name}.{source_name}.wav"
save_audio(source, str(path), samplerate=model.samplerate)
outputs[source_name] = path
song_path.unlink()
for source, path in sorted(outputs.items()):
flac_path = task_dir / f"{task.orig_name}-{source}.flac"
convert_audio(path, flac_path)
report_task_output(task.database_id, source, flac_path)
path.unlink() # delete wav
def worker_main():
model = get_model("htdemucs_ft")
model.eval()
with sqlite3.connect(DATABASE) as conn:
cur = conn.execute(
"SELECT id, orig_name, song_location FROM tasks WHERE status NOT IN ('done', 'errored')"
)
for id, orig_name, song_location in cur.fetchall():
work_queue.put(DemucsTask(UUID(id), orig_name, song_location))
while True:
task = work_queue.get()
try:
report_task_status(task.database_id, "in progress")
process_task(task, model)
report_task_status(task.database_id, "done")
except BaseException as e:
report_task_status(task.database_id, "errored")
print(e)
work_queue.task_done()
def start_worker():
import threading
thread = threading.Thread(target=worker_main, daemon=True)
thread.start()
return thread

6
src/dev_run.py Executable file
View File

@ -0,0 +1,6 @@
#!/usr/bin/env python
from demucs_server import app
if __name__ == "__main__":
app.run(port=8000, debug=True)

View File

@ -0,0 +1,5 @@
CREATE TABLE migrations (
applied TEXT PRIMARY KEY
) STRICT;
INSERT INTO migrations VALUES ('0000_init');

View File

@ -0,0 +1,4 @@
CREATE TABLE users (
user_id TEXT NOT NULL PRIMARY KEY,
nickname TEXT NOT NULL
) STRICT;

View File

@ -0,0 +1,6 @@
CREATE TABLE tasks (
id TEXT NOT NULL PRIMARY KEY,
orig_name TEXT NOT NULL,
status TEXT NOT NULL, -- "queued / in progress / done"
song_location TEXT NOT NULL
) STRICT;

View File

@ -0,0 +1,5 @@
CREATE TABLE task_outputs (
task_id TEXT NOT NULL,
track TEXT NOT NULL,
output_path TEXT NOT NULL
) STRICT;

View File

@ -0,0 +1,7 @@
CREATE TABLE task_progress (
task_id TEXT NOT NULL UNIQUE,
stage INTEGER NOT NULL,
stages INTEGER NOT NULL,
step INTEGER NOT NULL,
stage_steps INTEGER NOT NULL
) STRICT;

3
src/run.sh Executable file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env sh
gunicorn -b '127.0.0.2:8000' demucs_server:app