Fix Discord-to-Discord replies

main
Charlotte Som 2022-04-16 10:16:51 +01:00
parent b9daf95a04
commit 48c3eb1830
7 changed files with 107 additions and 34 deletions

View File

@ -6,17 +6,11 @@ use phoebe::{
get_linked_channels, link_messages, get_linked_channels, link_messages,
prelude::{ChatEvent, SqlitePool}, prelude::{ChatEvent, SqlitePool},
service::Service, service::Service,
DynServiceLookup,
}; };
fn dyn_service(service: &str) -> &'static str {
match service {
"discord" => "discord",
"matrix" => "matrix",
_ => panic!("Unsupported service: {}", service),
}
}
async fn handle_events( async fn handle_events(
dyn_service: DynServiceLookup,
db: SqlitePool, db: SqlitePool,
mut service: Box<dyn Service + Send + Sync>, mut service: Box<dyn Service + Send + Sync>,
mut rx: tokio::sync::broadcast::Receiver<ChatEvent>, mut rx: tokio::sync::broadcast::Receiver<ChatEvent>,
@ -32,7 +26,7 @@ async fn handle_events(
match event { match event {
phoebe::prelude::ChatEvent::NewMessage(message) => { phoebe::prelude::ChatEvent::NewMessage(message) => {
let linked_channels = let linked_channels =
get_linked_channels(&mut conn, &message.origin.channel, dyn_service).await; get_linked_channels(&mut conn, dyn_service, &message.origin.channel).await;
let mut resulting_messages = vec![]; let mut resulting_messages = vec![];
for destination_channel in linked_channels { for destination_channel in linked_channels {
@ -64,13 +58,22 @@ async fn main() -> Result<()> {
let (tx, _) = tokio::sync::broadcast::channel(512); let (tx, _) = tokio::sync::broadcast::channel(512);
let db = phoebe::open_core_db().await?; let db = phoebe::open_core_db().await?;
fn dyn_service(service: &str) -> &'static str {
match service {
"discord" => "discord",
"matrix" => "matrix",
_ => panic!("Unsupported service: {}", service),
}
}
let services: Vec<Box<dyn Service + Send + Sync>> = vec![Box::new( let services: Vec<Box<dyn Service + Send + Sync>> = vec![Box::new(
phoebe_discord::setup(db.clone(), tx.clone()).await?, phoebe_discord::setup(db.clone(), tx.clone(), dyn_service).await?,
)]; )];
let handles = services let handles = services
.into_iter() .into_iter()
.map(|srv| tokio::spawn(handle_events(db.clone(), srv, tx.subscribe()))); .map(|srv| tokio::spawn(handle_events(dyn_service, db.clone(), srv, tx.subscribe())));
let _ = futures::future::join_all(handles).await; let _ = futures::future::join_all(handles).await;

View File

@ -11,5 +11,4 @@ sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "sqlite"] }
tracing = "0.1" tracing = "0.1"
async-trait = "0.1.53" async-trait = "0.1.53"
eyre = "0.6.8" eyre = "0.6.8"
tokio-stream = "0.1.8"
futures = "0.3.21" futures = "0.3.21"

View File

@ -1,10 +1,10 @@
use futures::stream::BoxStream; use futures::{stream::BoxStream, Future};
pub use mid_chat; pub use mid_chat;
use mid_chat::{ChatMessageReference, ChatReference}; use mid_chat::{ChatMessageReference, ChatReference};
use futures::StreamExt;
use sqlx::{Row, SqliteConnection, SqlitePool}; use sqlx::{Row, SqliteConnection, SqlitePool};
use tokio::sync::broadcast::*; use tokio::sync::broadcast::*;
use tokio_stream::StreamExt;
pub mod db; pub mod db;
pub mod prelude; pub mod prelude;
@ -23,8 +23,8 @@ pub async fn open_core_db() -> sqlx::Result<SqlitePool> {
pub async fn get_linked_channels( pub async fn get_linked_channels(
conn: &mut SqliteConnection, conn: &mut SqliteConnection,
channel: &ChatReference,
dyn_service: DynServiceLookup, dyn_service: DynServiceLookup,
channel: &ChatReference,
) -> Vec<ChatReference> { ) -> Vec<ChatReference> {
let from_service = channel.service; let from_service = channel.service;
let from_channel = &channel.id; let from_channel = &channel.id;
@ -36,7 +36,7 @@ pub async fn get_linked_channels(
query query
.fetch(&mut *conn) .fetch(&mut *conn)
.filter_map(Result::ok) .filter_map(|r| async { r.ok() })
.map(|r| ChatReference { .map(|r| ChatReference {
service: dyn_service(&r.to_service), service: dyn_service(&r.to_service),
id: r.to_channel, id: r.to_channel,
@ -76,8 +76,8 @@ pub async fn link_messages(
pub async fn get_linked_messages<'a>( pub async fn get_linked_messages<'a>(
conn: &'a mut SqliteConnection, conn: &'a mut SqliteConnection,
message: &ChatMessageReference,
dyn_service: DynServiceLookup, dyn_service: DynServiceLookup,
message: &ChatMessageReference,
) -> sqlx::Result<BoxStream<'a, ChatMessageReference>> { ) -> sqlx::Result<BoxStream<'a, ChatMessageReference>> {
let link_id = { let link_id = {
let service = &message.channel.service; let service = &message.channel.service;
@ -96,7 +96,7 @@ pub async fn get_linked_messages<'a>(
let stream = sqlx::query("SELECT * FROM messages WHERE link_id = ?") let stream = sqlx::query("SELECT * FROM messages WHERE link_id = ?")
.bind(link_id) .bind(link_id)
.fetch(&mut *conn) .fetch(&mut *conn)
.filter_map(Result::ok) .filter_map(|r| futures::future::ready(r.ok()))
.map(move |r| { .map(move |r| {
ChatMessageReference::new( ChatMessageReference::new(
ChatReference { ChatReference {
@ -109,3 +109,27 @@ pub async fn get_linked_messages<'a>(
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }
pub async fn lookup_message<F, Fut>(
conn: &mut SqliteConnection,
dyn_service: DynServiceLookup,
linked_message: &ChatMessageReference,
filter: F,
) -> Option<ChatMessageReference>
where
F: FnMut(&ChatMessageReference) -> Fut,
Fut: Future<Output = bool>,
{
let references = get_linked_messages(&mut *conn, dyn_service, linked_message)
.await
.ok()?
.filter(filter)
.collect::<Vec<_>>()
.await;
if let [reference] = references.as_slice() {
Some(reference.clone())
} else {
None
}
}

View File

@ -1,6 +1,7 @@
pub use crate::{service::Service, ChatEventReceiver, ChatEventSender}; pub use crate::{service::Service, ChatEventReceiver, ChatEventSender};
pub use async_trait::async_trait; pub use async_trait::async_trait;
pub use eyre::Result; pub use eyre::{self, Result};
pub use futures::{self, prelude::*};
pub use mid_chat::event::ChatEvent; pub use mid_chat::event::ChatEvent;
pub use sqlx::{SqliteConnection, SqlitePool}; pub use sqlx::{self, SqliteConnection, SqlitePool};

View File

@ -1,12 +1,14 @@
use phoebe::{ use phoebe::{
mid_chat::{self, ChatMessage, ChatMessageReference, ChatReference}, mid_chat::{self, ChatMessage, ChatMessageReference, ChatReference},
prelude::*, prelude::*,
DynServiceLookup,
}; };
use serenity::{client::Context, Client}; use serenity::{client::Context, Client};
use tracing::{debug, info}; use tracing::{debug, info};
mod chat_conv; mod chat_conv;
mod handler; mod handler;
mod lookup;
mod sender; mod sender;
pub fn discord_reference(id: impl ToString) -> mid_chat::ChatReference { pub fn discord_reference(id: impl ToString) -> mid_chat::ChatReference {
@ -17,10 +19,17 @@ pub fn discord_reference(id: impl ToString) -> mid_chat::ChatReference {
} }
pub struct DiscordService { pub struct DiscordService {
pub discord_ctx: Context, pub core_db: SqlitePool,
pub discord_media_db: SqlitePool,
pub ctx: Context,
pub dyn_service: DynServiceLookup,
} }
pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result<DiscordService> { pub async fn setup(
core_db: SqlitePool,
tx: ChatEventSender,
dyn_service: DynServiceLookup,
) -> Result<DiscordService> {
info!("Setting up Discord service…"); info!("Setting up Discord service…");
let discord_media_db = phoebe::db::open("discord_media").await?; let discord_media_db = phoebe::db::open("discord_media").await?;
@ -29,8 +38,8 @@ pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result<DiscordSe
let (ctx_tx, mut ctx_rx) = tokio::sync::mpsc::unbounded_channel::<Context>(); let (ctx_tx, mut ctx_rx) = tokio::sync::mpsc::unbounded_channel::<Context>();
let discord_handler = handler::DiscordHandler { let discord_handler = handler::DiscordHandler {
core_db, core_db: core_db.clone(),
discord_media_db, discord_media_db: discord_media_db.clone(),
chat_event_tx: tx, chat_event_tx: tx,
ctx_tx, ctx_tx,
}; };
@ -53,7 +62,12 @@ pub async fn setup(core_db: SqlitePool, tx: ChatEventSender) -> Result<DiscordSe
let discord_ctx = ctx_rx.recv().await.expect("Couldn't get Discord context"); let discord_ctx = ctx_rx.recv().await.expect("Couldn't get Discord context");
debug!("Logged in!"); debug!("Logged in!");
Ok(DiscordService { discord_ctx }) Ok(DiscordService {
core_db,
discord_media_db,
ctx: discord_ctx,
dyn_service,
})
} }
#[async_trait] #[async_trait]
@ -64,7 +78,7 @@ impl Service for DiscordService {
destination_channel: ChatReference, destination_channel: ChatReference,
) -> Vec<ChatMessageReference> { ) -> Vec<ChatMessageReference> {
assert_eq!(destination_channel.service, "discord"); assert_eq!(destination_channel.service, "discord");
sender::send_discord_message(&mut self.discord_ctx, source, destination_channel) sender::send_discord_message(self, source, destination_channel)
.await .await
.ok() .ok()
.into_iter() .into_iter()

View File

@ -0,0 +1,18 @@
use phoebe::{lookup_message, mid_chat::ChatMessageReference, prelude::*};
use crate::DiscordService;
impl DiscordService {
pub async fn lookup_message<F, Fut>(
&self,
linked_message: &ChatMessageReference,
filter: F,
) -> Option<ChatMessageReference>
where
F: FnMut(&ChatMessageReference) -> Fut,
Fut: Future<Output = bool>,
{
let mut conn = self.core_db.acquire().await.ok()?;
lookup_message(&mut conn, self.dyn_service, linked_message, filter).await
}
}

View File

@ -1,14 +1,14 @@
use phoebe::{ use phoebe::{
mid_chat::{ChatMessage, ChatMessageReference, ChatReference}, mid_chat::{ChatMessage, ChatMessageReference, ChatReference},
prelude::Result, prelude::{future, Result},
}; };
use serenity::{model::prelude::*, prelude::*}; use serenity::{model::prelude::*, prelude::*};
use crate::{chat_conv, discord_reference}; use crate::{chat_conv, discord_reference, DiscordService};
pub async fn send_discord_message( pub async fn send_discord_message(
context: &mut Context, discord: &mut DiscordService,
source: &ChatMessage, source: &ChatMessage,
destination_channel: ChatReference, destination_channel: ChatReference,
) -> Result<ChatMessageReference> { ) -> Result<ChatMessageReference> {
@ -21,13 +21,27 @@ pub async fn send_discord_message(
source.author.display_name, source.author.reference.service, formatted_message source.author.display_name, source.author.reference.service, formatted_message
); );
let discord_reply = if let Some(reply) = &source.replying {
if let Some(reply_ref) = discord
.lookup_message(reply, |r| future::ready(r.channel == destination_channel))
.await
{
assert_eq!(reply_ref.channel.service, "discord");
let channel_id: ChannelId = reply_ref.channel.id.parse().unwrap();
let message_id: MessageId = reply_ref.message_id.parse::<u64>().unwrap().into();
Some((channel_id, message_id))
} else {
None
}
} else {
None
};
let sent_message = channel_id let sent_message = channel_id
.send_message(&context, move |m| { .send_message(&discord.ctx, move |m| {
let m = m.content(content); let m = m.content(content);
if let Some(reply) = &source.replying { if let Some(reply) = discord_reply {
let channel_id: ChannelId = reply.channel.id.parse().unwrap(); m.reference_message(reply)
let message_id: MessageId = reply.message_id.parse::<u64>().unwrap().into();
m.reference_message((channel_id, message_id))
} else { } else {
m m
} }