From bf85419c16ba7b31507948cb8aeade1121bec198 Mon Sep 17 00:00:00 2001 From: annieversary Date: Thu, 11 Nov 2021 14:47:55 +0000 Subject: [PATCH] fix prisms --- src/combinations.rs | 21 +++++++++++ src/fns.rs | 63 ++++++++++++++++++++++++++++++++ src/prisms/mod.rs | 69 ++++++++++++++++++++++++----------- src/prisms/option.rs | 12 ++---- src/prisms/result.rs | 14 ++----- src/traversals/mod.rs | 85 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 223 insertions(+), 41 deletions(-) diff --git a/src/combinations.rs b/src/combinations.rs index 9fd4a10..3309faf 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -1,5 +1,6 @@ use crate::{ lenses::{Lens, LensOver, LensView}, + prisms::Prism, traversals::{Traversal, TraversalOver, TraversalTraverse}, }; @@ -8,6 +9,7 @@ pub struct Combination(A, B); // additions +// lens + lens impl std::ops::Add> for Lens { type Output = Lens, Lens>>; @@ -15,6 +17,7 @@ impl std::ops::Add> for Lens { Lens(Combination(self, rhs)) } } +// traversal + traversal impl std::ops::Add> for Traversal { type Output = Traversal, Traversal>>; @@ -22,6 +25,7 @@ impl std::ops::Add> for Traversal { Traversal(Combination(self, rhs)) } } +// traversal + lens impl std::ops::Add> for Traversal { type Output = Traversal, Traversal>>>; @@ -29,6 +33,7 @@ impl std::ops::Add> for Traversal { Traversal(Combination(self, rhs.to_traversal())) } } +// lens + traversal impl std::ops::Add> for Lens { type Output = Traversal>, Traversal>>; @@ -36,6 +41,22 @@ impl std::ops::Add> for Lens { Traversal(Combination(self.to_traversal(), rhs)) } } +// traversal + prism +impl std::ops::Add> for Traversal { + type Output = Traversal, Traversal>>>; + + fn add(self, rhs: Prism) -> Self::Output { + Traversal(Combination(self, rhs.to_traversal())) + } +} +// prism + traversal +impl std::ops::Add> for Prism { + type Output = Traversal>, Traversal>>; + + fn add(self, rhs: Traversal) -> Self::Output { + Traversal(Combination(self.to_traversal(), rhs)) + } +} // trait impls for Combination diff --git a/src/fns.rs b/src/fns.rs index 88b0cc7..1f1f3c9 100644 --- a/src/fns.rs +++ b/src/fns.rs @@ -1,5 +1,6 @@ use crate::{ lenses::{Lens, LensOver, LensView}, + prisms::{Prism, PrismPreview}, traversals::{Traversal, TraversalOver, TraversalTraverse}, }; @@ -120,3 +121,65 @@ where L::over(&self.0, args.0, args.1) } } + +// prism preview +impl std::ops::FnOnce<(A,)> for Prism +where + L: PrismPreview, +{ + type Output = Option; + + extern "rust-call" fn call_once(self, args: (A,)) -> Self::Output { + L::preview(&self.0, args.0) + } +} +impl std::ops::FnMut<(A,)> for Prism +where + L: PrismPreview, +{ + extern "rust-call" fn call_mut(&mut self, args: (A,)) -> Self::Output { + L::preview(&self.0, args.0) + } +} +impl std::ops::Fn<(A,)> for Prism +where + L: PrismPreview, +{ + extern "rust-call" fn call(&self, args: (A,)) -> Self::Output { + L::preview(&self.0, args.0) + } +} + +// prism over +impl std::ops::FnOnce<(A, F)> for Prism +where + A: Clone, + L: PrismPreview, + F: FnMut(L::Field) -> L::Field, +{ + type Output = A; + + extern "rust-call" fn call_once(self, args: (A, F)) -> Self::Output { + L::over(&self.0, args.0, args.1) + } +} +impl std::ops::FnMut<(A, F)> for Prism +where + A: Clone, + L: PrismPreview, + F: FnMut(L::Field) -> L::Field, +{ + extern "rust-call" fn call_mut(&mut self, args: (A, F)) -> Self::Output { + L::over(&self.0, args.0, args.1) + } +} +impl std::ops::Fn<(A, F)> for Prism +where + A: Clone, + L: PrismPreview, + F: FnMut(L::Field) -> L::Field, +{ + extern "rust-call" fn call(&self, args: (A, F)) -> Self::Output { + L::over(&self.0, args.0, args.1) + } +} diff --git a/src/prisms/mod.rs b/src/prisms/mod.rs index c89bc36..536ac1c 100644 --- a/src/prisms/mod.rs +++ b/src/prisms/mod.rs @@ -12,10 +12,28 @@ pub struct Prism

