module RedBlack where -- Note: these imports are only used for pretty printing import Data.List (isSuffixOf) import qualified Data.Tree as T data Color = Red | Black deriving (Eq,Show) -- | A red-black tree. Leaves are assumed to be black. data Tree a = Node Color a (Tree a) (Tree a) | Leaf deriving (Eq,Show) -- | Does the tree contain a particular element? O(log n) contains :: Ord a => a -> Tree a -> Bool contains x Leaf = False contains x (Node _ y l r) | x == y = True | x < y = contains x l | otherwise = contains x r -- | Insert an element into the tree. O(log n) -- By inserting a red node, we ensure by construction that the -- balanced-black invariant will be preserved, so we just have to check -- for red-red violations. insert :: Ord a => a -> Tree a -> Tree a insert x = setBlack . ins x where setBlack (Node _ y l r) = Node Black y l r ins x Leaf = Node Red x Leaf Leaf ins x (Node c y l r) | x <= y = checkL (Node c y (ins x l) r) | otherwise = checkR (Node c y l (ins x r)) -- | Construct a balanced node as part of rebalancing. -- Refer to the "Rebalancing" slide for details. balance :: a -> a -> a -> Tree a -> Tree a -> Tree a -> Tree a -> Tree a balance x y z a b c d = Node Red y (Node Black x a b) (Node Black z c d) -- | Check for a red-red violation on the left branch -- and rebalance if necessary. checkL :: Tree a -> Tree a checkL (Node Black z (Node Red y (Node Red x a b) c) d) = balance x y z a b c d checkL (Node Black z (Node Red x a (Node Red y b c)) d) = balance x y z a b c d checkL n = n -- | Check for a red-red violation on the right branch -- and rebalance if necessary. checkR :: Tree a -> Tree a checkR (Node Black x a (Node Red y b (Node Red z c d))) = balance x y z a b c d checkR (Node Black x a (Node Red z (Node Red y b c) d)) = balance x y z a b c d checkR n = n -- -- * Pretty printing -- pretty :: Show a => Tree a -> IO () pretty = putStrLn . condense . T.drawTree . toDataTree toDataTree :: Show a => Tree a -> T.Tree String toDataTree Leaf = T.Node "•" [] toDataTree (Node c a l r) = T.Node (label c) [toDataTree l, toDataTree r] where label Red = "R: " ++ show a label Black = "B: " ++ show a condense :: String -> String condense = unlines . {- filter notLeaf . -} filter notEmpty . lines where notEmpty = not . all (\c -> c == ' ' || c == '|') notLeaf = not . isSuffixOf "•"