HaskellでCPSのお勉強

モナドの合成をマジメに勉強しようと思って次のページ
All About Monads
に行って、第III部の頭で面食らった人は僕だけではないはず...
Continuationモナドは使わない方が良いって書いてあったのに。
そして、僕はContinuationモナドは知らないんだ。
なので、CPS(継続渡しスタイル)から勉強しなおし。

参考にしたページ
Continuation-passing style - Wikipedia, the free encyclopedia
Haskell/Continuation passing style - Wikibooks, open books for an open world
Haskell で継続渡しスタイル (CPS) | すぐに忘れる脳みそのためのメモ

CPSって何?

CPS(Continuation-Passing Style)ってのはその名の通り、ソースコードの「書き方」の一種。
参考ページなどを見るに、次みたいに書くのがHaskell式のCPSだと思って良さそう。

somefunc_cps x k =
    f1_cps x $ \result1 -> -- xに対する最初の処理
    f2_cps x result1 $ \result2 -> -- xと最初の処理の結果に対する処理
    {-- ... --}
    fn_cps x result1 {-- ... --} resultn' \resultn ->
    k resultn -- 処理を終了

ただし、f1_cps...fn_cpsは全て

f1_cps x k = k $ f1 x -- xに対するなんか処理
f2_cps x y k = k $ f2 x y
{-- ... --}
fn_cps x y z k = k $ fn x y z

みたいな形で、このように最後の引数に関数を取って、処理の結果をその関数に渡して処理を終了するような関数をCPS関数と言う。
最後の引数に処理の結果を渡すというところがミソで、そこに次にして欲しい処理を記述していくことで関数を書くのがCPS。

注意してみると、上述のCPS関数の型は

f1_cps :: a1 -> (b->c) -> c
f2_cps :: a1 -> a2 -> (b->c) -> c
{-- ... --}
fn_cps :: a1 -> a2 -> {-- ... --} an -> (b->c) -> c

であり、末尾は皆同じ形をしている。
そこで、以下ではこの共通の型にContと名をつける。

type Cont a b = (b->a) -> a

2つの型a, bの順番は、後でこれがモナドになる時の都合。
すると、CPS関数とは、結果としてCont a b型を返す関数であるということもできる。

例えば、Fibonacci数を計算するような関数は次のように書ける。
書き方を工夫して、いかにもCont a b 型を返しているように見せる。

add_cps :: Num a => a -> a -> Cont b a
add_cps x y = \k -> k $ x+y

fib_cps :: Integer -> Cont b Integer
fib_cps 0 = \k -> k 1
fib_cps 1 = \k -> k 1
fib_cps n = \k ->
    fib_cps (n-1) $ \x ->
    fib_cps (n-2) $ \y ->
    add_cps x y $ \z ->
    k z

CPSで疑似手続き型プログラミング

上の例でわかるように、CPSで書くと手続き型っぽくなって、Haskellぽくない。
しかし、これは逆に言えば、純粋関数型言語であるHaskellにおいて、手続き型言語っぽい記述ができるということで、これはもしかしたら利点なのかも知れない。
そこで、手続き型に特有(?)の次の制御文を実現してみる。

分岐

値としての分岐ではなく、処理の分岐。
ある条件を満たす時のみ処理をしたい、という場合、それが最後の処理ならばHaskellの普通のif分を用いれば良いのだが、その後に共通の処理があったりする時には、それではちょっと汚く見える。
そもそも、Haskellのifは、「処理」ではなく「値」の分岐なので、手続き型言語のifとは趣が異なる。
そこで、「処理」の分岐をCPSで実現する。

when' :: Bool -> Cont b a -> a -> Cont b a
when' False _ x = \k -> k x
when' True f x = \k -> f x k

