Nearest Neighbor and Cross-Validation (complemented by python)

    科技2022-07-14  110

    k-Nearest Neighbor

    Nearest Neighbor

    Compare with the distance between the training set Distance:

    L1 distance

    L2 distance

    Algorithm:

    class NearestNeighbour(Object): def __init__(self): pass def train(X_train, y_train, self): self.X_train = X_train self.y_train = y_train def predict(X_test, self): y_prd = [] num_train = self.X_train.shape[0] num_test = X_test.shape[0] l2 = np.zeros[num_test, num_train] l2 += np.sum(self.X_train ** 2, axis=1).reshape(1, num_train) l2 += np.sum(X_test ** 2, axis=1).reshape(num_test, 1) l2 -= 2 * np.dot(X_test.T, self.X_train) l2 = np.sqrt(l2) y_prd = self.y_train[np.argmax(l2, axis=1)] return y_prd

    k-Nearest Neighbour

    To find the k closest label and vote the best label Alogorithm:

    '''same as above''' class kNearestNeighbour(Object): def __init__(self): pass def train(X_train, y_train, self): self.X_train = X_train self.y_train = y_train def predict(X_test, k, self): y_prd = [] num_train = self.X_train.shape[0] num_test = X_test.shape[0] l2 = np.zeros[num_test, num_train] l2 += np.sum(self.X_train ** 2, axis=1).reshape(1, num_train) l2 += np.sum(X_test ** 2, axis=1).reshape(num_test, 1) l2 -= 2 * np.dot(X_test, self.X_train.T) l2 = np.sqrt(l2) for i in range(num_test): y_closest l2_index = np.argsort(l2[i])[0:k] y_closest = self.y_train[dists_index] y_pred[i] = np.bincount(closest_y).argmax() return y_prd

    Cross-Validation

    A more sophisticated technique for hyperparameter tuning Typical number of folds would be 3-fold, 5-fold or 10-fold cross-validation

    Visiulization Algorithm:

    # plot the raw observations for k in k_choices: accuracies = k_to_accuracies[k] plt.scatter([k] * len(accuracies), accuracies) # plot the trend line with error bars that correspond to standard deviation accuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())]) accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())]) plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std) plt.title('Cross-validation on k') plt.xlabel('k') plt.ylabel('Cross-validation accuracy') plt.show()
    Processed: 0.012, SQL: 8