超高速 In-Place Merge Sort (クイックソートに比肩するインプレースなマージソート)


クイックソートに比肩する速度へ (クイックソートが最速という常識への挑戦)

一般のマージソートでは,データの記憶領域と同じ程度の追加の記憶領域が必要です。in-place merge sort(インプレースマージソート)とは,追加の記憶領域をほとんど使わずに行うマージソートです。
今までのインプレースマージソートの多くは,平均計算時間と最悪計算時間が共に O(n (log n)2) で,一般のマージソートと同じ安定ソートだったようです。(Wikipedia ソート)
また,In-Place Sorting With Merge Sort に,平均計算時間と最悪計算時間が共に O(n log n) で安定ソートでないインプレースマージソートがありますが,それは一般のマージソートより遅いものです。
一般のマージソートはクイックソートよりも遅いので,今までのインプレースマージソートはクイックソートより遅かったのです。(私が作ったマージソートはクイックソートよりも速いです。)

私も平均計算時間と最悪計算時間が共にO(n log n)で安定ソートでないインプレースマージソートを作りました。再帰呼び出しを用いていて,追加の記憶領域は O(log n) です。(2025年1月)
その後,改良を重ねたところ,平均計算時間はクイックソートに比肩するまでになりました。
そのソートは,In-Place Sorting With Merge Sort にあるインプレースマージソートを基本として,それに幾つかの速度を上げる改良を施した形になっているので,やや複雑です。
In-Place Sorting With Merge Sort での計算時間O(n log n)のインプレースマージソートの実現方法と,このインプレースマージソートを比較しながら読んで頂ければ幸いです。

ここで示すインプレースマージソートは,In-Place Sorting With Merge Sort にあるインプレースマージソートよりも高速ですが,それは,部分列のソート法が次のように異なるからです。
    データ列のある部分をソートする → データ列のある部分をソートされたデータ列にする
すなわち,データ列のある部分について,そこに存在するデータをソートするのではなく,その部分を,他の部分にあるデータも用いてソートされたデータ列に したのです。

In-Place Sorting With Merge Sort では,整列されたm個のデータと整列されたk個のデータを,整列されていないk個のデータを上手く用いて,整列されたm+k個のデータにしていました。
    (整列されたm個,整列されていないk 個,整列されたk個) → (整列されたm+k個,整列されていないk個)
    (整列されたk個,整列されていないk 個,整列されたm個) → (整列されていないk個,整列されたm+k個)
ここでのインプレースマージソートは,上記の操作を再帰的に繰り返しました。このことが,In-Place Sorting With Merge Sort のソート法より高速になった最大の理由です。

さらに,高速化のために,次の2つの手法を用いました。
1 データ数が少ないデータ群をソートするときには挿入ソートを用いた。
2 インプレースマージソートの最終段階でソートされていないデータ が僅かな数になったとき,ソートされていないデータをソートされたデータ列にマージするのに新たな手法を用いた。

このインプレースマージソートを,まずは改良した挿入ソートを用いて高速化したところ,ランダムなデータに対して,高速化されていないクイックソートと競 い合える速度になりました。
その後,さらなる工夫を重ねたところ,高速化されたクイックソートやイントロソートに比肩する速度までになりました。
このソートは,最悪計算時間において,クイックソートやイントロソートよりも優れていて,私の知る限りインプレースなソートの内では最悪計算時間において 最も 優れています。
なお,ほとんど整列されたデータに外れ値が入り込んでいるデータに対しては,ランダムなデータに対してより高速です。

最初のコードは,高速化していないマージソートよりは速く,高速化していないクイックソートと競い合えるレベルでした。
そのコードに,次の2つの変形を加えたのが,以下にあるコードです。

