2014年3月5日 星期三

Binary search

演算法類型

divide and conquer, 搜尋演算法

演算法目的

利用比對 key value 來對資料進行搜尋

演算法描述

使用 binary search 的前提是資料必須是已經排序好的,接著再藉由比對資料堆中間位置的 key value 來將資料堆做二分法搜尋。舉例說明,假設有一排序好的數列如下圖:



現在假設我們要搜尋的數字是 3,就我們目前所得知的資訊,3 有可能出現在此數列的任何位置,所以目前的搜尋範圍是整個數列,我們先找出在搜尋範圍內中間位置的數字,如下圖:


我們在中間位置找到 4,很明顯的 4 不是我們要的 3,不過因為數列是已經排序好的而且 3 < 4,故我們可以確定 3 只有可能落在 4 的左邊而不可能會落在 4 的右邊,如下圖:


如此我們就可以拋棄右邊,接著繼續向左邊搜尋,所以現在的搜尋範圍變成只剩 4 的左邊,同之前的方法,直接去比對搜尋範圍中間位置的數字,如下圖:


我們在搜尋範圍中間找到 2,很明顯 2 不是我們要的 3,不過因為數列是已經排序好的而且 3 > 2,故我們可以確定 3 只有可能落在 2 的右邊而不可能會落在 2 的左邊,如下圖:


同之前的方法,我們繼續在搜尋範圍內找中間的數字做比對,以目前的情況來看,我們只剩下一個數字,如下圖:


恭喜,我們很幸運的找到 3 了,binary search 圓滿結束。如果一直二分到沒有範圍可以搜尋,則代表此數字不在數列裡。

最壞情況效率分析

設總資料量為 n, 分析單位為比較次數
binary search 的最壞情況就是要搜尋的資料不在資料堆裡,依據最壞情況可以列出以下遞迴式:

T(n) = T(n / 2) + 1
T(1) = 1

根據 master theorem,可推出此演算法效率的最壞情況為 O(lgn)

程式實作


C
#include <stdio.h>

int binarySearch(int *, int, int);

int main() {
    int arr[8] = {1, 2, 3, 4, 5, 6, 7, 8};
    int arrLength = 8;
    int findNum = 3;
    int findIndex = binarySearch(arr, arrLength, findNum);

    if (findIndex == -1)
        printf("The number %d is not in the sequence\n", findNum);
    else
        printf("The number %d is at index %d\n", findNum, findIndex);

    return 0;
}

int binarySearch(int *arr, int arrLength, int findNum) {
    int low = 0;
    int high = arrLength - 1;

    while (low <= high) {
        int mid = (low + high) / 2;

        if (findNum > arr[mid])
            low = mid + 1;
        else if (findNum == arr[mid])
            return mid;
        else
            high = mid - 1;
    }
    return -1;
}

Java
public class BinarySearch {
    public static void main(String[] args) {
        int[] arr = {1, 2, 3, 4, 5, 6, 7, 8};
        int findNum = 3;
        int findIndex = binarySearch(arr, findNum);

        if (findIndex == -1)
            System.out.printf("The number %d is not in the sequence\n", findNum);
        else
            System.out.printf("The number %d is at index %d\n", findNum, findIndex);
    }

    private static int binarySearch(int[] arr, int findNum) {
        int low = 0;
        int high = arr.length - 1;

        while (low <= high) {
            int mid = (low + high) / 2;

            if (findNum > arr[mid])
                low = mid + 1;
            else if (findNum == arr[mid])
                return mid;
            else
                high = mid - 1;
        }
        return -1;
    }
}

Python
def binarySearch(arr, findNum):
    low = 0
    high = len(arr) - 1
    while low <= high:
        mid = (low + high) / 2
        if findNum > arr[mid]:
            low = mid + 1
        elif findNum == arr[mid]:
            return mid
        else:
            high = mid - 1
    return -1

arr = [1, 2, 3, 4, 5, 6, 7, 8]
findNum = 3
findIndex = binarySearch(arr, findNum)
if findIndex == -1:
    print "The number %d is not in the sequence" % (findNum)
else:
    print "The number %d is at index %d" % (findNum, findIndex)

2014年3月1日 星期六

Merge sort

演算法類型

divide and conquer, 排序演算法

演算法目的

利用比較 key value 來將資料做排序

演算法描述

merge sort 的核心觀念是將大筆資料切割成很多小筆資料做排序,接著利用已經排序好的小筆資料合併成排序好的大筆資料。merge sort 的概觀流程圖如下:


分割步驟相信大家應該沒有什麼問題,就是一次將資料切一半。比較需要解釋的應該是合併步驟,以下舉例說明合併步驟如何進行,假設要合併 A 和 B 兩個已排序好的數列,如下圖:


現在我們將兩個箭頭各指著 A 數列和 B 數列的第一個元素,如下圖:


將箭頭指到的數字做比較,把比較小的數字複製到另外一個 C 數列,並且將指向比較小的數字的箭頭往前移一格,如下圖:


繼續重複剛剛的動作,如下圖:


繼續重複這個動作直到有其中一個箭頭超出數列的範圍為止,如下圖:





現在 B 數列的箭頭已經超出範圍了,所以我們剩下要做的事情就只是把 A 數列箭頭開始以後的數字全部複製到 C 數列就可以了,如下圖:



