diff --git a/client/main.tsx b/client/main.tsx index a5b37e4..4bb7188 100644 --- a/client/main.tsx +++ b/client/main.tsx @@ -1,3 +1,5 @@ +import { Signal } from "@char/aftercare"; + const main = document.querySelector("main")!; async function nav() { @@ -9,7 +11,7 @@ async function nav() { button.addEventListener("click", e => { e.preventDefault(); - main.append(conversationUI(conversation.id)); + main.append(conversationUI(conversation.id, conversation.name)); nav.remove(); }); @@ -22,7 +24,7 @@ async function nav() { _tap={b => b.addEventListener("click", e => { e.preventDefault(); - main.append(conversationUI("new")); + main.append(conversationUI("new", "New conversation")); nav.remove(); }) } @@ -37,7 +39,37 @@ async function nav() { function conversationUI(id: string) { window.location.hash = `#${id}`; - const socket = new WebSocket(`/api/conversation/${id}/connect`); + let socket: WebSocket; + let connected = false; + const connect = () => { + socket = new WebSocket( + `/api/conversation/${id}/connect` + (connected ? "?continue=1" : ""), + ); + socket.addEventListener("open", () => (connected = true)); + socket.addEventListener("close", () => (socket = connect())); + socket.addEventListener("error", ev => { + console.warn(ev); + // TODO: handle errors + }); + return socket; + }; + socket = connect(); + + const header = ( +
+

+ +

+ ); + const name = new Signal(""); + name.subscribeImmediate(it => (header.querySelector("h1")!.textContent = it)); + header.querySelector("button")!.addEventListener("click", async e => { + e.preventDefault(); + await fetch(`/api/conversation/${id}`, { method: "DELETE" }); + window.location.hash = ""; + }); const chatlog =
; const inFlightMessages = new Map(); @@ -64,6 +96,8 @@ function conversationUI(id: string) { article.append(message.c); } else if ("d" in message) { inFlightMessages.delete(message.d); + } else if ("n" in message) { + name.set(message.n); } if (scrolledToBottom) chatlog.scrollTop = chatlog.scrollHeight - chatlog.clientHeight; @@ -83,6 +117,7 @@ function conversationUI(id: string) { return (
+ {header} {chatlog} {form}
diff --git a/client/web/css/styles.css b/client/web/css/styles.css index 9d343f2..220fb14 100644 --- a/client/web/css/styles.css +++ b/client/web/css/styles.css @@ -51,6 +51,19 @@ main, } } +.conversation header { + display: flex; + flex-direction: row; + justify-content: space-between; + align-items: center; + font-size: smaller; + + gap: 1em; + * { + margin: 0; + } +} + .chatlog { flex: 1; diff --git a/server/__init__.py b/server/__init__.py index fcaa1a2..33e308e 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -1,5 +1,5 @@ from server.http import Starlette, Route, Request, Response, JSONResponse, WebSocketRoute, Mount -from server.inference import list_conversations, connect_to_conversation +from server.inference import list_conversations, delete_conversation, connect_to_conversation from starlette.staticfiles import StaticFiles async def status(request: Request) -> Response: @@ -8,6 +8,7 @@ async def status(request: Request) -> Response: app = Starlette(debug=True, routes=[ Route("/api/", status), Route("/api/conversation", list_conversations, methods=["GET"]), + Route("/api/conversation/{conversation}", delete_conversation, methods=["DELETE"]), WebSocketRoute("/api/conversation/{conversation}/connect", connect_to_conversation), Mount("/", app=StaticFiles(directory="client/web", html=True), name="client") ]) diff --git a/server/inference.py b/server/inference.py index 332a2ba..625b07f 100644 --- a/server/inference.py +++ b/server/inference.py @@ -14,6 +14,13 @@ async def list_conversations(request: Request): return JSONResponse(conversations) +async def delete_conversation(request: Request): + conversation_id = request.path_params["conversation"] + db["responses"].delete_where("conversation_id = ?", [conversation_id]) + db["conversations"].delete(conversation_id) + return JSONResponse({"status": "ok"}) + + async def connect_to_conversation(ws: WebSocket): continuing = bool(ws.query_params.get("continue")) conversation_id = ws.path_params["conversation"] @@ -42,6 +49,8 @@ async def connect_to_conversation(ws: WebSocket): await ws.send_text(json({"u": response.prompt.prompt})) # user await ws.send_text(json({"f": response.text_or_raise()})) # full + await ws.send_text(json({"n": conversation.name})) + if conversation_id == "new": await ws.send_text(json({"i": conversation.id}))