例えば、FizzBuzz問題(Ruby Quiz - FizzBuzz (#126))は次のように解ける。

fizzbuzz_cps :: Integer -> Cont a String
fizzbuzz_cps n = \k ->
    when' (mod n 3==0) (\k'->k' "Fizz") "" $ \x->
    when' (mod n 5==0) (\k'->k' "Buzz") "" $ \y ->
    when' (length (x++y)==0) (\k'->k' $ show n) (x++y) -> \z
    k z

ついでに、継続する処理に渡す値の無い場合のwhen_

when_ :: Bool -> Cont b () -> Cont b ()
when_ False _ = \k -> k ()
when_ True f = f
中断

途中で処理をやめてしまいたい時の構文。
例として、整数を取ってその奇数因子を返す関数をCPSで書く。
最後の引数に処理を「渡さなけれ」ば、処理は進行しないという単純な発想で、試しに書いてみる。

exit_proto :: b -> Cont b a
exit_proto x = \_ -> x

しかし、これでうっかり次のように書いてしまったとする。

oddfact_cps n = \k ->
    when_ (mod n 2/=0) (exit_proto n) $ \_ ->
    oddfact_cps (floor $ fromIntegral n / 2) $ \x ->
    k x

一見うまく書けていて、うまく動いているように見える。
しかし、この関数を他の関数からCPSで呼ぶと問題が起きる。
例えば、正多角形が作図可能か判定する関数。
Gaussにより、正n角形が作図可能であることと、nがフェルマー素数の2のべき倍であることが同値であることが証明されているので、奇数因子がフェルマー素数かどうかを判定すれば良い。

isPrime n = foldl (&&) True $ map ((/=0).mod n) [2..n-1]
isPow2 n = n==head (snd$break (>=n) [2^l|l<-[1..]])
isFermatPrime_cps n = \k -> k $ isPrime n && isPow2 (n-1)

constructible_cps n = \k ->
    oddfact_cps n $ \x ->
    isFermatPrime_cps x $ \y ->
    k y

oddfact_cpsが前述のような定義になっている限り、この関数は正しく動作しない。
その原因は、oddfact_cpsの定義の2行目、exit_protoの引数にある。
良く見ると、CPSの条件である「結果を最後の引数に渡して終わる」を満さず、生の値を返していることがわかる。
正しくは、

oddfact_cps n = \k ->
    when_ (mod n 2/=0) (exit_proto (k n)) $ \_ -> -- exit_protoの引数を修正
    oddfact_cps (floor $ fromIntegral n / 2) $ \x ->
    k x

とすれば良い。

ところで、このようにkを適用することがわかりきっているのに、わざわざk nなどと書かなければいけないのは面倒臭い上に、上で見たようなほとんどの場合、コンパイルレベルで排除できないので、バグの温床になる。
わざわざexit_proto (k n)のように書かなくて良いためには、kもexit_protoの引数にしてしまうという手も考えられるが、それでは冗長であることに変わりはない。
そこで、kを渡す責任をCPS関数の外部に委ねてしまえれば良い。
その変わり、関数は自動的にkを適用するような「出口」を引数として受けとるようにする。
呼び出し元は、継続先kを知っているわけだから、出口として次のラムダ式をわたしてやれば、上述のexitのように機能する。

\x -> \_ -> k x

この方針で、oddfact_cpsを本体oddfact_cps_coreとそれを呼び出すcallOddfact_cpsに分割する。

callOddfact_cps :: Integer -> Cont b Integer
callOddfact_cps n = \k -> oddfact_cps_core n (\x -> \_ -> k x) k

oddfact_cps_core :: Integer -> (Integer->Cont b ()) -> Cont b Integer
oddfact_cps_core n exit = \k ->
    when_ (mod n 2/=0) (exit n) $ \_ ->
    callOddfact_cps (floor $ fromIntegral n / 2) $ \x ->
    k x

ここまで来ると、callOddfact_cpsの定義式を一般の場合に拡張するのは用意である。
中断を含む処理を実行する関数callCont_cpsは次のように定義できる。

callCont_cps :: ((a->Cont b c)->Cont b a)->Cont b c
callCont_cps f = \k -> f (\x -> \_ -> k x) k

この関数は次のように用いる。

oddfact_cps' :: Integer -> Cont b Integer
oddfact_cps' n = callCont_cps $ \exit -> \k ->
    when_ (mod n 2/=0) (exit n) $ \_ ->
    oddfact_cps' (floor $ fromIntegral n / 2) $ \x ->
    k x

継続はモナドなり

さて、CPSの書き方を見た瞬間に、モナドのdo構文に似てる!と思った人も少なくないはず。
そこで、Contにモナドインスタンスを与えてみる。

まず、上のままのContの定義でモナドインスタンスを与えようとするとコンパイラに怒られるので修正

newtype Cont a b = Cont {runCont :: (b->a)->a}

そして、モナドインスタンスを次のように与える。

instance Monad (Cont a) where
    (Cont f) >>= g = Cont $ \k -> f $ \x -> runCont (g x) k
    return x = Cont $ \k -> k x

変更したこれらの定義で、上に例示してある関数をいくつか書き換えてみる。
ただし、この場合、when_はControl.Monadで定義されているwhenと全く同等なので、それで置き換える。

module CPS where

import Control.Monad

newtype Cont a b = Cont {runCont :: (b->a) -> a}

instance Monad (Cont a) where
    (Cont f) >>= g = Cont $ \k -> f $ \x -> runCont (g x) k
    return x = Cont $ \k -> k x

when' :: Bool -> Cont a b -> b -> Cont a b
when' True _ x = return x
when' False f _ = f

callCont_cps :: ((a->Cont b c)->Cont b a)->Cont b a
callCont_cps f = Cont $ \k -> runCont (f (\x-> Cont $ \_ -> k x)) k

fib_cps :: Integer -> Cont a Integer
fib_cps 0 = return 1
fib_cps 1 = return 1
fib_cps n = do
    x <- fib_cps (n-1)
    y <- fib_cps (n-2)
    return (x+y)

fib_cps' n = callCont_cps $ \exit -> do
    when (n==0) (exit 1)
    when (n==1) (exit 1)
    x <- fib_cps' (n-1)
    y <- fib_cps' (n-2)
    return (x+y)

fizzbuzz_cps :: Integer -> Cont a String
fizzbuzz_cps n = do
    x <- when' (mod n 3==0) (return "Fizz") ""
    y <- when' (mod n 5==0) (return "Buzz") ""
    z <- when' (length (x++y)==0) (return $ show n) (x++y)
    return z

oddfact_cps :: Integer -> Cont b Integer
oddfact_cps n = callCont_cps $ \exit -> do
    when (mod n 2/=0) (exit n)
    x <- oddfact_cps (floor $ fromIntegral n / 2)
    return x

isPrime :: Integer -> Bool
isPrime n = foldl (&&) True $ map ((/=0).mod n) [2..n-1]
isPow2 n = n==head (snd$break (>=n) [2^l|l<-[1..]])
isFermatPrime_cps n = Cont $ \k -> k $ isPrime n && isPow2 (n-1)

-- is constructible?
constructible_cps :: Integer -> Cont b Bool
constructible_cps n = do
    x <- oddfact_cps n
    y <- isFermatPrime_cps x
    return y

Control.Monad.Contには、Cont a bが定義されており、また、callCont_cpsと同等な関数callCCも定義されている。

以上

これでContinuationモナドは理解できた気になれました。
しかし、CPSで書かれた上のコードを見ても、一瞬Haskellかどうかわからない。
上の例でも特に、fib_cps'とfizzbuzz_cpsは極めつけだと思う。
All About Monadsで、なるだけContinuationモナドは使わないように、って書いてあったのもうなずける。