1 最初のコードでは,できるだけデータ数がほぼ等しいデータ群を マージしようとしていました。すなわち,m≒k となる割合を多くすれば高速になると思い込んでいました。
    ところが,コードの一部の「k=n/2」を「k=n*(4.0/5.0)」に変えたところ,不思議にも大きく速くなりました。
    データ数比 1:1でマージしていたのが,1:4でマージするように変わったので遅くなると思っていたのに,速くなるなんてびっくり。
    私には理由は分かりませんが,「クイックソートのピボットは中央値でなく四分位数を選択したほうが高速」というブログを見て,もしかしたら分岐予測ミス確 率の関係かなと思っています。

2 次の変形により,少しだけ速くなりました。コンパイラの関係で しょうが,これも理由が分かりません。
    d_type *b=a+(m-1),*c=a+(m-1)+k,*d=a+(m-1)+k+k; → d_type *b=a+(m-1),*c=a+(m-1)+k,*d=a+(m-1)+2*k;

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

typedef int d_type;// ソートするキーの型

void merge_sortD(d_type a[], int n);
void merge_sortU(d_type a[], int n);

void swap(d_type *x,d_type *y){
    d_type tmp=*x;
    *x=*y;
    *y=tmp;
}

void insertion_sort(d_type array[],int n) {
    int i;
    for (i=1;i<n;i++) {
        int j=i;
        d_type tmp=array[i];
        if (tmp<array[0]){
            do {
                array[j]=array[j-1];
                j--;
            } while (j>0);
        }else{
            while (array[j-1]>tmp) {
                array[j]=array[j-1];
                j--;
            }
        }
        array[j]=tmp;
    }
}

void mergeD(d_type a[], int m, int k){
    d_type *b=a+(m-1),*c=a+(m-1)+k,*d=a+(m-1)+2*k;
    while (1) {
        if (*b>*d) {
            swap(b,c);
            b--;
            c--;
            if (b<a) {
                while(c>=a){
                    swap(c,d);
                    c--;
                    d--;
                }
                break;
            }
        } else {
            swap(c,d);
            c--;
            d--;
            if (c<=b) break;
        }
    }
}

void mergeU(d_type c[], int m, int k){
    d_type *b=c+k,*d=c-k,*a=c+m+k;
    while (1) {
        if (*b<*d) {
            swap(b,c);
            b++;
            c++;
            if (b>=a) {
                while(c<a){
                    swap(c,d);
                    c++;
                    d++;          
                }
                break;
            }
        } else {
            swap(c,d);
            c++;
            d++;
            if (c>=b) break;
        }
    }
}

#define RATIO 0.8
#define INSERT_THRESH 60
void merge_sortD(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4に変えたら、なぜか高速化
    merge_sortD(a,m);
    merge_sortU(a+n,k);
    mergeD(a,m,k);
}

void merge_sortU(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4に変えたら、なぜか高速化
    merge_sortU(a+k,m);
    merge_sortD(a-k,k);
    mergeU(a,m,k);
}
#undef INSERT_THRESH

d_type maxim(d_type array[],int last){
    int i,max_i=last;
    for (i=last-1;i>=0;i--){
        if (array[i]>array[max_i]) max_i=i;
    }
    d_type maxim=array[max_i];
    array[max_i]=array[last];
    return maxim;
}

void merge_sort(d_type array[], int n){
    if (n<60) {
        insertion_sort(array,n);
        return;
    }

    int m=n/(1.0+RATIO);//mを最大限まで大きくした
    merge_sortD(array,m);

    int k=(n-m)/2,end=sqrt(n);
    while(k>=end) {
        merge_sortU(array+m+k,k);         
        mergeD(array,m,k);
        m=m+k;
        k=(n-m)/2;
    }

    //マージソートされていないデータの処理
    int i=m,j=n-1;
    d_type max=maxim(array+i,j-i);
    while (1) {
        if (array[i-1]>max){
            array[j]=array[i-1];
            array[i-1]=array[j-1];
            i--;
            j--;
            if (i==0){
                merge_sort(array,j);
                array[j]=max;
                break;
            }
        }else{   
            array[j]=max;
            j--;
            if (j<i) break;
            max=maxim(array+i,j-i);
        }
    }
}
#undef RATIO