這樣 C 數列就是一個合併完成的數列了。

最壞情況時間複雜度分析

設總資料量為 n, 分析單位為比較次數
依據此演算法的最壞情況,我們可以列出以下遞迴式:

T(n) = 2T(n / 2) + n - 1
T(1) = 0

根據 master theorem,可以求出此演算法為 O(nlgn) 的演算法

程式實作


C
#include <stdio.h>

void mergeSort(int *, int, int);
void merge(int *, int, int, int);

int main() {
    int arr[8] = {4, 6, 1, 9, 5, 3, 0, 2};
    int dataNum = 8;
    int i;

    printf("before sorting: ");
    for (i=0; i<dataNum; i++)
        printf("%d " , arr[i]);
    printf("\n");
    mergeSort(arr, 0, dataNum - 1);
    printf("after sorting: ");
    for (i=0; i<dataNum; i++)
        printf("%d " , arr[i]);
    printf("\n");

    return 0;
}

void mergeSort(int *arr, int low, int high) {
    if (low < high) {
        int mid = (low + high) / 2;

        mergeSort(arr, low, mid);
        mergeSort(arr, mid + 1, high);
        merge(arr, low, mid, high);
    }
}

void merge(int *arr, int low, int mid, int high) {
    int leftIndex = low;
    int rightIndex = mid + 1;
    int tempArrLength = high - low + 1;
    int tempArr[tempArrLength];
    int tempIndex = 0;

    while (leftIndex <= mid && rightIndex <= high) {
        if (arr[leftIndex] <= arr[rightIndex]) {
            tempArr[tempIndex] = arr[leftIndex];
            leftIndex++;
        }
        else {
            tempArr[tempIndex] = arr[rightIndex];
            rightIndex++;
        }
        tempIndex++;
    }
    if (leftIndex > mid) {
        while (rightIndex <= high) {
            tempArr[tempIndex] = arr[rightIndex];
            rightIndex++;
            tempIndex++;
        }
    }
    else {
        while (leftIndex <= mid) {
            tempArr[tempIndex] = arr[leftIndex];
            leftIndex++;
            tempIndex++;
        }
    }
    leftIndex = low;
    for (tempIndex=0; tempIndex<tempArrLength; tempIndex++) {
        arr[leftIndex] = tempArr[tempIndex];
        leftIndex++;
    }
}

Java
public class MergeSort {
    public static void main(String[] args) {
        int[] arr = {4, 6, 1, 9, 5, 3, 0, 2};

        System.out.print("before sorting: ");
        for (int num: arr)
            System.out.printf("%d " , num);
        System.out.println();
        mergeSort(arr, 0, arr.length - 1);
        System.out.print("after sorting: ");
        for (int num: arr)
            System.out.printf("%d " , num);
        System.out.println();
    }

    private static void mergeSort(int[] arr, int low, int high) {
        if (low < high) {
            int mid = (low + high) / 2;

            mergeSort(arr, low, mid);
            mergeSort(arr, mid + 1, high);
            merge(arr, low, mid, high);
        }
    }

    private static void merge(int[] arr, int low, int mid, int high) {
        int leftIndex = low;
        int rightIndex = mid + 1;
        int[] tempArr = new int[high - low + 1];
        int tempIndex = 0;

        while (leftIndex <= mid && rightIndex <= high) {
            if (arr[leftIndex] <= arr[rightIndex]) {
                tempArr[tempIndex] = arr[leftIndex];
                leftIndex++;
            }
            else {
                tempArr[tempIndex] = arr[rightIndex];
                rightIndex++;
            }
            tempIndex++;
        }
        if (leftIndex > mid) {
            while (rightIndex <= high) {
                tempArr[tempIndex] = arr[rightIndex];
                rightIndex++;
                tempIndex++;
            }
        }
        else {
            while (leftIndex <= mid) {
                tempArr[tempIndex] = arr[leftIndex];
                leftIndex++;
                tempIndex++;
            }
        }
        leftIndex = low;
        for (int temp: tempArr) {
            arr[leftIndex] = temp;
            leftIndex++;
        }
    }
}

Python
def merge(arr, low, mid, high):
    leftIndex = low
    rightIndex = mid + 1
    tempArr = []
    while leftIndex <= mid and rightIndex <= high:
        if arr[leftIndex] <= arr[rightIndex]:
            tempArr.append(arr[leftIndex])
            leftIndex += 1
        else:
            tempArr.append(arr[rightIndex])
            rightIndex += 1
    if leftIndex > mid:
        while rightIndex <= high:
            tempArr.append(arr[rightIndex])
            rightIndex += 1
    else:
        while leftIndex <= mid:
            tempArr.append(arr[leftIndex])
            leftIndex += 1
    leftIndex = low
    for temp in tempArr:
        arr[leftIndex] = temp
        leftIndex += 1

def mergeSort(arr, low, high):
    if low < high:
        mid = (low + high) / 2
        mergeSort(arr, low, mid)
        mergeSort(arr, mid + 1, high)
        merge(arr, low, mid, high)

arr = [4, 6, 1, 9, 5, 3, 0, 2]
print "before sorting:",
for num in arr:
    print "%d" % (num),
print ""
mergeSort(arr, 0, len(arr) - 1)
print "after sorting:",
for num in arr:
    print "%d" % (num),
print ""