demucs-server/src/demucs_server/routes.py

169 lines
4.5 KiB
Python

from flask import render_template, request, redirect, send_file, Response
from werkzeug.utils import secure_filename
from werkzeug.security import safe_join
from uuid import uuid4
import json
from . import app, UPLOAD_DIR
from .auth import get_user
from .db import get_db
from .torch_worker import get_device_name, get_queue_size, work_queue, DemucsTask
from .sse import bus
@app.route("/")
@app.route("/from-url")
def index_page():
user_id = request.cookies.get("demucs_server_user_id")
user = user_id and get_user(user_id)
if user is None:
return render_template("login.html.j2")
status = {"device": get_device_name(), "queue_size": get_queue_size()}
if request.path == "/from-url":
return render_template("demix_ytdlp.html.j2", user=user, status=status)
return render_template("demix.html.j2", user=user, status=status)
@app.route("/login", methods=["GET", "POST"])
def login():
response = redirect("/")
user_id = request.form.get("user-id", type=str)
if user_id is not None:
response.set_cookie(
"demucs_server_user_id",
user_id,
samesite="Strict",
secure=True,
)
return response
@app.route("/logout", methods=["GET", "POST"])
def logout():
if request.method == "POST":
response = redirect("/")
response.delete_cookie("demucs_server_user_id")
return response
return render_template("logout.html.j2")
# TODO: demix routes (involves multipart form parsing!! maybe ytdlp!!)
@app.route("/demix", methods=["POST"])
def demix():
user_id = request.cookies.get("demucs_server_user_id")
user = user_id and get_user(user_id)
if user is None:
return "You are not logged in!", 403
audio_file = next(request.files.values(), None)
if audio_file is None:
return redirect("/")
orig_name = secure_filename(audio_file.filename).replace("_", "-")
task_id = uuid4()
song_location = f"{task_id}-{orig_name}"
orig_name = orig_name.rsplit(".", 1)[0]
db = get_db()
db.execute(
"INSERT INTO tasks (id, orig_name, status, song_location) VALUES (?, ?, ?, ?)",
[str(task_id), orig_name, "queued", song_location],
)
audio_file.save(UPLOAD_DIR / song_location)
db.commit()
demucs_task = DemucsTask(task_id, orig_name, song_location)
work_queue.put(demucs_task)
return redirect(f"/task/{task_id}")
@app.route("/task/<task_id>")
def task_view(task_id):
db = get_db()
try:
cur = db.execute("SELECT status, orig_name FROM tasks WHERE id = ?", [task_id])
(status, orig_name) = cur.fetchone()
except:
# TODO: real error page
return "Task with the given ID was not found", 404
outputs = []
try:
cur = db.execute(
"SELECT track, output_path FROM task_outputs WHERE task_id = ?", [task_id]
)
outputs = [
{"track": track, "path": path}
for (
track,
path,
) in cur.fetchall()
]
except:
pass
outputs = sorted(
outputs,
key=lambda o: ["master", "vocals", "other", "bass", "drums"].index(o["track"]),
)
progress = {}
try:
cur = db.execute(
"SELECT stage, stages, step, stage_steps FROM task_progress WHERE task_id = ?",
[task_id],
)
stage, stages, step, stage_steps = cur.fetchone()
progress = {
"stage": stage,
"stages": stages,
"step": step,
"stage-steps": stage_steps,
}
except:
pass
return render_template(
"task.html.j2",
task_id=task_id,
status=status,
orig_name=orig_name,
outputs=outputs,
progress=progress,
)
@app.route("/uploads/<path:upload>")
def send_upload(upload):
upload_file = safe_join(UPLOAD_DIR.absolute(), upload)
return send_file(upload_file, conditional=True)
@app.route("/sse")
def sse():
task_id = request.args.get("task_id")
if task_id is None:
return "bad request", 400
def stream():
queue = bus.listen()
while True:
message = queue.get()
if message.get("task") != task_id:
continue
message = json.dumps(message)
yield f"{message}\n\n"
return Response(
stream(),
mimetype="text/event-stream",
headers={"cache-control": "no-cache", "x-accel-buffering": "no"},
)