2014年3月10日

erlang - tail recursion 尾端遞迴

學習程式語言C/Java/組合語言等等,都會用到迴圈,以變數重複相同的步驟,就可以完成一些計算,但 erlang 沒有迴圈的語法,只有遞迴,也就是要把重複的步驟換成一個重複執行的函數,為了追求高運算速度,遞迴程式必須盡量以尾端遞迴實作,因為編譯器會自動將遞迴展開,避免在重複計算時,消耗大量的 stack 記憶體。

老式迴圈

以下是用 C/Java 撰寫從 0 ~ n 加總的程式。

int sum(int n) {
    int total=0;
    while(n!=0) {
        total=total+n;
        n=n-1;
    }
    return total;
}

但是我們可以改用這樣的概念:用不同的參數,重複相同的步驟。上面的程序,其中的 while 迴圈中,其實就代表著,把 total 換成 total+n,把 n 換成 n-1 這兩個一直重複的步驟,因此,如果沒有 while,而是把它換成一個重複執行的函數。

do_sum(N, Total) when N>0 ->
    do_sum(N-1, Total+N);
do_sum(0, Total) -> Total.

一開始 total 初始化為 0,其實換句話說,要計算 do_sum(N) 就等於 Total 一開始為 0,也就是計算 do_sum(N, 0)。

do_sum(N) -> do_sum(N, 0).

把順序調整一下,就變成以下的結果,因為函數會依照順序被比對並呼叫,當第二句 do_sum(0, Total) 無法匹配時,就表示 N 不是 0,因此第三句就不用寫 when N>0 了。

do_sum(N) -> do_sum(N, 0).
do_sum(0, Total) -> Total;
do_sum(N, Total) ->    do_sum(N-1, Total+N).

tail recursion

上面的 do_sum,其實也可以改成下面的 sum 這樣。

sum(0) -> 0;
sum(N) -> sum(N-1) + N.

把這個版本跟 do_sum 比較一下

do_sum(N, Total) -> do_sum(N-1, Total+N).

會發現遞迴出現的地方,sum 在中間,do_sum 在尾端,而尾端遞迴因為在最後一行呼叫函數後,原本這個函數後面,不再有其他需要估算的子句,因此編譯器可以將 tail recursion 的程式最佳化,不需要將變數重複地放置到 stack 中,不會消耗 stack 記憶體,速度也遠比中間遞迴快很多。

因此我們在撰寫遞迴程式時,盡量要將重複呼叫的函數,放在最後面。

如果 sum 改寫成這樣,是不是 tail recursion 呢?

sum(0) -> 0;
sum(N) -> N + sum(N-1).

這種寫法不是 tail recursion,尾端遞迴的用意,是在最後呼叫函數後,不需要再做任何其他的估算,就直接將遞迴函數的結果回傳回去,雖然 sum(N-1) 寫在後面,但是 sum(N) -> N + sum(N-1) 卻必須要在得到 sum(N-1) 的結果後,再加上前面的 N,所以這種寫法並不是 tail recursion。

累加器參數 accumulator parameter

sum 將工作拆成兩半,一半是將 N 倒數到 0,同時要在 stack 上記錄後續要累加的數值,另一半負責從 stack 找出對應數值累加,直到 stack 被清空。

do_sum 則是只在 stack 上保留一份記錄,它會不斷地更新該紀錄,直到 N 遞減為 0,然後丟棄記錄,並直接取得結果 Total。Total 扮演了 accumulator parameter 的角色,用來將過程中的結果累加。

撰寫 tail recursion 通常需要額外的參數,而這些參數會在遞迴開始前先初始化,因此常常會需要兩個函數,一個用來當作前端的 API 界面,另一個用來作為主循環。

效率問題

通常尾端遞迴的方案比非尾端遞迴來得有效率,但不絕對是這樣,這取決於運算內容。尾端遞迴函數,就像是將所有家當都放在推車上的旅者,非尾端遞迴則是靠扔紙片作為路標,才能回得了家的醉漢。

如果像剛剛 do_sum 的例子,最終結果只是一個簡單的值,那麼尾端遞迴效率比較高,因為負擔不重。但如果要追蹤的訊息很複雜,旅者必須耗費一些時間處理資料管理工作,最後反而會導致比非尾端遞迴來得慢。

有些問題不一定要選擇用尾端遞迴實作,有些則明顯該使用尾端遞迴。但如果是無窮循環的狀況,就只能採用尾端遞迴的方式,避免 stack overflow。

撰寫尾端遞迴的訣竅

