From 48c3eb1830638b7cb6acd0e58d692b2bd822bc8f Mon Sep 17 00:00:00 2001 From: videogame hacker Date: Sat, 16 Apr 2022 10:16:51 +0100 Subject: [PATCH] Fix Discord-to-Discord replies --- phoebe-main/src/main.rs | 25 +++++++++++-------- phoebe/Cargo.toml | 1 - phoebe/src/lib.rs | 36 ++++++++++++++++++++++----- phoebe/src/prelude.rs | 5 ++-- services/phoebe-discord/src/lib.rs | 26 ++++++++++++++----- services/phoebe-discord/src/lookup.rs | 18 ++++++++++++++ services/phoebe-discord/src/sender.rs | 30 ++++++++++++++++------ 7 files changed, 107 insertions(+), 34 deletions(-) create mode 100644 services/phoebe-discord/src/lookup.rs diff --git a/phoebe-main/src/main.rs b/phoebe-main/src/main.rs index 10f59fc..6583b97 100644 --- a/phoebe-main/src/main.rs +++ b/phoebe-main/src/main.rs @@ -6,17 +6,11 @@ use phoebe::{ get_linked_channels, link_messages, prelude::{ChatEvent, SqlitePool}, service::Service, + DynServiceLookup, }; -fn dyn_service(service: &str) -> &'static str { - match service { - "discord" => "discord", - "matrix" => "matrix", - _ => panic!("Unsupported service: {}", service), - } -} - async fn handle_events( + dyn_service: DynServiceLookup, db: SqlitePool, mut service: Box, mut rx: tokio::sync::broadcast::Receiver, @@ -32,7 +26,7 @@ async fn handle_events( match event { phoebe::prelude::ChatEvent::NewMessage(message) => { let linked_channels = - get_linked_channels(&mut conn, &message.origin.channel, dyn_service).await; + get_linked_channels(&mut conn, dyn_service, &message.origin.channel).await; let mut resulting_messages = vec![]; for destination_channel in linked_channels { @@ -64,13 +58,22 @@ async fn main() -> Result<()> { let (tx, _) = tokio::sync::broadcast::channel(512); let db = phoebe::open_core_db().await?; + + fn dyn_service(service: &str) -> &'static str { + match service { + "discord" => "discord", + "matrix" => "matrix", + _ => panic!("Unsupported service: {}", service), + } + } + let services: Vec> = vec![Box::new( - phoebe_discord::setup(db.clone(), tx.clone()).await?, + phoebe_discord::setup(db.clone(), tx.clone(), dyn_service).await?, )]; let handles = services .into_iter() - .map(|srv| tokio::spawn(handle_events(db.clone(), srv, tx.subscribe()))); + .map(|srv| tokio::spawn(handle_events(dyn_service, db.clone(), srv, tx.subscribe()))); let _ = futures::future::join_all(handles).await; diff --git a/phoebe/Cargo.toml b/phoebe/Cargo.toml index 87b6658..649ce22 100644 --- a/phoebe/Cargo.toml +++ b/phoebe/Cargo.toml @@ -11,5 +11,4 @@ sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "sqlite"] } tracing = "0.1" async-trait = "0.1.53" eyre = "0.6.8" -tokio-stream = "0.1.8" futures = "0.3.21" diff --git a/phoebe/src/lib.rs b/phoebe/src/lib.rs index 7741940..700d687 100644 --- a/phoebe/src/lib.rs +++ b/phoebe/src/lib.rs @@ -1,10 +1,10 @@ -use futures::stream::BoxStream; +use futures::{stream::BoxStream, Future}; pub use mid_chat; use mid_chat::{ChatMessageReference, ChatReference}; +use futures::StreamExt; use sqlx::{Row, SqliteConnection, SqlitePool}; use tokio::sync::broadcast::*; -use tokio_stream::StreamExt; pub mod db; pub mod prelude; @@ -23,8 +23,8 @@ pub async fn open_core_db() -> sqlx::Result { pub async fn get_linked_channels( conn: &mut SqliteConnection, - channel: &ChatReference, dyn_service: DynServiceLookup, + channel: &ChatReference, ) -> Vec { let from_service = channel.service; let from_channel = &channel.id; @@ -36,7 +36,7 @@ pub async fn get_linked_channels( query .fetch(&mut *conn) - .filter_map(Result::ok) + .filter_map(|r| async { r.ok() }) .map(|r| ChatReference { service: dyn_service(&r.to_service), id: r.to_channel, @@ -76,8 +76,8 @@ pub async fn link_messages( pub async fn get_linked_messages<'a>( conn: &'a mut SqliteConnection, - message: &ChatMessageReference, dyn_service: DynServiceLookup, + message: &ChatMessageReference, ) -> sqlx::Result> { let link_id = { let service = &message.channel.service; @@ -96,7 +96,7 @@ pub async fn get_linked_messages<'a>( let stream = sqlx::query("SELECT * FROM messages WHERE link_id = ?") .bind(link_id) .fetch(&mut *conn) - .filter_map(Result::ok) + .filter_map(|r| futures::future::ready(r.ok())) .map(move |r| { ChatMessageReference::new( ChatReference { @@ -109,3 +109,27 @@ pub async fn get_linked_messages<'a>( Ok(Box::pin(stream)) } + +pub async fn lookup_message( + conn: &mut SqliteConnection, + dyn_service: DynServiceLookup, + linked_message: &ChatMessageReference, + filter: F, +) -> Option +where + F: FnMut(&ChatMessageReference) -> Fut, + Fut: Future, +{ + let references = get_linked_messages(&mut *conn, dyn_service, linked_message) + .await + .ok()? + .filter(filter) + .collect::>() + .await; + + if let [reference] = references.as_slice() { + Some(reference.clone()) + } else { + None + } +} diff --git a/phoebe/src/prelude.rs b/phoebe/src/prelude.rs index 2d2b381..aa30c44 100644 --- a/phoebe/src/prelude.rs +++ b/phoebe/src/prelude.rs @@ -1,6 +1,7 @@ pub use crate::{service::Service, ChatEventReceiver, ChatEventSender}; pub use async_trait::async_trait; -pub use eyre::Result; +pub use eyre::{self, Result}; +pub use futures::{self, prelude::*}; pub use mid_chat::event::ChatEvent; -pub use sqlx::{SqliteConnection, SqlitePool}; +pub use sqlx::{self, SqliteConnection, SqlitePool}; diff --git a/services/phoebe-discord/src/lib.rs b/services/phoebe-discord/src/lib.rs index 82086ad..c87cfce 100644 --- a/services/phoebe-discord/src/lib.rs +++ b/services/phoebe-discord/src/lib.rs @@ -1,12 +1,14 @@ use phoebe::{ mid_chat::{self, ChatMessage, ChatMessageReference, ChatReference}, prelude::*, + DynServiceLookup, }; use serenity::{client::Context, Client}; use tracing::{debug, info}; mod chat_conv; mod handler; +mod lookup; mod sender; pub fn discord_reference(id: impl ToString) -> mid_chat::ChatReference { @@ -17,10 +19,17 @@ pub fn discord_reference(id: impl ToString) -> mid_chat::ChatReference { } pub struct DiscordService { - pub discord_ctx: Context, + pub core_db: SqlitePool, + pub discord_media_db: SqlitePool, + pub ctx: Context, + pub dyn_service: DynServiceLookup, } -pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result { +pub async fn setup( + core_db: SqlitePool, + tx: ChatEventSender, + dyn_service: DynServiceLookup, +) -> Result { info!("Setting up Discord serviceā€¦"); let discord_media_db = phoebe::db::open("discord_media").await?; @@ -29,8 +38,8 @@ pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result(); let discord_handler = handler::DiscordHandler { - core_db, - discord_media_db, + core_db: core_db.clone(), + discord_media_db: discord_media_db.clone(), chat_event_tx: tx, ctx_tx, }; @@ -53,7 +62,12 @@ pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result Vec { assert_eq!(destination_channel.service, "discord"); - sender::send_discord_message(&mut self.discord_ctx, source, destination_channel) + sender::send_discord_message(self, source, destination_channel) .await .ok() .into_iter() diff --git a/services/phoebe-discord/src/lookup.rs b/services/phoebe-discord/src/lookup.rs new file mode 100644 index 0000000..06e5919 --- /dev/null +++ b/services/phoebe-discord/src/lookup.rs @@ -0,0 +1,18 @@ +use phoebe::{lookup_message, mid_chat::ChatMessageReference, prelude::*}; + +use crate::DiscordService; + +impl DiscordService { + pub async fn lookup_message( + &self, + linked_message: &ChatMessageReference, + filter: F, + ) -> Option + where + F: FnMut(&ChatMessageReference) -> Fut, + Fut: Future, + { + let mut conn = self.core_db.acquire().await.ok()?; + lookup_message(&mut conn, self.dyn_service, linked_message, filter).await + } +} diff --git a/services/phoebe-discord/src/sender.rs b/services/phoebe-discord/src/sender.rs index aa6c30f..ba4b9a5 100644 --- a/services/phoebe-discord/src/sender.rs +++ b/services/phoebe-discord/src/sender.rs @@ -1,14 +1,14 @@ use phoebe::{ mid_chat::{ChatMessage, ChatMessageReference, ChatReference}, - prelude::Result, + prelude::{future, Result}, }; use serenity::{model::prelude::*, prelude::*}; -use crate::{chat_conv, discord_reference}; +use crate::{chat_conv, discord_reference, DiscordService}; pub async fn send_discord_message( - context: &mut Context, + discord: &mut DiscordService, source: &ChatMessage, destination_channel: ChatReference, ) -> Result { @@ -21,13 +21,27 @@ pub async fn send_discord_message( source.author.display_name, source.author.reference.service, formatted_message ); + let discord_reply = if let Some(reply) = &source.replying { + if let Some(reply_ref) = discord + .lookup_message(reply, |r| future::ready(r.channel == destination_channel)) + .await + { + assert_eq!(reply_ref.channel.service, "discord"); + let channel_id: ChannelId = reply_ref.channel.id.parse().unwrap(); + let message_id: MessageId = reply_ref.message_id.parse::().unwrap().into(); + Some((channel_id, message_id)) + } else { + None + } + } else { + None + }; + let sent_message = channel_id - .send_message(&context, move |m| { + .send_message(&discord.ctx, move |m| { let m = m.content(content); - if let Some(reply) = &source.replying { - let channel_id: ChannelId = reply.channel.id.parse().unwrap(); - let message_id: MessageId = reply.message_id.parse::().unwrap().into(); - m.reference_message((channel_id, message_id)) + if let Some(reply) = discord_reply { + m.reference_message(reply) } else { m }