int main(){
    int n=100000000;
    int i;
    clock_t start,end;
    d_type *array=malloc(n*sizeof(d_type));
    srand((unsigned) time(NULL));
    for (i=0;i<n;i++){
        array[i]=rand()*(RAND_MAX+1)+rand();
        //整列されたデータに外れ値が入り込んでいるデータの例
        //array[i]=i;
        //if (rand()%10==0)
        //    array[i]=(int)((double)rand()/(RAND_MAX+1)*n);
    }
    start=clock();
    merge_sort(array,n);
    end=clock();
    printf("%f秒  \n",(double)(end-start)/CLOCKS_PER_SEC);
    for (i=1;i<n;i++) if(array[i-1]>array[i]) printf("%d %d\n",i,array[i]);
    free(array);
    return 0;
}

理由が分からなくても,これらによりランダムなデータに対しては,高速化されたクイックソートやイントロソートに近い速度にな りまし た。
ただし,別のコンパイラではどこまで速くなるか不明です。

最悪計算時間がO(n log n) となることは,次の様に分かります。
 ① ほぼ同じ大きさのデータをマージする操作の計算時間の合計はO(n log n) で抑えられる。
 ② 大きさが大きく異なるデータをマージする回数はO(log n)で抑えられるから,それらの計算時間の和はO(n log n) で抑えられる。
 ③ 高速化のテクニックは計算時間を減少させるだけである。

注1 整列されたデータに外れ値が入り込んでいるデータでは,「k=n*(4.0/5.0)」の場合は「k=n/2」の場合より遅くなった。
注2 並列化により更に高速化する場合は,「k=n*(4.0/5.0)」では難しく,「k=n/5」に変えるのが良い。


クイックソートが最速という常識の終わり
実用では上記のコードの速度で十分だと思いますが,もう少しだけ速くして高速化されたクイックソートと拮抗させたいと思い,次の改良を加えました。

このソートの終わりの方にあるソートされていない要素の処理に,ヒープソートを用いました。その基本的なアイデアは次の様です。

    (整列されたm個,整列されていない2k 個) → (整列されたm個,整列されていないk個,整列されたk個) → (整列されたm+k個,整列されていないk個)
と変形するときに,整列されたk個の要素の全てが,整列されていないk個のどの要素よりも大きく無いようにしたい。
そうすれば,整列されたm+k個の要素の半数位は,整列されていないk個のどの要素よりも大きく無くなるから,次に整列されていないk個を処理するときに 高速化される。
そのためには,整列されていない2k個の要素を整列されていないk個と整列されたk個の要素に分けるときに,ヒープソートを用いればよい。

この改良によって,1億個のランダムなデータのソートでは前記のコードより数%は速く成り,高速化されたクイックソートやイントロソートに対抗で きる速度に なりました。
最悪計算時間については,クイックソートはO(n2)で,イントロソートとヒープソートとこのソートは同じO(n log n)です。
イントロソートは,最悪の場合に途中からヒープソートに切り替えるため,最悪計算時間はヒープソートの最悪計算時間より大きくなります。
ヒープソートは,平均計算時間がO(n log n)のソートの内で比較的遅いソートです。
これらのことより,このソートは,私が知る限り,インプレースなソートのうちでデータ数が多いときの最悪計算時間が最小と考えられま す。

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

typedef int d_type;// ソートするキーの型

void merge_sortD(d_type a[], int n);
void merge_sortU(d_type a[], int n);

void swap(d_type *x,d_type *y){
    d_type tmp=*x;
    *x=*y;
    *y=tmp;
}

void insertion_sort(d_type array[],int n) {
    int i;
    for (i=1;i<n;i++) {
        int j=i;
        d_type tmp=array[i];
        if (tmp<array[0]){
            do {
                array[j]=array[j-1];
                j--;
            } while (j>0);
        }else{
            while (array[j-1]>tmp) {
                array[j]=array[j-1];
                j--;
            }
        }
        array[j]=tmp;
    }
}

void mergeD(d_type a[], int m, int k){
    d_type *b=a+(m-1),*c=a+(m-1)+k,*d=a+(m-1)+2*k;
    while (1) {
        if (*b>*d) {
            swap(b,c);
            b--;
            c--;
            if (b<a) {
                while(c>=a){
                    swap(c,d);
                    c--;
                    d--;
                }
                break;
            }
        } else {
            swap(c,d);
            c--;
            d--;
            if (c<=b) break;
        }
    }
}

void mergeU(d_type c[], int m, int k){
    d_type *b=c+k,*d=c-k,*a=c+m+k;
    while (1) {
        if (*b<*d) {
            swap(b,c);
            b++;
            c++;
            if (b>=a) {
                while(c<a){
                    swap(c,d);
                    c++;
                    d++;
                }
                break;
            }
        } else {
            swap(c,d);
            c++;
            d++;
            if (c>=b) break;
        }
    }
}

#define RATIO 0.8
#define INSERT_THRESH 60
void merge_sortD(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4に変えたら、なぜか高速化
    merge_sortD(a,m);
    merge_sortU(a+n,k);
    mergeD(a,m,k);
}

void merge_sortU(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4 に変えたら、なぜか高速化
    merge_sortU(a+k,m);
    merge_sortD(a-k,k);
    mergeU(a,m,k);
}
#undef INSERT_THRESH

#define N 5 //N分木ヒープソート N>=3
void downheap(d_type *array,d_type work,int parent,int child,int last) {
    while (1) {
        if (child<last) {
            int i=child,end=child+N-1;
            if (last<end) end=last;
            do {
                if (array[++i]<array[child]) child=i;
            } while (i<end);
        } else if (child>last) {
            break;
        }
        if (work<=array[child]) break;
        array[parent]=array[child];
        parent=child;
        child=N*parent+1;
    }
    array[parent]=work;
}

void heap_small(d_type array[],int k,int last){
    int i;
    for (i=(last-1)/N;i>=0;i--)
        downheap(array,array[i],i,N*i+1,last);

    //小さいk個を大きい順に並べる
    int end=last-k+1;
    for (i=last;;){
        d_type work=array[i];
        array[i]=array[0];
        if (i==end){
            array[0]=work;
            break;
        }
        downheap(array,work,0,1,--i);
    }

    //小さい順に並べ替える
    d_type *ak=array+k,*al=array+last;
    do{
        swap(ak,al);
    }while (++ak < --al);
}
#undef N

d_type maxim(d_type array[],int last){
    int i,max_i=last;
    for (i=last-1;i>=0;i--){
        if (array[i]>array[max_i]) max_i=i;
    }
    d_type maxim=array[max_i];
    array[max_i]=array[last];
    return maxim;
}

void merge_sort(d_type array[], int n){
    if (n<60) {
        insertion_sort(array,n);
        return;
    }

    int m=n/(1.0+RATIO);//mを最大限まで大きくした
    merge_sortD(array,m);

    int k=(n-m)/2,end=n/80+7;
    while (k>=end){
        merge_sortU(array+m+k,k);
        mergeD(array,m,k);
        m=m+k;
        k=(n-m)/2;
    }
    while(k>=40){
        heap_small(array+m,k,n-m-1);
        mergeD(array,m,k);
        m=m+k;
        k=(n-m)/2;
    }

    //マージソートされていないデータの処理
    int i=m,j=n-1;
    d_type max=maxim(array+i,j-i);
    while (1) {
        if (array[i-1]>max){
            array[j]=array[i-1];
            array[i-1]=array[j-1];
            i--;
            j--;
            if (i==0){
                merge_sort(array,j);
                array[j]=max;
                break;
            }
        }else{
            array[j]=max;
            j--;
            if (j<i) break;
            max=maxim(array+i,j-i);
        }
    }
}
#undef RATIO

int main(){
    int n=100000000;
    int i;
    clock_t start,end;
    d_type *array=malloc(n*sizeof(d_type));
    srand((unsigned) time(NULL));
    for (i=0;i<n;i++){
        array[i]=rand()*(RAND_MAX+1)+rand();
        //整列されたデータに外れ値が入り込んでいるデータの例
        //array[i]=i;
        //if (rand()%10==0)
        //    array[i]=(int)((double)rand()/(RAND_MAX+1)*n);
    }
    start=clock();
    merge_sort(array,n);
    end=clock();
    printf("%f秒  \n",(double)(end-start)/CLOCKS_PER_SEC);
    for (i=1;i<n;i++) if(array[i-1]>array[i]) printf("%d %d\n",i,array[i]);
    free(array);
    return 0;
}

理由が分からずマジックのように速くなったソートなので,冗談でマージックソートとでも名前をつけたくなります。マージソートの考案者のフォン・ノイマン もマジマジと見ているかもしれません。


クイックソートとの比較
最高速と言われているクイックソートと速度 比較ができるように,上記のコードに,挿入ソートで高速化したクイックソートのコードを加えて掲載します。
私のパソコンで速度を測定したところでは,インプレースマージソートとクイックソートの差は,コンパイラのご機嫌によって変わる程度です。
実際にコンパイルして実行してみて下さい。

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

typedef int d_type;// ソートするキーの型

void merge_sortD(d_type a[], int n);
void merge_sortU(d_type a[], int n);

void swap(d_type *x,d_type *y){
    d_type tmp=*x;
    *x=*y;
    *y=tmp;
}

void insertion_sort(d_type array[],int n) {
    int i;
    for (i=1;i<n;i++) {
        int j=i;
        d_type tmp=array[i];
        if (tmp<array[0]){
            do {
                array[j]=array[j-1];
                j--;
            } while (j>0);
        }else{
            while (array[j-1]>tmp) {
                array[j]=array[j-1];
                j--;
            }
        }
        array[j]=tmp;
    }
}

d_type med3(d_type x,d_type y,d_type z){
    if (x<y){
        if (y<z)      return y;
        else if (z<x) return x;
    }else{
        if (z<y)      return y;
        else if (x<z) return x;
    }
    return z;
}

void quick_sort(d_type array[], int begin, int end){
    int n=end-begin+1;
    if (n<60){
        insertion_sort(array+begin,n);
        return;       
    }
    d_type pivot=med3(array[begin],array[begin+n/2],array[end]);
    int l=begin,r=end;
    while(1){
        while(array[l]<pivot) l++;
        while(array[r]>pivot) r--;
        if(l>=r) break;
        swap(array+l,array+r);
        l++;
        r--;
    }
    quick_sort(array,begin,r);//l=r+1
    quick_sort(array,r+1,end);
}

void mergeD(d_type a[], int m, int k){
    d_type *b=a+(m-1),*c=a+(m-1)+k,*d=a+(m-1)+2*k;
    while (1) {
        if (*b>*d) {
            swap(b,c);
            b--;
            c--;
            if (b<a) {
                while(c>=a){
                    swap(c,d);
                    c--;
                    d--;
                }
                break;
            }
        } else {
            swap(c,d);
            c--;
            d--;
            if (c<=b) break;
        }
    }
}

void mergeU(d_type c[], int m, int k){
    d_type *b=c+k,*d=c-k,*a=c+m+k;
    while (1) {
        if (*b<*d) {
            swap(b,c);
            b++;
            c++;
            if (b>=a) {
                while(c<a){
                    swap(c,d);
                    c++;
                    d++;
                }
                break;
            }
        } else {
            swap(c,d);
            c++;
            d++;
            if (c>=b) break;
        }
    }
}

#define RATIO 0.8
#define INSERT_THRESH 60
void merge_sortD(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4に変えたら、なぜか高速化
    merge_sortD(a,m);
    merge_sortU(a+n,k);
    mergeD(a,m,k);
}