以訪問 list 內所有元素為例,實作反轉 list 的函數。

  1. 觀察樣本
    先將幾個輸入與輸出的樣本資料列出來,這不單是可以讓問題更具體,還能讓我們考慮到問題的一些特殊狀況,還可以用來寫入 test case 中。

     [] -> []
     [x] -> [x]
     [x,y] -> [y,x]
     [x,y,z] -> [z,y,x]
  2. 基本狀況
    就是不涉及遞迴的狀況,通常就是遞迴的終點,這可以用來測試 rev([])、rev([17])、rev([atom]) 這些 test case。

     rev([]) -> [];
     rev([X]) -> [X].
  3. 分析輸入資料的型態
    因為 list 的樣子是 [A,B,...] 的狀態,我們也知道 list 是由 [...|...] 這種方式組成的,因此我們可以將輸入的參數改為這樣

     rev([A | [B | TheRest]]) -> not_yet_implemented;

    很明顯,這裡需要用遞迴方式解決問題

  4. 如果沒有現成的函數可使用
    先假設已經存在一個舊版的 old_rev/1。
    如果要把剛剛的 not_yet_implemented 改掉,應該怎麼寫呢?
    如果能用 old_rev 把 TheRest 反轉,在 list 後面加上 [B|A]就可以完成整個 list 反轉了。

     rev([A | [B | TheRest]]) -> old_rev(TheRest) ++ [B,A];

    既然 rev 已經可以運作了,那麼 old_rev 不也就可以換成 rev。用 rev([1,2,3,4,5]) 測試看看,的確可以運作。

     rev([A | [B | TheRest]]) -> rev(TheRest) ++ [B,A];
     rev([]) -> [];
     rev([X]) -> [X].
  5. 可終止性證明
    要驗證函數的可終止性,主要線索為「單調遞減函數」,基本狀況是在處理此函數可接受的最小輸入資料,而遞迴是在處理所有其他較多的輸入資料。

    只要每一次遞迴都持續在縮小函數呼叫的資料量,就能知道輸入參數一定會慢慢地收斂到基本狀況,於是函數就一定會終止。

    因為累加器 accumulator parameter 與判斷函數循環的條件無關,除了累加器之外,如果有某個遞迴函數傳入了想等或更大的輸入資料量,那就必須要懷疑函數是不是會終止。

    上一個步驟裡的 rev(TheRest),很明顯 TheRest 少了 A 與 B,裡面包含的元素數量少於前面輸入的參數,由此可知遞迴呼叫一定會終止。

    在以整數遞迴時,要注意整數的下限,因為輸入的參數有可能會變成 -1, -2, -3, ... ,為了避免這個狀況,可以加上 when N > 0 guard clause,確保參數不會是負數。

  6. 基本狀況最小化
    剛剛的例子列出兩個基本狀況 [] 與 [X],但其實 [X] 也可以寫成 [X | []],由於 rev 已經可以處理 [],所以 rev([X]) 就可以遞迴呼叫 rev([]) ++ [X]。

    這樣就不需要將單個元素的列表,和包含了兩個或兩個以上的元素列表,區分為兩種狀況,兩個規則可以合成一個。

     rev([X | TheRest]) -> rev(TheRest) ++ [X];
     rev([]) -> [].

    rev([]) -> [] 沒有必要放在前面,假設有一個包含 100 個元素的 list,前100次都是非空 list,只有最後一次是 [],因此把 rev([]) -> [] 放在後面,就可以節省比對的次數。

  7. 識別平方等級時間複雜度的行為
    雖然剛剛的遞迴函數實作已經可以運作了,但如果 rev 的運算是平方等級的時間複雜度,就會因為 list 長度變長而使得 rev 所消耗的時間變得非常糟糕。

    rev 因為用到了 ++ ,它所需要耗費的執行時間與 ++ 左邊的 list 長度成正比。假設 ++ 左邊的 list 長度為 1,耗費的時間為 T,則在長度為 100 的 list 上執行 rev,第一次呼叫會耗費 100T,第二次 99T...,而 100T+99T+...+T 的總和為 (101100/2)T,大約就是 NN/2 的耗時。

    如果是立方等級的時間複雜度,就只會讓狀況變得更糟。

  8. 避免平方等級的時間複雜度
    如果改成 tail recursion 是不是就能解決問題?
    先加上一個累加器 accumulator parameter 參數。

     tailrev([X|TheRest], Acc) -> not_yet_implemented;
     tailrev([], Acc) -> Acc.

    因為要構成 tail recusrion,一定要寫成 tailrev(TheRest, ...),而且我們知道 [X|Acc] 是個廉價的運算,所以我們就改成

     tailrev([X|TheRest], Acc) -> tailrev(TheRest, [X|Acc]);
     tailrev([], Acc) -> Acc.

    現在要考慮的是 Acc 的初始值是什麼,如果是 tailrev([foo], Acc),因為 X 綁定為 foo ,而 TheRest 綁定為 [],因此就變成 tailrev( [], [foo | Acc]),換句話說, Acc 必須要是 []。

     tailrev(List) -> tailrev(List, []).
    
     tailrev([X|TheRest], Acc) -> tailrev(TheRest, [X|Acc]);
     tailrev([], Acc) -> Acc.

    因為處理列表裡每個元素,都是單純地把元素加到 Acc 的左側,因此如果列表裡有 L 個元素,總耗時就是 L*C,因此這樣的運算就是線性等級的時間複雜度。

  9. 注意長度
    Java 語言的 length 跟 erlang 的 length BIF 是完全不同的。
    如果在 guard clause 裡面呼叫 length,因為 erlang 的 length 每一次呼叫都需要從頭到尾 traverse 一次 list,所有的耗時加起來,又是一個平方等級的時間複雜度。

     loop(List) when length(List) > 0 ->
         do_something();
     loop(EmptyList) ->
         done.

    要改用 pattern matching 匹配空列表

     loop([SomeElement|RestOfList]) ->
         do_something();
     loop([]) ->
         done.

    也可以用 pattern matching 來比對不同長度的 list。但要注意必須把長度較長的比對 pattern 放在前面。

     loop([A,B,C|TheRest]) -> three_or_more;
     loop([A,B|TheRest]) -> two_or_more;
     loop([A|TheRest]) -> one_or_more;
     loop([]) -> none;

參考

Erlang and OTP in Action
Programming Erlang: Software for a Concurrent World