Initial commit
This commit is contained in:
commit
3138e96ae0
31 changed files with 1243 additions and 0 deletions
9
.editorconfig
Normal file
9
.editorconfig
Normal file
|
@ -0,0 +1,9 @@
|
|||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = false
|
||||
insert_final_newline = true
|
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
# python generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# venv
|
||||
.venv
|
||||
|
||||
src/data.db
|
||||
src/uploads
|
1
.python-version
Normal file
1
.python-version
Normal file
|
@ -0,0 +1 @@
|
|||
3.11.3
|
6
.vscode/settings.json
vendored
Normal file
6
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none"
|
||||
}
|
19
README.md
Normal file
19
README.md
Normal file
|
@ -0,0 +1,19 @@
|
|||
# demucs-server
|
||||
|
||||
Web service which lets you demix stems using remote compute resources.
|
||||
|
||||
## Setup
|
||||
|
||||
```shell
|
||||
$ rye sync
|
||||
$ rye shell
|
||||
[rye] $ cd src/
|
||||
src/ [rye] $ ./dev_run.py
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- flask application mostly lives in `demucs_server.routes`
|
||||
- demucs / ffmpeg is run on a separate thread at `demucs_server.torch_worker`
|
||||
- songs are queued with `torch_worker.work_queue.put(DemucsTask(..))`
|
||||
- songs are processed by the thread worker and the database is written to report progress
|
24
demucs-stems-zip.nginx.conf
Normal file
24
demucs-stems-zip.nginx.conf
Normal file
|
@ -0,0 +1,24 @@
|
|||
server {
|
||||
listen 443 ssl http2;
|
||||
listen [::]:443 ssl http2;
|
||||
|
||||
server_name demucs.stems.zip;
|
||||
|
||||
ssl_certificate /etc/letsencrypt/live/demucs.stems.zip/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/demucs.stems.zip/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.2:8000/;
|
||||
}
|
||||
|
||||
location ^~ /uploads/ {
|
||||
gzip on;
|
||||
alias /path/to/demucs-server/src/uploads/;
|
||||
}
|
||||
|
||||
location = /sse {
|
||||
proxy_read_timeout 300;
|
||||
proxy_send_timeout 300;
|
||||
proxy_pass http://127.0.0.2:8000/;
|
||||
}
|
||||
}
|
26
pyproject.toml
Normal file
26
pyproject.toml
Normal file
|
@ -0,0 +1,26 @@
|
|||
[project]
|
||||
name = "demucs-server"
|
||||
version = "0.1.0"
|
||||
description = "Add a short description here"
|
||||
dependencies = [
|
||||
"demucs>=4.0.0",
|
||||
"flask>=2.3.2",
|
||||
"gunicorn>=20.1.0",
|
||||
"colorama>=0.4.6",
|
||||
"ipython>=8.12.2",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = [
|
||||
"black>=23.3.0",
|
||||
]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
82
requirements-dev.lock
Normal file
82
requirements-dev.lock
Normal file
|
@ -0,0 +1,82 @@
|
|||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
|
||||
-e file:.
|
||||
antlr4-python3-runtime==4.9.3
|
||||
asttokens==2.2.1
|
||||
backcall==0.2.0
|
||||
black==23.3.0
|
||||
blinker==1.6.2
|
||||
click==8.1.3
|
||||
cloudpickle==2.2.1
|
||||
cmake==3.26.4
|
||||
colorama==0.4.6
|
||||
cython==0.29.36
|
||||
decorator==5.1.1
|
||||
demucs==4.0.0
|
||||
diffq==0.2.4
|
||||
dora-search==0.1.12
|
||||
einops==0.6.1
|
||||
executing==1.2.0
|
||||
filelock==3.12.2
|
||||
flask==2.3.2
|
||||
gunicorn==20.1.0
|
||||
ipython==8.14.0
|
||||
itsdangerous==2.1.2
|
||||
jedi==0.18.2
|
||||
jinja2==3.1.2
|
||||
julius==0.2.7
|
||||
lameenc==1.5.1
|
||||
lit==16.0.6
|
||||
markupsafe==2.1.3
|
||||
matplotlib-inline==0.1.6
|
||||
mpmath==1.3.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.1
|
||||
numpy==1.25.0
|
||||
nvidia-cublas-cu11==11.10.3.66
|
||||
nvidia-cuda-cupti-cu11==11.7.101
|
||||
nvidia-cuda-nvrtc-cu11==11.7.99
|
||||
nvidia-cuda-runtime-cu11==11.7.99
|
||||
nvidia-cudnn-cu11==8.5.0.96
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-curand-cu11==10.2.10.91
|
||||
nvidia-cusolver-cu11==11.4.0.1
|
||||
nvidia-cusparse-cu11==11.7.4.91
|
||||
nvidia-nccl-cu11==2.14.3
|
||||
nvidia-nvtx-cu11==11.7.91
|
||||
omegaconf==2.3.0
|
||||
openunmix==1.2.1
|
||||
packaging==23.1
|
||||
parso==0.8.3
|
||||
pathspec==0.11.1
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
platformdirs==3.8.0
|
||||
prompt-toolkit==3.0.39
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pygments==2.15.1
|
||||
pyyaml==6.0
|
||||
retrying==1.3.4
|
||||
six==1.16.0
|
||||
stack-data==0.6.2
|
||||
submitit==1.4.5
|
||||
sympy==1.12
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
tqdm==4.65.0
|
||||
traitlets==5.9.0
|
||||
treetable==0.2.5
|
||||
triton==2.0.0
|
||||
typing-extensions==4.7.1
|
||||
wcwidth==0.2.6
|
||||
werkzeug==2.3.6
|
||||
wheel==0.40.0
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==68.0.0
|
77
requirements.lock
Normal file
77
requirements.lock
Normal file
|
@ -0,0 +1,77 @@
|
|||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
|
||||
-e file:.
|
||||
antlr4-python3-runtime==4.9.3
|
||||
asttokens==2.2.1
|
||||
backcall==0.2.0
|
||||
blinker==1.6.2
|
||||
click==8.1.3
|
||||
cloudpickle==2.2.1
|
||||
cmake==3.26.4
|
||||
colorama==0.4.6
|
||||
cython==0.29.36
|
||||
decorator==5.1.1
|
||||
demucs==4.0.0
|
||||
diffq==0.2.4
|
||||
dora-search==0.1.12
|
||||
einops==0.6.1
|
||||
executing==1.2.0
|
||||
filelock==3.12.2
|
||||
flask==2.3.2
|
||||
gunicorn==20.1.0
|
||||
ipython==8.14.0
|
||||
itsdangerous==2.1.2
|
||||
jedi==0.18.2
|
||||
jinja2==3.1.2
|
||||
julius==0.2.7
|
||||
lameenc==1.5.1
|
||||
lit==16.0.6
|
||||
markupsafe==2.1.3
|
||||
matplotlib-inline==0.1.6
|
||||
mpmath==1.3.0
|
||||
networkx==3.1
|
||||
numpy==1.25.0
|
||||
nvidia-cublas-cu11==11.10.3.66
|
||||
nvidia-cuda-cupti-cu11==11.7.101
|
||||
nvidia-cuda-nvrtc-cu11==11.7.99
|
||||
nvidia-cuda-runtime-cu11==11.7.99
|
||||
nvidia-cudnn-cu11==8.5.0.96
|
||||
nvidia-cufft-cu11==10.9.0.58
|
||||
nvidia-curand-cu11==10.2.10.91
|
||||
nvidia-cusolver-cu11==11.4.0.1
|
||||
nvidia-cusparse-cu11==11.7.4.91
|
||||
nvidia-nccl-cu11==2.14.3
|
||||
nvidia-nvtx-cu11==11.7.91
|
||||
omegaconf==2.3.0
|
||||
openunmix==1.2.1
|
||||
parso==0.8.3
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
prompt-toolkit==3.0.39
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pygments==2.15.1
|
||||
pyyaml==6.0
|
||||
retrying==1.3.4
|
||||
six==1.16.0
|
||||
stack-data==0.6.2
|
||||
submitit==1.4.5
|
||||
sympy==1.12
|
||||
torch==2.0.1
|
||||
torchaudio==2.0.2
|
||||
tqdm==4.65.0
|
||||
traitlets==5.9.0
|
||||
treetable==0.2.5
|
||||
triton==2.0.0
|
||||
typing-extensions==4.7.1
|
||||
wcwidth==0.2.6
|
||||
werkzeug==2.3.6
|
||||
wheel==0.40.0
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==68.0.0
|
32
src/admin.py
Executable file
32
src/admin.py
Executable file
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import demucs_server
|
||||
|
||||
import uuid
|
||||
import IPython
|
||||
|
||||
with demucs_server.app.app_context():
|
||||
|
||||
def create_user(name: str) -> str:
|
||||
db = demucs_server.db.get_db()
|
||||
user_id = str(uuid.uuid4())
|
||||
db.execute(
|
||||
"INSERT INTO users (user_id, nickname) VALUES (?, ?)", [user_id, name]
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return user_id
|
||||
|
||||
def list_users() -> list["demucs_server.auth.User"]:
|
||||
from demucs_server.auth import User
|
||||
|
||||
db = demucs_server.db.get_db()
|
||||
users = []
|
||||
cur = db.execute("SELECT * FROM users")
|
||||
for user_id, nickname in cur.fetchall():
|
||||
users.append(User(uuid.UUID(user_id), nickname))
|
||||
cur.close()
|
||||
|
||||
return users
|
||||
|
||||
IPython.embed()
|
22
src/demucs_server/__init__.py
Normal file
22
src/demucs_server/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from flask import Flask
|
||||
import colorama
|
||||
from pathlib import Path
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
colorama.init()
|
||||
UPLOAD_DIR = Path(".") / "uploads"
|
||||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
|
||||
from . import db
|
||||
|
||||
with app.app_context():
|
||||
from .migrations import apply_migrations
|
||||
|
||||
apply_migrations()
|
||||
|
||||
from . import routes
|
||||
|
||||
from .torch_worker import start_worker
|
||||
|
||||
start_worker()
|
25
src/demucs_server/auth.py
Normal file
25
src/demucs_server/auth.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from .db import get_db
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
user_id: UUID
|
||||
nickname: str
|
||||
|
||||
|
||||
def get_user(user_id: str) -> User | None:
|
||||
try:
|
||||
user_id = UUID(user_id)
|
||||
except:
|
||||
return None
|
||||
|
||||
db = get_db()
|
||||
cur = db.execute("SELECT * FROM users WHERE user_id = ?", [str(user_id)])
|
||||
user_row = cur.fetchone()
|
||||
if user_row is None:
|
||||
return None
|
||||
|
||||
return User(UUID(user_row[0]), user_row[1])
|
22
src/demucs_server/db.py
Normal file
22
src/demucs_server/db.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import sqlite3
|
||||
from flask import g
|
||||
|
||||
import os
|
||||
|
||||
DATABASE = os.environ.get("DATABASE_PATH") or "data.db"
|
||||
|
||||
from . import app
|
||||
|
||||
|
||||
def get_db():
|
||||
db = getattr(g, "_database", None)
|
||||
if db is None:
|
||||
db = g._database = sqlite3.connect(DATABASE)
|
||||
return db
|
||||
|
||||
|
||||
@app.teardown_appcontext
|
||||
def close_connection(exception):
|
||||
db = getattr(g, "_database", None)
|
||||
if db is not None:
|
||||
db.close()
|
29
src/demucs_server/migrations.py
Normal file
29
src/demucs_server/migrations.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from .db import get_db
|
||||
from pathlib import Path
|
||||
|
||||
import colorama
|
||||
|
||||
|
||||
def apply_migrations():
|
||||
migrations_dir = Path("./migrations")
|
||||
if not migrations_dir.exists():
|
||||
raise Exception("migrations directory was not found")
|
||||
|
||||
db = get_db()
|
||||
already_run_migrations = set([])
|
||||
try:
|
||||
cur = db.execute("SELECT applied FROM migrations")
|
||||
for (applied,) in cur.fetchall():
|
||||
already_run_migrations.add(applied)
|
||||
except:
|
||||
pass
|
||||
|
||||
for migration in sorted(migrations_dir.iterdir()):
|
||||
if migration.name not in already_run_migrations:
|
||||
print(
|
||||
f"{colorama.Fore.LIGHTBLUE_EX}[*]{colorama.Fore.RESET} "
|
||||
f"Applied migration: {colorama.Fore.GREEN}{migration.name}{colorama.Fore.RESET}"
|
||||
)
|
||||
db.executescript(migration.read_text())
|
||||
db.execute("INSERT INTO migrations VALUES (?)", [migration.name])
|
||||
db.commit()
|
168
src/demucs_server/routes.py
Normal file
168
src/demucs_server/routes.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
from flask import render_template, request, redirect, send_file, Response
|
||||
|
||||
from werkzeug.utils import secure_filename
|
||||
from werkzeug.security import safe_join
|
||||
|
||||
from uuid import uuid4
|
||||
import json
|
||||
|
||||
from . import app, UPLOAD_DIR
|
||||
from .auth import get_user
|
||||
from .db import get_db
|
||||
from .torch_worker import get_device_name, get_queue_size, work_queue, DemucsTask
|
||||
from .sse import bus
|
||||
|
||||
|
||||
@app.route("/")
|
||||
@app.route("/from-url")
|
||||
def index_page():
|
||||
user_id = request.cookies.get("demucs_server_user_id")
|
||||
user = user_id and get_user(user_id)
|
||||
if user is None:
|
||||
return render_template("login.html.j2")
|
||||
|
||||
status = {"device": get_device_name(), "queue_size": get_queue_size()}
|
||||
|
||||
if request.path == "/from-url":
|
||||
return render_template("demix_ytdlp.html.j2", user=user, status=status)
|
||||
return render_template("demix.html.j2", user=user, status=status)
|
||||
|
||||
|
||||
@app.route("/login", methods=["GET", "POST"])
|
||||
def login():
|
||||
response = redirect("/")
|
||||
user_id = request.form.get("user-id", type=str)
|
||||
if user_id is not None:
|
||||
response.set_cookie(
|
||||
"demucs_server_user_id",
|
||||
user_id,
|
||||
samesite="Strict",
|
||||
secure=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/logout", methods=["GET", "POST"])
|
||||
def logout():
|
||||
if request.method == "POST":
|
||||
response = redirect("/")
|
||||
response.delete_cookie("demucs_server_user_id")
|
||||
return response
|
||||
return render_template("logout.html.j2")
|
||||
|
||||
|
||||
# TODO: demix routes (involves multipart form parsing!! maybe ytdlp!!)
|
||||
@app.route("/demix", methods=["POST"])
|
||||
def demix():
|
||||
user_id = request.cookies.get("demucs_server_user_id")
|
||||
user = user_id and get_user(user_id)
|
||||
if user is None:
|
||||
return "You are not logged in!", 403
|
||||
|
||||
audio_file = next(request.files.values(), None)
|
||||
if audio_file is None:
|
||||
return redirect("/")
|
||||
|
||||
orig_name = secure_filename(audio_file.filename).replace("_", "-")
|
||||
task_id = uuid4()
|
||||
song_location = f"{task_id}-{orig_name}"
|
||||
orig_name = orig_name.rsplit(".", 1)[0]
|
||||
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO tasks (id, orig_name, status, song_location) VALUES (?, ?, ?, ?)",
|
||||
[str(task_id), orig_name, "queued", song_location],
|
||||
)
|
||||
audio_file.save(UPLOAD_DIR / song_location)
|
||||
db.commit()
|
||||
|
||||
demucs_task = DemucsTask(task_id, orig_name, song_location)
|
||||
work_queue.put(demucs_task)
|
||||
|
||||
return redirect(f"/task/{task_id}")
|
||||
|
||||
|
||||
@app.route("/task/<task_id>")
|
||||
def task_view(task_id):
|
||||
db = get_db()
|
||||
|
||||
try:
|
||||
cur = db.execute("SELECT status, orig_name FROM tasks WHERE id = ?", [task_id])
|
||||
(status, orig_name) = cur.fetchone()
|
||||
except:
|
||||
# TODO: real error page
|
||||
return "Task with the given ID was not found", 404
|
||||
|
||||
outputs = []
|
||||
try:
|
||||
cur = db.execute(
|
||||
"SELECT track, output_path FROM task_outputs WHERE task_id = ?", [task_id]
|
||||
)
|
||||
outputs = [
|
||||
{"track": track, "path": path}
|
||||
for (
|
||||
track,
|
||||
path,
|
||||
) in cur.fetchall()
|
||||
]
|
||||
except:
|
||||
pass
|
||||
|
||||
outputs = sorted(
|
||||
outputs,
|
||||
key=lambda o: ["master", "vocals", "other", "bass", "drums"].index(o["track"]),
|
||||
)
|
||||
|
||||
progress = {}
|
||||
try:
|
||||
cur = db.execute(
|
||||
"SELECT stage, stages, step, stage_steps FROM task_progress WHERE task_id = ?",
|
||||
[task_id],
|
||||
)
|
||||
stage, stages, step, stage_steps = cur.fetchone()
|
||||
progress = {
|
||||
"stage": stage,
|
||||
"stages": stages,
|
||||
"step": step,
|
||||
"stage-steps": stage_steps,
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
return render_template(
|
||||
"task.html.j2",
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
orig_name=orig_name,
|
||||
outputs=outputs,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/uploads/<path:upload>")
|
||||
def send_upload(upload):
|
||||
upload_file = safe_join(UPLOAD_DIR.absolute(), upload)
|
||||
return send_file(upload_file, conditional=True)
|
||||
|
||||
|
||||
@app.route("/sse")
|
||||
def sse():
|
||||
task_id = request.args.get("task_id")
|
||||
if task_id is None:
|
||||
return "bad request", 400
|
||||
|
||||
def stream():
|
||||
queue = bus.listen()
|
||||
while True:
|
||||
message = queue.get()
|
||||
if message.get("task") != task_id:
|
||||
continue
|
||||
message = json.dumps(message)
|
||||
yield f"{message}\n\n"
|
||||
|
||||
return Response(
|
||||
stream(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"cache-control": "no-cache", "x-accel-buffering": "no"},
|
||||
)
|
21
src/demucs_server/sse.py
Normal file
21
src/demucs_server/sse.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import queue
|
||||
|
||||
|
||||
class AnnounceBus:
|
||||
def __init__(self):
|
||||
self.listeners = []
|
||||
|
||||
def listen(self):
|
||||
q = queue.Queue(maxsize=5)
|
||||
self.listeners.append(q)
|
||||
return q
|
||||
|
||||
def announce(self, msg):
|
||||
for i in reversed(range(len(self.listeners))):
|
||||
try:
|
||||
self.listeners[i].put_nowait(msg)
|
||||
except queue.Full:
|
||||
del self.listeners[i]
|
||||
|
||||
|
||||
bus = AnnounceBus()
|
265
src/demucs_server/static/styles.css
Normal file
265
src/demucs_server/static/styles.css
Normal file
|
@ -0,0 +1,265 @@
|
|||
:root {
|
||||
--col-bg: #25181b;
|
||||
--col-bg-dark: #1e1417;
|
||||
--col-bg-light: #331f23;
|
||||
--col-accent: #44cf6c;
|
||||
--col-accent-dark: #32a287;
|
||||
--col-accent-light: #a9fdac;
|
||||
--col-fg: #f5fff2;
|
||||
|
||||
--font-main: "Inter", "Arial", sans-serif;
|
||||
--font-mono: ui-monospace, "Cascadia Mono", "Segoe UI Mono", "Courier New",
|
||||
monospace;
|
||||
}
|
||||
|
||||
html {
|
||||
color-scheme: dark;
|
||||
font-family: var(--font-main);
|
||||
font-size: 1.25rem;
|
||||
color: var(--col-fg);
|
||||
background-color: var(--col-bg);
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
main {
|
||||
flex: 1;
|
||||
width: calc(100% - 1em);
|
||||
max-width: 80ch;
|
||||
margin: 0 auto;
|
||||
padding: 0.5em;
|
||||
}
|
||||
|
||||
code {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
a {
|
||||
color: var(--col-accent);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
a:hover {
|
||||
border-bottom: 1px solid var(--col-accent);
|
||||
}
|
||||
|
||||
.logo {
|
||||
display: inline;
|
||||
height: 3.5em;
|
||||
}
|
||||
|
||||
header {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 1em;
|
||||
}
|
||||
|
||||
.tagline {
|
||||
text-align: center;
|
||||
margin-top: 0.25em;
|
||||
}
|
||||
|
||||
header.vertical {
|
||||
flex-direction: column;
|
||||
gap: 0.5em;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
margin: 0;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
header + h2 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
header small {
|
||||
font-size: 1em;
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
padding: 1em;
|
||||
border: 2.5px solid var(--col-accent-dark);
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
form :first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
form :last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
label {
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
input {
|
||||
background-color: transparent;
|
||||
border: 1.5px solid var(--col-accent-light);
|
||||
outline: none;
|
||||
border-radius: 8px;
|
||||
|
||||
font-family: var(--font-main);
|
||||
font-size: 1em;
|
||||
padding: 0.75rem 1rem;
|
||||
margin-bottom: 0.75rem;
|
||||
|
||||
box-shadow: none;
|
||||
width: calc(100% - 2rem);
|
||||
|
||||
color: var(--col-fg);
|
||||
}
|
||||
|
||||
.file-upload {
|
||||
margin: 0.75rem 0;
|
||||
overflow-x: clip;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
input[type="url"] {
|
||||
margin-bottom: 0.25em;
|
||||
}
|
||||
|
||||
.file-upload + p,
|
||||
input[type="url"] + p {
|
||||
margin-top: 0.25em;
|
||||
padding-left: 1.5em;
|
||||
}
|
||||
|
||||
input:active,
|
||||
input:focus {
|
||||
border-color: var(--col-accent);
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
button,
|
||||
.file-upload-button {
|
||||
font-family: var(--font-main);
|
||||
font-size: 1em;
|
||||
|
||||
background-color: var(--col-accent-light);
|
||||
border: 1.5px solid var(--col-accent-light);
|
||||
border-radius: 8px;
|
||||
|
||||
outline: none;
|
||||
color: var(--col-bg);
|
||||
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
.file-upload-button {
|
||||
margin-right: 0.5em;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
cursor: pointer;
|
||||
background-color: var(--col-accent);
|
||||
}
|
||||
|
||||
input::placeholder {
|
||||
opacity: 0.33;
|
||||
text-transform: lowercase;
|
||||
}
|
||||
|
||||
dl {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 0.5ch;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
dt {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
dt::after {
|
||||
content: ":";
|
||||
}
|
||||
|
||||
dd {
|
||||
margin: 0;
|
||||
display: inline;
|
||||
}
|
||||
|
||||
footer {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
|
||||
padding: 1em;
|
||||
background-color: var(--col-bg-dark);
|
||||
}
|
||||
|
||||
nav ul {
|
||||
list-style: none;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 0.5em;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
nav ul li {
|
||||
display: table-cell;
|
||||
padding: 0.1em 0;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
nav ul li + li {
|
||||
border-inline-start: 1px solid var(--col-fg);
|
||||
padding-inline-start: 0.5em;
|
||||
}
|
||||
|
||||
.track-list {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.track-list li {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
margin: 1em;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.track-list a {
|
||||
display: inline-block;
|
||||
padding: 0.5em 0.5em;
|
||||
margin-right: 1em;
|
||||
border: 1px solid var(--col-accent);
|
||||
border-radius: 4px;
|
||||
width: 8em;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.track-list a:hover {
|
||||
background: var(--col-accent);
|
||||
color: var(--col-bg);
|
||||
}
|
||||
|
||||
.track-list audio {
|
||||
flex: 1;
|
||||
}
|
21
src/demucs_server/templates/_base.html.j2
Normal file
21
src/demucs_server/templates/_base.html.j2
Normal file
|
@ -0,0 +1,21 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<link rel="stylesheet" href="/static/styles.css" />
|
||||
<title>demucs.stems.zip</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<main>
|
||||
<header>
|
||||
<h1>demucs.stems.zip</h1>
|
||||
</header>
|
||||
|
||||
{% block content %}{% endblock %}
|
||||
</main>
|
||||
</body>
|
||||
|
||||
</html>
|
41
src/demucs_server/templates/demix.html.j2
Normal file
41
src/demucs_server/templates/demix.html.j2
Normal file
|
@ -0,0 +1,41 @@
|
|||
{% extends "_base.html.j2" %}
|
||||
|
||||
{% block content %}
|
||||
<h2>demix song</h2>
|
||||
<p>
|
||||
use <a href="https://github.com/facebookresearch/demucs">demucs</a>
|
||||
to generate 'vocal', 'bass', 'drums', and 'other' tracks:
|
||||
</p>
|
||||
|
||||
{% block form %}
|
||||
<form action="/demix" method="POST" enctype="multipart/form-data">
|
||||
<br>
|
||||
|
||||
<label class="file-upload">
|
||||
<span class="file-upload-button">browse...</span>
|
||||
<span class="file-upload-status">upload audio file (.aic, .aif, .flac, .m4a, .mp3, .ogg, .opus, .wav)</span>
|
||||
<input type="file" name="song" accept="audio/*" required />
|
||||
</label>
|
||||
|
||||
<!-- NYI TODO: <p>or <a href="/from-url">download from online</a> (e.g. youtube).</p> -->
|
||||
|
||||
<button>enqueue</button>
|
||||
|
||||
<ul class="status">
|
||||
<li>songs in queue: <strong>{{ status['queue_size'] }}</strong></li>
|
||||
<li>this node is using: <strong>{{ status['device'] }}</strong></li>
|
||||
</ul>
|
||||
</form>
|
||||
|
||||
<script>
|
||||
const fileInput = document.querySelector(".file-upload input[type=file]")
|
||||
fileInput.addEventListener("change", e => {
|
||||
if (fileInput.files.length > 0) {
|
||||
document.querySelector(".file-upload-status").textContent = `selected: ${fileInput.files[0].name}`
|
||||
} else {
|
||||
document.querySelector(".file-upload-status").textContent = `no file selected.`
|
||||
}
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
{% endblock %}
|
19
src/demucs_server/templates/demix_ytdlp.html.j2
Normal file
19
src/demucs_server/templates/demix_ytdlp.html.j2
Normal file
|
@ -0,0 +1,19 @@
|
|||
{% extends "demix.html.j2" %}
|
||||
|
||||
{% block form %}
|
||||
<form action="/demix-ytlp" method="POST" enctype="multipart/form-data">
|
||||
<br>
|
||||
|
||||
<label for="url">song url:</label>
|
||||
<input type="url" id="url" name="url" placeholder="https://youtu.be/…" />
|
||||
|
||||
<p>or <a href="/">upload an audio file</a>.</p>
|
||||
|
||||
<button>enqueue</button>
|
||||
|
||||
<ul class="status">
|
||||
<li>songs in queue: <strong>{{ status['queue_size'] }}</strong></li>
|
||||
<li>this node is using: <strong>{{ status['device'] }}</strong></li>
|
||||
</ul>
|
||||
</form>
|
||||
{% endblock %}
|
12
src/demucs_server/templates/login.html.j2
Normal file
12
src/demucs_server/templates/login.html.j2
Normal file
|
@ -0,0 +1,12 @@
|
|||
{% extends "_base.html.j2" %}
|
||||
|
||||
{% block content %}
|
||||
<h2>log in</h2>
|
||||
|
||||
<form action="/login" method="POST">
|
||||
<label for="user-id">user id:</label>
|
||||
<input type="password" name="user-id" id="user-id" placeholder="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" required />
|
||||
<button>submit</button>
|
||||
<p><code>demucs.stems.zip</code> is <em>invite-only</em>.</p>
|
||||
</form>
|
||||
{% endblock %}
|
9
src/demucs_server/templates/logout.html.j2
Normal file
9
src/demucs_server/templates/logout.html.j2
Normal file
|
@ -0,0 +1,9 @@
|
|||
{% extends "_base.html.j2" %}
|
||||
|
||||
{% block content %}
|
||||
<form action="/logout" method="POST">
|
||||
<h2>are you sure you want to log out?</h2>
|
||||
<button>log out</button>
|
||||
<p>you will need your user ID to log back in.</p>
|
||||
</form>
|
||||
{% endblock %}
|
61
src/demucs_server/templates/task.html.j2
Normal file
61
src/demucs_server/templates/task.html.j2
Normal file
|
@ -0,0 +1,61 @@
|
|||
{% extends "_base.html.j2" %}
|
||||
|
||||
{% block content %}
|
||||
<h2>task</h2>
|
||||
|
||||
{% if status in ["queued"] %}
|
||||
<p>your upload is currently in the queue, and hasn't started processing yet.</p>
|
||||
{% elif status in ["in progress", "saving"] %}
|
||||
<p>your upload is currently processing.</p>
|
||||
{% endif %}
|
||||
|
||||
<p>status: <strong class="status">{{ status }}</strong></p>
|
||||
<p>originally: <strong>{{ orig_name }}</strong></p>
|
||||
|
||||
<section class="progress-section" {% if status !="in progress" %}style="display: none" {% endif %}>
|
||||
<h3>progress</h3>
|
||||
<p>
|
||||
stage:
|
||||
<data id="stage">{{ progress.get('stage') or '' }}</data> /
|
||||
<data id="stages">{{ progress.get('stages') or '' }}</data>
|
||||
</p>
|
||||
<p>
|
||||
step:
|
||||
<data id="step">{{ progress.get('step') or '' }}</data> /
|
||||
<data id="stage-steps">{{ progress.get('stage-steps') or '' }}</data>
|
||||
</p>
|
||||
</section>
|
||||
|
||||
<script>
|
||||
const source = new EventSource("/sse?task_id={{ task_id }}")
|
||||
source.onmessage = (event) => {
|
||||
console.log(event);
|
||||
|
||||
const message = JSON.parse(event.data);
|
||||
if (message.status != null) {
|
||||
document.querySelector("#status").textContent = message.status;
|
||||
}
|
||||
|
||||
if (message.progress != null) {
|
||||
document.querySelector("progress-section").style.display = "unset";
|
||||
for (const item in message.progress) {
|
||||
document.getElementById(item).textContent = message.progress[item];
|
||||
}
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
{% if outputs %}
|
||||
<ul class="track-list">
|
||||
{% for output in outputs %}
|
||||
<li>
|
||||
<a href="/{{ output['path'] }}">{{ output['track'] }}</a>
|
||||
<audio controls>
|
||||
<source src="/{{ output['path'] }}" />
|
||||
</audio>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% endif %}
|
||||
|
||||
{% endblock %}
|
203
src/demucs_server/torch_worker.py
Normal file
203
src/demucs_server/torch_worker.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
import torch
|
||||
import queue
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from . import UPLOAD_DIR
|
||||
from .db import DATABASE
|
||||
from .sse import bus
|
||||
|
||||
import sqlite3
|
||||
|
||||
|
||||
@dataclass
|
||||
class DemucsTask:
|
||||
database_id: UUID
|
||||
orig_name: str
|
||||
song_location: str
|
||||
|
||||
|
||||
work_queue: queue.Queue[DemucsTask] = queue.Queue()
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
if torch.cuda.is_available():
|
||||
return f"{torch.cuda.get_device_name()} (CUDA)"
|
||||
|
||||
try:
|
||||
cpu_info = Path("/proc/cpuinfo").read_text().splitlines()
|
||||
cpu_info = [l.split(":") for l in cpu_info]
|
||||
cpu_info = [l for l in cpu_info if l[0].strip() == "model name"]
|
||||
cpu_info = [l[1].strip() for l in cpu_info][0]
|
||||
return f"{cpu_info} (CPU)"
|
||||
except:
|
||||
return "Unknown CPU"
|
||||
|
||||
|
||||
def get_queue_size() -> int:
|
||||
return work_queue.unfinished_tasks
|
||||
|
||||
|
||||
def report_task_status(task_id: UUID, status: str):
|
||||
bus.announce({"task": str(task_id), "status": status})
|
||||
# short-lived sqlite connections since we spend most of the time
|
||||
# on the worker thread just waiting for stuff
|
||||
with sqlite3.connect(DATABASE) as conn:
|
||||
conn.execute(
|
||||
"UPDATE tasks SET status = ? WHERE id = ?",
|
||||
[status, str(task_id)],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def report_task_output(task_id: UUID, track: str, output_path: Path):
|
||||
bus.announce({"task": str(task_id), "output": track})
|
||||
# short-lived sqlite connections since we spend most of the time
|
||||
# on the worker thread just waiting for stuff
|
||||
with sqlite3.connect(DATABASE) as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO task_outputs (task_id, track, output_path) VALUES (?, ?, ?)",
|
||||
[str(task_id), track, str(output_path)],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def report_task_progress(
|
||||
task_id: UUID, stage: int, stages: int, step: int, stage_steps: int
|
||||
):
|
||||
bus.announce(
|
||||
{
|
||||
"task": str(task_id),
|
||||
"progress": {
|
||||
"stage": stage,
|
||||
"stages": stages,
|
||||
"step": step,
|
||||
"stage-steps": stage_steps,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with sqlite3.connect(DATABASE) as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO task_progress (task_id, stage, stages, step, stage_steps) VALUES (?, ?, ?, ?, ?)",
|
||||
[str(task_id), stage, stages, step, stage_steps],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
from demucs.pretrained import get_model
|
||||
from demucs.apply import apply_model, tqdm, BagOfModels
|
||||
from demucs.separate import load_track
|
||||
from demucs.audio import save_audio
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
def convert_audio(source: Path, output: Path):
|
||||
command = ["ffmpeg", "-y", "-loglevel", "panic"]
|
||||
command += ["-i", str(source)]
|
||||
command += [str(output)]
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
|
||||
def process_task(task: DemucsTask, model):
|
||||
task_id = task.database_id
|
||||
|
||||
song_path = UPLOAD_DIR / task.song_location
|
||||
|
||||
task_dir = UPLOAD_DIR / str(task.database_id)
|
||||
task_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
master_path = task_dir / f"{task.orig_name}-master.flac"
|
||||
convert_audio(song_path, master_path)
|
||||
report_task_output(task.database_id, "master", master_path)
|
||||
|
||||
wav = load_track(song_path, model.audio_channels, model.samplerate)
|
||||
|
||||
ref = wav.mean(0)
|
||||
wav -= ref.mean()
|
||||
wav /= ref.std()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# hook the demucs progress function
|
||||
stages = 1
|
||||
if isinstance(model, BagOfModels):
|
||||
stages = len(model.models)
|
||||
stage = 0
|
||||
|
||||
def progress(futures, **kwargs):
|
||||
nonlocal stage, task_id
|
||||
|
||||
stage += 1
|
||||
stage_steps = len(futures)
|
||||
|
||||
for step, future in enumerate(futures, start=1):
|
||||
report_task_progress(task_id, stage, stages, step, stage_steps)
|
||||
yield future
|
||||
|
||||
tqdm.tqdm = progress
|
||||
sources = apply_model(
|
||||
model,
|
||||
wav[None],
|
||||
device=device,
|
||||
shifts=1,
|
||||
split=True,
|
||||
overlap=0.25,
|
||||
progress=True,
|
||||
num_workers=0,
|
||||
)[0]
|
||||
|
||||
# denormalize amplitude / DC
|
||||
sources *= ref.std()
|
||||
sources += ref.mean()
|
||||
|
||||
report_task_status(task.database_id, "saving")
|
||||
|
||||
outputs = {}
|
||||
for source, source_name in zip(sources, model.sources):
|
||||
path = song_path.parent / f"{song_path.name}.{source_name}.wav"
|
||||
save_audio(source, str(path), samplerate=model.samplerate)
|
||||
outputs[source_name] = path
|
||||
|
||||
song_path.unlink()
|
||||
|
||||
for source, path in sorted(outputs.items()):
|
||||
flac_path = task_dir / f"{task.orig_name}-{source}.flac"
|
||||
convert_audio(path, flac_path)
|
||||
report_task_output(task.database_id, source, flac_path)
|
||||
path.unlink() # delete wav
|
||||
|
||||
|
||||
def worker_main():
|
||||
model = get_model("htdemucs_ft")
|
||||
model.eval()
|
||||
|
||||
with sqlite3.connect(DATABASE) as conn:
|
||||
cur = conn.execute(
|
||||
"SELECT id, orig_name, song_location FROM tasks WHERE status NOT IN ('done', 'errored')"
|
||||
)
|
||||
for id, orig_name, song_location in cur.fetchall():
|
||||
work_queue.put(DemucsTask(UUID(id), orig_name, song_location))
|
||||
|
||||
while True:
|
||||
task = work_queue.get()
|
||||
try:
|
||||
report_task_status(task.database_id, "in progress")
|
||||
process_task(task, model)
|
||||
report_task_status(task.database_id, "done")
|
||||
except BaseException as e:
|
||||
report_task_status(task.database_id, "errored")
|
||||
print(e)
|
||||
work_queue.task_done()
|
||||
|
||||
|
||||
def start_worker():
|
||||
import threading
|
||||
|
||||
thread = threading.Thread(target=worker_main, daemon=True)
|
||||
thread.start()
|
||||
|
||||
return thread
|
6
src/dev_run.py
Executable file
6
src/dev_run.py
Executable file
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from demucs_server import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=8000, debug=True)
|
5
src/migrations/0000_init.sql
Normal file
5
src/migrations/0000_init.sql
Normal file
|
@ -0,0 +1,5 @@
|
|||
CREATE TABLE migrations (
|
||||
applied TEXT PRIMARY KEY
|
||||
) STRICT;
|
||||
|
||||
INSERT INTO migrations VALUES ('0000_init');
|
4
src/migrations/0001_auth.sql
Normal file
4
src/migrations/0001_auth.sql
Normal file
|
@ -0,0 +1,4 @@
|
|||
CREATE TABLE users (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
nickname TEXT NOT NULL
|
||||
) STRICT;
|
6
src/migrations/0002_tasks.sql
Normal file
6
src/migrations/0002_tasks.sql
Normal file
|
@ -0,0 +1,6 @@
|
|||
CREATE TABLE tasks (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
orig_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL, -- "queued / in progress / done"
|
||||
song_location TEXT NOT NULL
|
||||
) STRICT;
|
5
src/migrations/0003_task_outputs.sql
Normal file
5
src/migrations/0003_task_outputs.sql
Normal file
|
@ -0,0 +1,5 @@
|
|||
CREATE TABLE task_outputs (
|
||||
task_id TEXT NOT NULL,
|
||||
track TEXT NOT NULL,
|
||||
output_path TEXT NOT NULL
|
||||
) STRICT;
|
7
src/migrations/0004_task_progress.sql
Normal file
7
src/migrations/0004_task_progress.sql
Normal file
|
@ -0,0 +1,7 @@
|
|||
CREATE TABLE task_progress (
|
||||
task_id TEXT NOT NULL UNIQUE,
|
||||
stage INTEGER NOT NULL,
|
||||
stages INTEGER NOT NULL,
|
||||
step INTEGER NOT NULL,
|
||||
stage_steps INTEGER NOT NULL
|
||||
) STRICT;
|
3
src/run.sh
Executable file
3
src/run.sh
Executable file
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
gunicorn -b '127.0.0.2:8000' demucs_server:app
|
Loading…
Reference in a new issue