(pub(crate) P); pub trait PrismPreview { type Field; - fn preview(thing: T) -> Option; -} -pub trait PrismReview: PrismPreview { - fn review(thing: Self::Field) -> T; + fn preview(&self, thing: T) -> Option; + fn review(&self, thing: Self::Field) -> T; + // TODO id like for this to not need clone + fn over(&self, thing: T, f: F) -> T + where + F: FnOnce(Self::Field) -> Self::Field, + T: Clone, + { + if let Some(a) = Self::preview(&self, thing.clone()) { + Self::review(&self, f(a)) + } else { + thing + } + } + + fn set(&self, thing: T, v: Self::Field) -> T + where + T: Clone, + Self::Field: Clone, + { + Self::over(self, thing, move |_| v.clone()) + } } impl PrismPreview for Prism

@@ -24,25 +42,27 @@ where { type Field = P::Field; - fn preview(thing: T) -> Option { - P::preview(thing) + fn preview(&self, thing: T) -> Option { + P::preview(&self.0, thing) + } + + fn review(&self, thing: Self::Field) -> T { + P::review(&self.0, thing) } } -impl PrismReview for Prism

-where - P: PrismReview, -{ - fn review(thing: Self::Field) -> T { - P::review(thing) - } +pub fn preview>(prism: P, thing: T) -> Option { + P::preview(&prism, thing) } - -pub fn preview>(_prism: P, thing: T) -> Option { - P::preview(thing) +pub fn review>(prism: P, thing: P::Field) -> T { + P::review(&prism, thing) } -pub fn review>(_prism: P, thing: P::Field) -> T { - P::review(thing) +pub fn over>( + prism: P, + thing: T, + f: impl FnOnce(P::Field) -> P::Field, +) -> T { + P::over(&prism, thing, f) } #[cfg(test)] @@ -52,13 +72,13 @@ mod tests { #[test] fn preview_result() { let a: Result = Ok(3); - assert_eq!(preview(_Ok, a), Some(3)); + assert_eq!(_Ok(a), Some(3)); let a: Result = Err(3); assert_eq!(preview(_Ok, a), None); let a: Result = Ok(3); - assert_eq!(preview(_Err, a), None); + assert_eq!(_Err(a), None); let a: Result = Err(3); assert_eq!(preview(_Err, a), Some(3)); @@ -67,7 +87,7 @@ mod tests { #[test] fn preview_option() { let a = Some(3); - assert_eq!(preview(_Some, a), Some(3)); + assert_eq!(_Some(a), Some(3)); let a = Some(3); assert_eq!(preview(_None, a), Some(())); @@ -90,4 +110,11 @@ mod tests { assert_eq!(review(_Some, 3), Some(3)); assert_eq!(review(_None, ()), None::<()>); } + + #[test] + fn over_option() { + assert_eq!(over(_Some, Some(3), |v| v + 1), Some(4)); + assert_eq!(_Some(Some(3), |v| v + 1), Some(4)); + assert_eq!(over(_None, None, |_v: ()| ()), None::<()>); + } } diff --git a/src/prisms/option.rs b/src/prisms/option.rs index 16e6436..e856542 100644 --- a/src/prisms/option.rs +++ b/src/prisms/option.rs @@ -8,13 +8,11 @@ pub const _Some: Prism = Prism(SomeInner); impl PrismPreview> for SomeInner { type Field = T; - fn preview(thing: Option) -> Option { + fn preview(&self, thing: Option) -> Option { thing } -} -impl PrismReview> for SomeInner { - fn review(thing: Self::Field) -> Option { + fn review(&self, thing: Self::Field) -> Option { Some(thing) } } @@ -26,13 +24,11 @@ pub const _None: Prism = Prism(NoneInner); impl PrismPreview> for NoneInner { type Field = (); - fn preview(_thing: Option) -> Option { + fn preview(&self, _thing: Option) -> Option { Some(()) } -} -impl PrismReview> for NoneInner { - fn review(_thing: Self::Field) -> Option { + fn review(&self, _thing: Self::Field) -> Option { None } } diff --git a/src/prisms/result.rs b/src/prisms/result.rs index 6dd5cdc..d9fd074 100644 --- a/src/prisms/result.rs +++ b/src/prisms/result.rs @@ -7,13 +7,10 @@ pub const _Ok: Prism = Prism(OkInner); impl PrismPreview> for OkInner { type Field = T; - fn preview(thing: Result) -> Option { + fn preview(&self, thing: Result) -> Option { thing.ok() } -} - -impl PrismReview> for OkInner { - fn review(thing: Self::Field) -> Result { + fn review(&self, thing: Self::Field) -> Result { Ok(thing) } } @@ -26,13 +23,10 @@ pub const _Err: Prism = Prism(ErrInner); impl PrismPreview> for ErrInner { type Field = E; - fn preview(thing: Result) -> Option { + fn preview(&self, thing: Result) -> Option { thing.err() } -} - -impl PrismReview> for ErrInner { - fn review(thing: Self::Field) -> Result { + fn review(&self, thing: Self::Field) -> Result { Err(thing) } } diff --git a/src/traversals/mod.rs b/src/traversals/mod.rs index 19043ac..48ded0f 100644 --- a/src/traversals/mod.rs +++ b/src/traversals/mod.rs @@ -4,7 +4,10 @@ pub use both::both; mod each; pub use each::each; -use crate::lenses::{Lens, LensOver, LensView}; +use crate::{ + lenses::{Lens, LensOver, LensView}, + prisms::{Prism, PrismPreview}, +}; /// Wrapper type #[derive(Clone, Copy)] @@ -87,6 +90,43 @@ where } } +// all prisms are traversals, so we can freely transform them into a traversal +impl Prism { + /// Returns this lens as a traversal + pub fn to_traversal(self) -> Traversal> { + Traversal(self) + } +} +// we can go back to a lens from a "traversal-ed" lens +impl Traversal> { + /// Returns the wrapped lens + pub fn to_prism(self) -> Prism { + self.0 + } +} +impl TraversalTraverse for Prism +where + L: PrismPreview, +{ + type Field = L::Field; + + fn traverse(&self, thing: T) -> Vec { + L::preview(&self.0, thing).into_iter().collect() + } +} +impl TraversalOver for Prism +where + T: Clone, + L: PrismPreview, +{ + fn over(&self, thing: T, f: F) -> T + where + F: FnMut(Self::Field) -> Self::Field, + { + L::over(&self.0, thing, f) + } +} + pub fn traverse>(lens: L, thing: T) -> Vec { L::traverse(&lens, thing) } @@ -102,7 +142,7 @@ pub fn over>(lens: L, thing: T, f: impl FnMut(L::Field) - #[cfg(test)] mod tests { - use crate::lenses::_0; + use crate::{lenses::_0, prisms::_Some}; use super::*; @@ -157,4 +197,45 @@ mod tests { let res = t(array, |v| v + 1); assert_eq!(res, [(2, 3), (3, 4), (5, 6)]); } + + #[test] + fn can_combine_prism_with_traversal() { + let array = [Some(1), None, Some(3), None, Some(5)]; + + // combine a traversal with a lens + let t = each + _Some; + + // traverse + let res = t(array); + assert_eq!(res, vec![1, 3, 5]); + + // over + let res = t(array, |v| v + 1); + assert_eq!(res, [Some(2), None, Some(4), None, Some(6)]); + } + + #[test] + fn can_combine_traversal_with_prism() { + let array = Some([1, 2, 3]); + + // combine a traversal with a lens + let t = _Some + each; + + // traverse + let res = t(array); + assert_eq!(res, vec![1, 2, 3]); + + // over + let res = t(array, |v| v + 1); + assert_eq!(res, Some([2, 3, 4])); + + let array: Option<[i32; 3]> = None; + // traverse + let res = t(array); + assert_eq!(res, vec![]); + + // over + let res = t(array, |v| v + 1); + assert_eq!(res, None); + } }