get inference over websocket working
This commit is contained in:
parent
75c5c2db63
commit
deb2f6e8e5
4 changed files with 38 additions and 6 deletions
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
|
@ -7,5 +7,5 @@ async def status(request: Request) -> Response:
|
||||||
app = Starlette(debug=True, routes=[
|
app = Starlette(debug=True, routes=[
|
||||||
Route("/api/", status),
|
Route("/api/", status),
|
||||||
Route("/api/conversation", list_conversations, methods=["GET"]),
|
Route("/api/conversation", list_conversations, methods=["GET"]),
|
||||||
WebSocketRoute("/api/conversation/:conversation/connect", connect_to_conversation)
|
WebSocketRoute("/api/conversation/{conversation}/connect", connect_to_conversation)
|
||||||
])
|
])
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import llm, llm.cli, sqlite_utils
|
import llm, llm.cli, sqlite_utils
|
||||||
from .http import Request, JSONResponse, WebSocket, RedirectResponse
|
from .http import Request, JSONResponse, WebSocket
|
||||||
|
from .tid import tid_now
|
||||||
import json
|
import json
|
||||||
|
|
||||||
db = sqlite_utils.Database(llm.cli.logs_db_path())
|
db = sqlite_utils.Database(llm.cli.logs_db_path())
|
||||||
|
@ -22,12 +23,30 @@ async def connect_to_conversation(ws: WebSocket):
|
||||||
except:
|
except:
|
||||||
await ws.send_denial_response(JSONResponse({
|
await ws.send_denial_response(JSONResponse({
|
||||||
"error": "unable to load conversation {}".format(conversation_id)
|
"error": "unable to load conversation {}".format(conversation_id)
|
||||||
}))
|
}, status_code=404))
|
||||||
return
|
return
|
||||||
|
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
|
|
||||||
|
# only send the system prompt at the start of a conversation
|
||||||
|
system_prompt = girlypop_prompt
|
||||||
|
|
||||||
|
for response in conversation.responses:
|
||||||
|
response: llm.AsyncResponse = response
|
||||||
|
if not response._done: continue
|
||||||
|
if response.prompt.system:
|
||||||
|
system_prompt = None
|
||||||
|
await ws.send_text(json.dumps({"u": response.prompt.prompt})) # user
|
||||||
|
await ws.send_text(json.dumps({"f": response.text_or_raise()})) # full
|
||||||
|
|
||||||
async for message in ws.iter_text():
|
async for message in ws.iter_text():
|
||||||
response = conversation.prompt(message, system=girlypop_prompt)
|
response = conversation.prompt(message, system=system_prompt, stream=True)
|
||||||
|
system_prompt = None
|
||||||
|
|
||||||
|
response_tid = tid_now()
|
||||||
|
await ws.send_text(json.dumps({"u": message}))
|
||||||
|
await ws.send_text(json.dumps({"s": response_tid})) # start
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
ws.send_text(json.dumps({"c": chunk}))
|
await ws.send_text(json.dumps({"r": response_tid, "c": chunk}))
|
||||||
ws.send_text(json.dumps({"d": True})) # done
|
await ws.send_text(json.dumps({"d": response_tid})) # done
|
||||||
|
(await response.to_sync_response()).log_to_db(db)
|
||||||
|
|
13
server/tid.py
Normal file
13
server/tid.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# i stole this from david
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
B32_CHARSET = "234567abcdefghijklmnopqrstuvwxyz"
|
||||||
|
|
||||||
|
def tid_now():
|
||||||
|
micros, nanos = divmod(int(time.time() * 1_000_000_000), 1000)
|
||||||
|
clkid = nanos
|
||||||
|
tid_int = (micros << 10) | clkid
|
||||||
|
return "".join(
|
||||||
|
B32_CHARSET[(tid_int >> (60 - (i * 5))) & 31] for i in range(13)
|
||||||
|
)
|
Loading…
Reference in a new issue