P M A J E W S K I

Please Wait For Loading

Understanding K-Nearest-Neighbors (KNN) – Essential Machine Learning Algorithm - Software Developer's Tour

neighbors

Understanding K-Nearest-Neighbors (KNN) – Essential Machine Learning Algorithm

A quick introduction to this algorithm

KNN is used to determine the class of our sample data. To understand this more clearly, please take a look at the plot below. We have two groups of samples: blue ones and red ones. We know a priori which samples are blue and which are red. Now, however, a new sample has arrived – a green one. Our task is to determine whether this sample better fits with the reds or the blues.

plot1

As you might guess, the simplest way to assign class A or B to a new sample is to calculate the distance between our sample and other points. Why does KNN have a ‘K‘ in its name? The ‘K‘ represents the number of nearest neighbors used to predict a classification.

How do we calculate distance between samples?

The easiest way to calculate the distance is to use a formula we all know from primary school, called the Euclidean distance. Of course, there are other algorithms that are better depending on the distribution of our data.

( x - x ) 2 + ( y - y ) 2

Just code this function

def euclidean_distance(a, b):
  return np.sqrt(np.sum(a - b) ** 2)

Now, we need to find the K nearest neighbors. First, we need to calculate the distance between each point and the test point.

def get_neighbors(X, y, test_point, k):
    distances = []
    for i in range(len(X)):
        distance = euclidean_distance(X[i], test_point)
        distances.append((X[i], y[i], distance))
    distances.sort(key=lambda x: x[2])
    neighbors = distances[:k]
    return neighbors

In this step, we will calculate class for our test point.

This is the result of get_neighbors function.

neighbors

As you can see, we have three instances of class 0 and two instances of class 1. Therefore, when the KNN algorithm is run on this data with K equal to 5, it will classify the test point as class 0, since it is nearest to this class.

This is how it looks like

final plot

Our test point has three connections with the blue ones and only two connections with the red ones.

d(P, Q) = i 1 ^ n ( ( p _ i - q _ i ) ^ 2 )

Our function for calculating the Euclidean distance, called euclidean_distance, is designed to solve real-life problems that are not limited to two dimensions but can also handle n-dimensional scenarios.

leave a comment