169 lines
4.5 KiB
Python
169 lines
4.5 KiB
Python
from flask import render_template, request, redirect, send_file, Response
|
|
|
|
from werkzeug.utils import secure_filename
|
|
from werkzeug.security import safe_join
|
|
|
|
from uuid import uuid4
|
|
import json
|
|
|
|
from . import app, UPLOAD_DIR
|
|
from .auth import get_user
|
|
from .db import get_db
|
|
from .torch_worker import get_device_name, get_queue_size, work_queue, DemucsTask
|
|
from .sse import bus
|
|
|
|
|
|
@app.route("/")
|
|
@app.route("/from-url")
|
|
def index_page():
|
|
user_id = request.cookies.get("demucs_server_user_id")
|
|
user = user_id and get_user(user_id)
|
|
if user is None:
|
|
return render_template("login.html.j2")
|
|
|
|
status = {"device": get_device_name(), "queue_size": get_queue_size()}
|
|
|
|
if request.path == "/from-url":
|
|
return render_template("demix_ytdlp.html.j2", user=user, status=status)
|
|
return render_template("demix.html.j2", user=user, status=status)
|
|
|
|
|
|
@app.route("/login", methods=["GET", "POST"])
|
|
def login():
|
|
response = redirect("/")
|
|
user_id = request.form.get("user-id", type=str)
|
|
if user_id is not None:
|
|
response.set_cookie(
|
|
"demucs_server_user_id",
|
|
user_id,
|
|
samesite="Strict",
|
|
secure=True,
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@app.route("/logout", methods=["GET", "POST"])
|
|
def logout():
|
|
if request.method == "POST":
|
|
response = redirect("/")
|
|
response.delete_cookie("demucs_server_user_id")
|
|
return response
|
|
return render_template("logout.html.j2")
|
|
|
|
|
|
# TODO: demix routes (involves multipart form parsing!! maybe ytdlp!!)
|
|
@app.route("/demix", methods=["POST"])
|
|
def demix():
|
|
user_id = request.cookies.get("demucs_server_user_id")
|
|
user = user_id and get_user(user_id)
|
|
if user is None:
|
|
return "You are not logged in!", 403
|
|
|
|
audio_file = next(request.files.values(), None)
|
|
if audio_file is None:
|
|
return redirect("/")
|
|
|
|
orig_name = secure_filename(audio_file.filename).replace("_", "-")
|
|
task_id = uuid4()
|
|
song_location = f"{task_id}-{orig_name}"
|
|
orig_name = orig_name.rsplit(".", 1)[0]
|
|
|
|
db = get_db()
|
|
db.execute(
|
|
"INSERT INTO tasks (id, orig_name, status, song_location) VALUES (?, ?, ?, ?)",
|
|
[str(task_id), orig_name, "queued", song_location],
|
|
)
|
|
audio_file.save(UPLOAD_DIR / song_location)
|
|
db.commit()
|
|
|
|
demucs_task = DemucsTask(task_id, orig_name, song_location)
|
|
work_queue.put(demucs_task)
|
|
|
|
return redirect(f"/task/{task_id}")
|
|
|
|
|
|
@app.route("/task/<task_id>")
|
|
def task_view(task_id):
|
|
db = get_db()
|
|
|
|
try:
|
|
cur = db.execute("SELECT status, orig_name FROM tasks WHERE id = ?", [task_id])
|
|
(status, orig_name) = cur.fetchone()
|
|
except:
|
|
# TODO: real error page
|
|
return "Task with the given ID was not found", 404
|
|
|
|
outputs = []
|
|
try:
|
|
cur = db.execute(
|
|
"SELECT track, output_path FROM task_outputs WHERE task_id = ?", [task_id]
|
|
)
|
|
outputs = [
|
|
{"track": track, "path": path}
|
|
for (
|
|
track,
|
|
path,
|
|
) in cur.fetchall()
|
|
]
|
|
except:
|
|
pass
|
|
|
|
outputs = sorted(
|
|
outputs,
|
|
key=lambda o: ["master", "vocals", "other", "bass", "drums"].index(o["track"]),
|
|
)
|
|
|
|
progress = {}
|
|
try:
|
|
cur = db.execute(
|
|
"SELECT stage, stages, step, stage_steps FROM task_progress WHERE task_id = ?",
|
|
[task_id],
|
|
)
|
|
stage, stages, step, stage_steps = cur.fetchone()
|
|
progress = {
|
|
"stage": stage,
|
|
"stages": stages,
|
|
"step": step,
|
|
"stage-steps": stage_steps,
|
|
}
|
|
except:
|
|
pass
|
|
|
|
return render_template(
|
|
"task.html.j2",
|
|
task_id=task_id,
|
|
status=status,
|
|
orig_name=orig_name,
|
|
outputs=outputs,
|
|
progress=progress,
|
|
)
|
|
|
|
|
|
@app.route("/uploads/<path:upload>")
|
|
def send_upload(upload):
|
|
upload_file = safe_join(UPLOAD_DIR.absolute(), upload)
|
|
return send_file(upload_file, conditional=True)
|
|
|
|
|
|
@app.route("/sse")
|
|
def sse():
|
|
task_id = request.args.get("task_id")
|
|
if task_id is None:
|
|
return "bad request", 400
|
|
|
|
def stream():
|
|
queue = bus.listen()
|
|
while True:
|
|
message = queue.get()
|
|
if message.get("task") != task_id:
|
|
continue
|
|
message = json.dumps(message)
|
|
yield f"data:{message}\n\n"
|
|
|
|
return Response(
|
|
stream(),
|
|
mimetype="text/event-stream",
|
|
headers={"cache-control": "no-cache", "x-accel-buffering": "no"},
|
|
)
|