Last active
June 18, 2022 13:49
-
-
Save sjoerdvisscher/03299f8b0c8f208f6239bc75a42b004e to your computer and use it in GitHub Desktop.
Alternative ways to write the linear quicksort from https://www.tweag.io/blog/2021-02-10-linear-base/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LinearTypes #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE BlockArguments #-} | |
{-# LANGUAGE NoImplicitPrelude #-} | |
{-# LANGUAGE QualifiedDo #-} | |
import Prelude.Linear hiding (partition) | |
import qualified Control.Functor.Linear as Linear | |
import Control.Functor.Linear (State, state, execState, pure) | |
import qualified Data.Array.Mutable.Linear as Array | |
import Data.Array.Mutable.Linear (Array) | |
quickSort :: [Int] -> [Int] | |
quickSort xs = unur $ Array.fromList xs $ Array.toList . arrQuicksort | |
arrQuicksort :: Array Int %1 -> Array Int | |
arrQuicksort = execState Linear.do | |
Ur len <- state Array.size | |
go 0 (len-1) | |
go :: Int -> Int -> State (Array Int) () | |
go lo hi = | |
if (lo >= hi) then pure () else Linear.do | |
Ur pivot <- readST lo | |
Ur ix <- partition pivot lo hi | |
swap lo ix | |
go lo (ix-1) | |
go (ix+1) hi | |
partition :: Int -> Int -> Int -> State (Array Int) (Ur Int) | |
partition pivot lx rx | |
| (rx < lx) = pure $ Ur (lx-1) | |
| otherwise = Linear.do | |
Ur lVal <- readST lx | |
Ur rVal <- readST rx | |
case (lVal <= pivot, pivot < rVal) of | |
(True, True) -> partition pivot (lx+1) (rx-1) | |
(True, False) -> partition pivot (lx+1) rx | |
(False, True) -> partition pivot lx (rx-1) | |
(False, False) -> Linear.do | |
swap lx rx | |
partition pivot (lx+1) (rx-1) | |
swap :: Int -> Int -> State (Array Int) () | |
swap i j = Linear.do | |
Ur ival <- readST i | |
Ur jval <- readST j | |
setST i jval | |
setST j ival | |
readST :: Int -> State (Array a) (Ur a) | |
readST i = state (\arr -> Array.read arr i) | |
setST :: Int -> a -> State (Array a) () | |
setST i val = state (\arr -> ((), Array.set i val arr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LinearTypes #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE BlockArguments #-} | |
import Prelude (Functor(..), Applicative(..), Monad(..)) | |
import Control.Monad (when) | |
import Prelude.Linear hiding (partition) | |
import qualified Data.Array.Mutable.Linear as Array | |
import Data.Array.Mutable.Linear (Array) | |
import qualified Control.Functor.Linear as Linear | |
newtype UrT m a = UrT (m (Ur a)) | |
instance Linear.Functor m => Functor (UrT m) where | |
fmap f (UrT ma) = UrT (Linear.fmap (\(Ur a) -> Ur (f a)) ma) | |
instance Linear.Applicative m => Applicative (UrT m) where | |
pure a = UrT (Linear.pure (Ur a)) | |
UrT mf <*> UrT ma = UrT (Linear.liftA2 (\(Ur f) (Ur a) -> Ur (f a)) mf ma) | |
instance Linear.Monad m => Monad (UrT m) where | |
UrT ma >>= f = UrT (ma Linear.>>= (\(Ur a) -> case f a of (UrT mb) -> mb)) | |
type State s = UrT (Linear.State s) | |
state :: (s %1 -> (Ur a, s)) -> State s a | |
state f = UrT (Linear.state f) | |
execState :: State s () %1 -> s %1 -> s | |
execState (UrT f) s = Linear.execState (Linear.fmap unur f) s | |
quickSort :: [Int] -> [Int] | |
quickSort xs = unur $ Array.fromList xs $ Array.toList . arrQuicksort | |
arrQuicksort :: Array Int %1 -> Array Int | |
arrQuicksort = execState do | |
len <- state Array.size | |
go 0 (len-1) | |
go :: Int -> Int -> State (Array Int) () | |
go lo hi = | |
when (lo < hi) do | |
pivot <- readST lo | |
ix <- partition pivot lo hi | |
swap lo ix | |
go lo (ix-1) | |
go (ix+1) hi | |
partition :: Int -> Int -> Int -> State (Array Int) Int | |
partition pivot lx rx | |
| (rx < lx) = pure (lx-1) | |
| otherwise = do | |
lVal <- readST lx | |
rVal <- readST rx | |
case (lVal <= pivot, pivot < rVal) of | |
(True, True) -> partition pivot (lx+1) (rx-1) | |
(True, False) -> partition pivot (lx+1) rx | |
(False, True) -> partition pivot lx (rx-1) | |
(False, False) -> do | |
swap lx rx | |
partition pivot (lx+1) (rx-1) | |
swap :: Int -> Int -> State (Array Int) () | |
swap i j = do | |
ival <- readST i | |
jval <- readST j | |
setST i jval | |
setST j ival | |
readST :: Int -> State (Array a) a | |
readST i = state (\arr -> Array.read arr i) | |
setST :: Int -> a -> State (Array a) () | |
setST i val = state (\arr -> (Ur (), Array.set i val arr)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment