Переглянути джерело

AK: Guarantee a maximum stack depth for dual_pivot_quick_sort

When the two chosen pivots happen to be the smallest and largest
elements of the array, three partitions will be created, two of
size 0 and one of size n-2. If this happens on each recursive call
to dual_pivot_quick_sort, the stack depth will reach approximately n/2.

To avoid the stack from deepening, iteration can be used for the
largest of the three partitions. This ensures the stack depth
will only increase for partitions of size n/2 or smaller, which
results in a maximum stack depth of log(n).
Mart G 4 роки тому
батько
коміт
c9f3cc6dcc
1 змінених файлів з 58 додано та 45 видалено
  1. 58 45
      AK/QuickSort.h

+ 58 - 45
AK/QuickSort.h

@@ -18,62 +18,75 @@ namespace AK {
 template<typename Collection, typename LessThan>
 void dual_pivot_quick_sort(Collection& col, int start, int end, LessThan less_than)
 {
-    int size = end - start + 1;
-    if (size <= 1) {
-        return;
-    }
-
-    if (size > 3) {
-        int third = size / 3;
-        if (less_than(col[start + third], col[end - third])) {
-            swap(col[start + third], col[start]);
-            swap(col[end - third], col[end]);
+    while (start < end) {
+        int size = end - start + 1;
+        if (size > 3) {
+            int third = size / 3;
+            if (less_than(col[start + third], col[end - third])) {
+                swap(col[start + third], col[start]);
+                swap(col[end - third], col[end]);
+            } else {
+                swap(col[start + third], col[end]);
+                swap(col[end - third], col[start]);
+            }
         } else {
-            swap(col[start + third], col[end]);
-            swap(col[end - third], col[start]);
-        }
-    } else {
-        if (!less_than(col[start], col[end])) {
-            swap(col[start], col[end]);
+            if (!less_than(col[start], col[end])) {
+                swap(col[start], col[end]);
+            }
         }
-    }
 
-    int j = start + 1;
-    int k = start + 1;
-    int g = end - 1;
+        int j = start + 1;
+        int k = start + 1;
+        int g = end - 1;
 
-    auto&& left_pivot = col[start];
-    auto&& right_pivot = col[end];
+        auto&& left_pivot = col[start];
+        auto&& right_pivot = col[end];
 
-    while (k <= g) {
-        if (less_than(col[k], left_pivot)) {
-            swap(col[k], col[j]);
-            j++;
-        } else if (!less_than(col[k], right_pivot)) {
-            while (!less_than(col[g], right_pivot) && k < g) {
-                g--;
-            }
-            swap(col[k], col[g]);
-            g--;
+        while (k <= g) {
             if (less_than(col[k], left_pivot)) {
                 swap(col[k], col[j]);
                 j++;
+            } else if (!less_than(col[k], right_pivot)) {
+                while (!less_than(col[g], right_pivot) && k < g) {
+                    g--;
+                }
+                swap(col[k], col[g]);
+                g--;
+                if (less_than(col[k], left_pivot)) {
+                    swap(col[k], col[j]);
+                    j++;
+                }
             }
+            k++;
+        }
+        j--;
+        g++;
+
+        swap(col[start], col[j]);
+        swap(col[end], col[g]);
+
+        int left_pointer = j;
+        int right_pointer = g;
+
+        int left_size = left_pointer - start;
+        int middle_size = right_pointer - (left_pointer + 1);
+        int right_size = (end + 1) - (right_pointer + 1);
+
+        if (left_size >= middle_size && left_size >= right_size) {
+            dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
+            dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
+            end = left_pointer - 1;
+        } else if (middle_size >= right_size) {
+            dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
+            dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
+            start = left_pointer + 1;
+            end = right_pointer - 1;
+        } else {
+            dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
+            dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
+            start = right_pointer + 1;
         }
-        k++;
     }
-    j--;
-    g++;
-
-    swap(col[start], col[j]);
-    swap(col[end], col[g]);
-
-    int left_pointer = j;
-    int right_pointer = g;
-
-    dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
-    dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
-    dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
 }
 
 template<typename Iterator, typename LessThan>