diff --git a/effers-derive/src/lette.rs b/effers-derive/src/lette.rs index 7970a0d..52af450 100644 --- a/effers-derive/src/lette.rs +++ b/effers-derive/src/lette.rs @@ -1,26 +1,40 @@ +const LETTERS: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + #[derive(Clone, Copy)] pub struct LettersIter { - idx: u32, + idx: usize, } impl LettersIter { pub fn new() -> Self { - Self { - idx: 'A' as u32 - 1, - } + Self { idx: 0 } } } impl Iterator for LettersIter { - type Item = char; + type Item = String; fn next(&mut self) -> Option { - for _ in 0..100 { - self.idx += 1; - if let Some(c) = char::from_u32(self.idx) { - return Some(c); - } - } + let l = LETTERS.chars().nth(self.idx % LETTERS.len()).unwrap(); + let c = self.idx / LETTERS.len(); - None + self.idx += 1; + + Some(l.to_string().repeat(c + 1)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn iter() { + let mut i = LettersIter::new(); + + assert_eq!(i.next(), Some("A".to_string())); + assert_eq!(i.next(), Some("B".to_string())); + assert_eq!(i.nth(23), Some("Z".to_string())); + assert_eq!(i.next(), Some("AA".to_string())); + assert_eq!(i.next(), Some("BB".to_string())); } } diff --git a/effers-derive/src/lib.rs b/effers-derive/src/lib.rs index 63aae2f..cfe00b1 100644 --- a/effers-derive/src/lib.rs +++ b/effers-derive/src/lib.rs @@ -1,9 +1,13 @@ use convert_case::{Case, Casing}; +use lette::LettersIter; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::token::{Mut, SelfValue}; +use syn::token::{Dot, Mut, SelfValue}; use syn::visit_mut::VisitMut; -use syn::{parse_macro_input, Expr, ExprCall, FnArg, Ident, ItemFn, PathSegment, Receiver}; +use syn::{ + parse_macro_input, Expr, ExprCall, ExprField, FnArg, Ident, Index, ItemFn, Member, PathSegment, + QSelf, Receiver, +}; mod parse; use parse::Args; @@ -17,6 +21,7 @@ pub fn program( item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { let item = parse_macro_input!(item as syn::ItemFn); + // dbg!(&item); let mut args = parse_macro_input!(attr as Args); if args.name.is_none() { @@ -152,7 +157,7 @@ impl<'a> syn::visit_mut::VisitMut for FuncRewriter<'a> { // check if the function name is in args // if it is, replace it with the correct name if let Expr::Path(path) = &mut *node.func { - for (i, effect) in self.args.effects.iter().enumerate() { + for ((i, effect), l) in self.args.effects.iter().enumerate().zip(LettersIter::new()) { for func in &effect.functions { let ident = func.alias.clone().unwrap_or(func.ident.clone()); if path.path.is_ident(&ident) { @@ -163,15 +168,58 @@ impl<'a> syn::visit_mut::VisitMut for FuncRewriter<'a> { arguments: syn::PathArguments::None, }); + let span = [Span::call_site()]; + path.path = effect_path; - // then change the parameters so the handler is the first - // get the effect's index, and add the inverse num of `.0`s - let idx = eff_len - (i + 1); - let m = if func.mut_token.is_some() { "mut " } else { "" }; - let s = format!("&{}self{}.1", m, ".0".repeat(idx)); - let expr: Expr = syn::parse_str(&s).unwrap(); - node.args.insert(0, expr); + // qualify the trait se we get + // ::print + // instead of Printer::print + let ty: syn::Type = syn::parse_str(&l).unwrap(); + path.qself = Some(QSelf { + lt_token: syn::token::Lt { + spans: span.clone(), + }, + ty: Box::new(ty), + position: path.path.segments.len() - 1, + as_token: Some(syn::token::As { span: span[0] }), + gt_token: syn::token::Gt { spans: span }, + }); + + // if the effect function takes a self, add it to the list of params + if let Some(mut expr) = func.self_reference.clone() { + // then change the parameters so the handler is the first + // get the effect's index, and add the inverse num of `.0`s + let idx = eff_len - (i + 1); + + for _ in 0..idx { + expr = Expr::Field(ExprField { + attrs: vec![], + base: Box::new(expr), + dot_token: Dot { + spans: [Span::call_site()], + }, + member: Member::Unnamed(Index { + index: 0, + span: Span::call_site(), + }), + }); + } + + expr = Expr::Field(ExprField { + attrs: vec![], + base: Box::new(expr), + dot_token: Dot { + spans: [Span::call_site()], + }, + member: Member::Unnamed(Index { + index: 1, + span: Span::call_site(), + }), + }); + + node.args.insert(0, expr); + } } } } diff --git a/effers-derive/src/parse.rs b/effers-derive/src/parse.rs index 814dc31..6bb0d82 100644 --- a/effers-derive/src/parse.rs +++ b/effers-derive/src/parse.rs @@ -1,7 +1,7 @@ use syn::parse::{Parse, ParseStream, Result}; use syn::punctuated::Punctuated; use syn::token::Paren; -use syn::{parenthesized, Ident, Path, Token}; +use syn::{parenthesized, Expr, ExprPath, ExprReference, Ident, Path, Token}; #[derive(Debug)] pub struct Args { @@ -21,7 +21,7 @@ pub struct Effect { pub struct EffectFunction { pub ident: Ident, pub alias: Option, - pub mut_token: Option, + pub self_reference: Option, } impl Parse for Args { @@ -68,14 +68,29 @@ impl Parse for Effect { } impl Parse for EffectFunction { fn parse(input: ParseStream) -> Result { - let mut_token: Option = if input.peek(Token![mut]) { - input.parse()? + let ident = input.parse()?; + + let self_reference: Option = if input.peek(Paren) { + let content; + parenthesized!(content in input); + + // &mut self + if content.peek(Token![&]) && content.peek2(Token![mut]) && content.peek3(Token![self]) + { + Some(Expr::Reference(content.parse::()?)) + } else + // &self + if content.peek(Token![&]) && content.peek2(Token![self]) { + Some(Expr::Reference(content.parse::()?)) + } else if content.peek(Token![self]) { + Some(Expr::Path(content.parse::()?)) + } else { + None + } } else { None }; - let ident = input.parse()?; - let alias: Option = if input.peek(Token![as]) { input.parse::()?; input.parse()? @@ -86,7 +101,7 @@ impl Parse for EffectFunction { Ok(EffectFunction { ident, alias, - mut_token, + self_reference, }) } } diff --git a/examples/main.rs b/examples/main.rs index 24cda27..16f0b2a 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -1,8 +1,10 @@ use effers::program; -#[program(Smth => Printer(print as p), Logger(mut debug, mut info))] +#[program(Smth => Printer(print(&self) as p, check as check_printer), Logger(debug(&mut self), info(self)))] fn smth(val: u8) -> u8 { - p("hey hi hello"); + if check_printer() { + p("hey hi hello"); + } debug("this is a debug-level log"); info("this is a info-level log"); @@ -10,7 +12,7 @@ fn smth(val: u8) -> u8 { val + 3 } -#[program(Printer(mut print as p))] +#[program(Printer(print(&self) as p))] fn other_program() { p("hey hi hello"); } @@ -33,10 +35,11 @@ fn main() { trait Printer { fn print(&self, s: &str); + fn check() -> bool; } trait Logger { fn debug(&mut self, s: &str); - fn info(&mut self, s: &str); + fn info(self, s: &str); } struct IoPrinter; @@ -44,6 +47,9 @@ impl Printer for IoPrinter { fn print(&self, s: &str) { println!("{}", s) } + fn check() -> bool { + true + } } struct FileLogger; @@ -51,7 +57,7 @@ impl Logger for FileLogger { fn debug(&mut self, s: &str) { println!("debug: {}", s) } - fn info(&mut self, s: &str) { + fn info(self, s: &str) { println!("info: {}", s) } } @@ -66,7 +72,7 @@ impl Logger for NetworkLogger { s, self.credentials ) } - fn info(&mut self, s: &str) { + fn info(self, s: &str) { println!( "info through network: {}; with password {}", s, self.credentials diff --git a/examples/path.rs b/examples/path.rs index 23afb9b..313cb0e 100644 --- a/examples/path.rs +++ b/examples/path.rs @@ -1,7 +1,7 @@ use effers::program; // Effects can be referenced from inside a module -#[program(inc::Incrementer(increment))] +#[program(inc::Incrementer(increment(&self)))] fn prog(val: u8) -> u8 { let x = increment(val); let y = increment(x); @@ -11,10 +11,15 @@ fn prog(val: u8) -> u8 { mod inc { pub trait Incrementer { fn increment(&self, v: u8) -> u8; + + fn check() -> bool; } pub struct TestInc; impl Incrementer for TestInc { + fn check() -> bool { + false + } fn increment(&self, v: u8) -> u8 { v + 3 } diff --git a/src/lib.rs b/src/lib.rs index ca2854a..b6e8ef2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,12 @@ pub use effers_derive::program; mod test { use super::*; - #[program(Smth => Printer(print as p), Logger(mut debug, mut info), inc::Incrementer(mut increment))] + #[program(Smth => Printer(print(&self) as p), Logger(debug(self), info(&mut self)), inc::Incrementer(increment))] fn smth(val: u8) -> u8 { let s = p("hey hi hello"); - debug("this is a debug-level log"); info("this is a info-level log"); + debug("this is a debug-level log"); let _s = p("hey hi hello"); @@ -24,16 +24,15 @@ mod test { fn print(&self, s: &str) -> &str; } trait Logger { - fn debug(&mut self, s: &str); + fn debug(self, s: &str); fn info(&mut self, s: &str); } mod inc { pub trait Incrementer { - fn increment(&mut self, v: u8) -> u8; + fn increment(v: u8) -> u8; } } - // TODO make nameless programs work #[program(Printer(print as p))] fn ohter() { let _s = p("hey hi hello");