From deb2f6e8e54bb0dd39bace749301dcb3fc762edc Mon Sep 17 00:00:00 2001 From: Charlotte Som Date: Wed, 26 Feb 2025 04:19:04 +0000 Subject: [PATCH] get inference over websocket working --- requirements.txt | Bin 136 -> 156 bytes server/__init__.py | 2 +- server/inference.py | 29 ++++++++++++++++++++++++----- server/tid.py | 13 +++++++++++++ 4 files changed, 38 insertions(+), 6 deletions(-) create mode 100644 server/tid.py diff --git a/requirements.txt b/requirements.txt index 620b0c7e037c17e90338a4fd7a840be91da1124f..216d9c6898a4a3847a3260ff0e049aba369dadb3 100644 GIT binary patch delta 47 wcmeBRoWnSwL?fD^n4yFrks*&E1xOYF$yf$kAT(gmV=!hg0g?s`ybN3n0Quzz5C8xG delta 27 hcmbQk*uglVM8uZCmcf8QkHMJ11V|b%@G@{Q002yD1M>g? diff --git a/server/__init__.py b/server/__init__.py index 935c59b..927ecdf 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -7,5 +7,5 @@ async def status(request: Request) -> Response: app = Starlette(debug=True, routes=[ Route("/api/", status), Route("/api/conversation", list_conversations, methods=["GET"]), - WebSocketRoute("/api/conversation/:conversation/connect", connect_to_conversation) + WebSocketRoute("/api/conversation/{conversation}/connect", connect_to_conversation) ]) diff --git a/server/inference.py b/server/inference.py index 0c1fd0d..5c6e7b3 100644 --- a/server/inference.py +++ b/server/inference.py @@ -1,5 +1,6 @@ import llm, llm.cli, sqlite_utils -from .http import Request, JSONResponse, WebSocket, RedirectResponse +from .http import Request, JSONResponse, WebSocket +from .tid import tid_now import json db = sqlite_utils.Database(llm.cli.logs_db_path()) @@ -22,12 +23,30 @@ async def connect_to_conversation(ws: WebSocket): except: await ws.send_denial_response(JSONResponse({ "error": "unable to load conversation {}".format(conversation_id) - })) + }, status_code=404)) return await ws.accept() + + # only send the system prompt at the start of a conversation + system_prompt = girlypop_prompt + + for response in conversation.responses: + response: llm.AsyncResponse = response + if not response._done: continue + if response.prompt.system: + system_prompt = None + await ws.send_text(json.dumps({"u": response.prompt.prompt})) # user + await ws.send_text(json.dumps({"f": response.text_or_raise()})) # full + async for message in ws.iter_text(): - response = conversation.prompt(message, system=girlypop_prompt) + response = conversation.prompt(message, system=system_prompt, stream=True) + system_prompt = None + + response_tid = tid_now() + await ws.send_text(json.dumps({"u": message})) + await ws.send_text(json.dumps({"s": response_tid})) # start async for chunk in response: - ws.send_text(json.dumps({"c": chunk})) - ws.send_text(json.dumps({"d": True})) # done + await ws.send_text(json.dumps({"r": response_tid, "c": chunk})) + await ws.send_text(json.dumps({"d": response_tid})) # done + (await response.to_sync_response()).log_to_db(db) diff --git a/server/tid.py b/server/tid.py new file mode 100644 index 0000000..442a427 --- /dev/null +++ b/server/tid.py @@ -0,0 +1,13 @@ +# i stole this from david + +import time + +B32_CHARSET = "234567abcdefghijklmnopqrstuvwxyz" + +def tid_now(): + micros, nanos = divmod(int(time.time() * 1_000_000_000), 1000) + clkid = nanos + tid_int = (micros << 10) | clkid + return "".join( + B32_CHARSET[(tid_int >> (60 - (i * 5))) & 31] for i in range(13) + )