void merge_sortU(d_type a[], int n){
    if (n<INSERT_THRESH) {
        insertion_sort(a,n);
        return;
    }
    int k=n*RATIO,m=n-k;//k=n/2をk=n/5*4 に変えたら、なぜか高速化
    merge_sortU(a+k,m);
    merge_sortD(a-k,k);
    mergeU(a,m,k);
}
#undef INSERT_THRESH

#define N 5 //N分木ヒープソート N>=3
void downheap(d_type *array,d_type work,int parent,int child,int last) {
    while (1) {
        if (child<last) {
            int i=child,end=child+N-1;
            if (last<end) end=last;
            do {
                if (array[++i]<array[child]) child=i;
            } while (i<end);
        } else if (child>last) {
            break;
        }
        if (work<=array[child]) break;
        array[parent]=array[child];
        parent=child;
        child=N*parent+1;
    }
    array[parent]=work;
}

void heap_small(d_type array[],int k,int last){
    int i;
    for (i=(last-1)/N;i>=0;i--)
        downheap(array,array[i],i,N*i+1,last);

    //小さいk個を大きい順に並べる
    int end=last-k+1;
    for (i=last;;){
        d_type work=array[i];
        array[i]=array[0];
        if (i==end){
            array[0]=work;
            break;
        }
        downheap(array,work,0,1,--i);
    }

    //小さい順に並べ替える
    d_type *ak=array+k,*al=array+last;
    do{
        swap(ak,al);
    }while (++ak < --al);
}
#undef N

d_type maxim(d_type array[],int last){
    int i,max_i=last;
    for (i=last-1;i>=0;i--){
        if (array[i]>array[max_i]) max_i=i;
    }
    d_type maxim=array[max_i];
    array[max_i]=array[last];
    return maxim;
}

void merge_sort(d_type array[], int n){
    if (n<60) {
        insertion_sort(array,n);           
        return;
    }

    int m=n/(1.0+RATIO);//mを最大限まで大きくした
    merge_sortD(array,m);

    int k=(n-m)/2,end=n/80+7;
    while (k>=end){
        merge_sortU(array+m+k,k);
        mergeD(array,m,k);
        m=m+k;
        k=(n-m)/2;
    }
    while(k>=40){
        heap_small(array+m,k,n-m-1);
        mergeD(array,m,k);
        m=m+k;
        k=(n-m)/2;
    }

    //マージソートされていないデータの処理
    int i=m,j=n-1;
    d_type max=maxim(array+i,j-i);
    while (1) {
        if (array[i-1]>max){
            array[j]=array[i-1];
            array[i-1]=array[j-1];
            i--;
            j--;
            if (i==0){
                merge_sort(array,j);
                array[j]=max;
                break;
            }
        }else{
            array[j]=max;
            j--;
            if (j<i) break;
            max=maxim(array+i,j-i);
        }
    }
}
#undef RATIO

int main(){
    int n=100000000;
    int i;
    clock_t start,end,end2;
    d_type *array=malloc(n*sizeof(d_type));
    d_type *array2=malloc(n*sizeof(d_type));
    srand((unsigned) time(NULL));
    for (i=0;i<n;i++){
        array[i]=rand()*(RAND_MAX+1)+rand();
        //整列されたデータに外れ値が入り込んでいるデータの例
        //array[i]=i;
        //if (rand()%10==0)
        //    array[i]=(int)((double)rand()/(RAND_MAX+1)*n);
        array2[i]=array[i];
    }
    start=clock();
    merge_sort(array,n);
    end=clock();
    quick_sort(array2,0,n-1);
    end2=clock();
    printf("merge_sort %f秒  \n",(double)(end-start)/CLOCKS_PER_SEC);
    printf("quick_sort %f秒  \n",(double)(end2-end)/CLOCKS_PER_SEC);
    for (i=0;i<n;i++) if(array[i]!=array2[i]) printf("%d %d\n",array[i],array2[i]);
    free(array);
    free(array2);
    return 0;
}


最初のページ に戻る