{-# LANGUAGE GADTs #-} module RedBlackGADT where -- Note: these imports are only used for pretty printing import Data.List (intercalate,isSuffixOf) import qualified Data.Tree as T data Red data Black -- | A red-black tree that statically enforces that there are no -- red-red violations. data Tree c a where Leaf :: Tree Black a NodeB :: a -> Tree cl a -> Tree cr a -> Tree Black a NodeR :: a -> Tree Black a -> Tree Black a -> Tree Red a -- | Does the tree contain a particular element? O(log n) contains :: Ord a => a -> Tree c a -> Bool contains x Leaf = False contains x (NodeB y l r) = containsHelp x y l r contains x (NodeR y l r) = containsHelp x y l r -- | Helper function for contains, since we have two kinds of nodes. containsHelp :: Ord a => a -> a -> Tree c1 a -> Tree c2 a -> Bool containsHelp x y l r | x == y = True | x < y = contains x l | otherwise = contains x r -- | Insert an element into the tree. O(log n) insert :: Ord a => a -> Tree Black a -> Tree Black a insert x t = toTreeB (insertHelp x t) -- | Helper function for insert. insertHelp :: Ord a => a -> Tree c a -> Temp a insertHelp x Leaf = TempR x Leaf Leaf insertHelp x (NodeB y l r) | x <= y = checkL y (insertHelp x l) r | otherwise = checkR y l (insertHelp x r) insertHelp x (NodeR y l r) | x <= y = TempR y (toTreeR (insertHelp x l)) r | otherwise = TempR y l (toTreeR (insertHelp x r)) -- | A data type for representing temporary nodes, -- which may violate the invariants. data Temp a where TempB :: a -> Tree c1 a -> Tree c2 a -> Temp a TempR :: a -> Tree c1 a -> Tree c2 a -> Temp a -- | Convert from a temporary red node to a permanent red node. toTreeR :: Temp a -> Tree Red a toTreeR (TempR y Leaf Leaf) = NodeR y Leaf Leaf toTreeR (TempR y Leaf (NodeB z c d)) = NodeR y Leaf (NodeB z c d) toTreeR (TempR y (NodeB x a b) Leaf) = NodeR y (NodeB x a b) Leaf toTreeR (TempR y (NodeB x a b) (NodeB z c d)) = NodeR y (NodeB x a b) (NodeB z c d) -- | Convert a temporary node of either color to a permandent black node. -- Only to be applied to the root of a tree. toTreeB :: Temp a -> Tree Black a toTreeB (TempB y l r) = NodeB y l r toTreeB (TempR y l r) = NodeB y l r -- | Construct a balanced node as part of rebalancing. -- Refer to the "Rebalancing" slide for details. balance :: a -> a -> a -> Tree c1 a -> Tree c2 a -> Tree c3 a -> Tree c4 a -> Temp a balance x y z a b c d = TempR y (NodeB x a b) (NodeB z c d) -- | Check for a Red-Red violation on the left branch -- and rebalance if necessary. checkL :: a -> Temp a -> Tree c a -> Temp a checkL z (TempR y (NodeR x a b) c) d = balance x y z a b c d checkL z (TempR x a (NodeR y b c)) d = balance x y z a b c d checkL z (TempB x a b) d = TempB z (NodeB x a b) d -- wish we had dependent types here! -- checkL z (TempR x a b) d = TempB z (NodeR x a b) d checkL z (TempR x a@Leaf b@Leaf) d = TempB z (NodeR x a b) d checkL z (TempR x a@Leaf b@(NodeB _ _ _)) d = TempB z (NodeR x a b) d checkL z (TempR x a@(NodeB _ _ _) b@Leaf) d = TempB z (NodeR x a b) d checkL z (TempR x a@(NodeB _ _ _) b@(NodeB _ _ _)) d = TempB z (NodeR x a b) d -- | Check for a Red-Red violation on the right branch -- and rebalance if necessary. checkR :: a -> Tree c a -> Temp a -> Temp a checkR x a (TempR y b (NodeR z c d)) = balance x y z a b c d checkR x a (TempR z (NodeR y b c) d) = balance x y z a b c d checkR x a (TempB z b c) = TempB x a (NodeB z b c) -- wish we had dependent types here! -- checkR x a (TempR z b c) = TempB x a (NodeR z b c) checkR x a (TempR z b@Leaf c@Leaf) = TempB x a (NodeR z b c) checkR x a (TempR z b@Leaf c@(NodeB _ _ _)) = TempB x a (NodeR z b c) checkR x a (TempR z b@(NodeB _ _ _) c@Leaf) = TempB x a (NodeR z b c) checkR x a (TempR z b@(NodeB _ _ _) c@(NodeB _ _ _)) = TempB x a (NodeR z b c) -- -- * Pretty printing -- instance Show a => Show (Tree c a) where show Leaf = "Leaf" show (NodeB x l r) = intercalate " " ["(NodeB", show x, show l, show r ++ ")"] show (NodeR x l r) = intercalate " " ["(NodeR", show x, show l, show r ++ ")"] pretty :: Show a => Tree c a -> IO () pretty = putStrLn . condense . T.drawTree . toDataTree toDataTree :: Show a => Tree c a -> T.Tree String toDataTree Leaf = T.Node "•" [] toDataTree (NodeB a l r) = T.Node ("B: " ++ show a) [toDataTree l, toDataTree r] toDataTree (NodeR a l r) = T.Node ("R: " ++ show a) [toDataTree l, toDataTree r] condense :: String -> String condense = unlines . {- filter notLeaf . -} filter notEmpty . lines where notEmpty = not . all (\c -> c == ' ' || c == '|') notLeaf = not . isSuffixOf "•"