Neural networks and boosted decision trees have proven to be some of the most well-performing machine learning methods. When it come to dense continutous data such as audio signals and images, it’s the networks who win. When it comes to sparse discrete data it’s typically the trees. I’ve seen this a couple of times and is also illustrated in this blog post:
How are trees and networks different and how are they the same?
This begs some interesting questions:
- How are trees and networks different?
- Do trees have some fundamental advantage when it comes to discrete data?
- Do weak but correlated features require a fundamentally different technique that discrete (and strong) features?
Understanding the differences might be beneficial not only from theoretical perspective, but also may lead to practical improvements. For example, it is believed that deep neural networks require much more training data than other classifiers (including trees). Neural networks are also much more sensitive to irrelavant features and won’t tolerate features at different scales, while trees are fine. By relating networks to trees we might be able to create neural networks that are more robust.
What I want to discuss in this blog post is that networks and trees share some fundamental building blocks, but follow a different topology. They are also trained with different strategies. My discussion also touches upon how search in a high-dimensional space is related to machine learning in the same space and when search is more advantageous to machine learning and the other way around. Upon examining some of the literature it turns out that both trees and networks can used for search and machine learning. However, so far neural networks have not been used for nearest neighbor search.
Deep neural networks are strongly hyped while decision trees are not. Should trees be hyped as well? After all, they share similar foundations.
There are two obvious similarities between trees and networks:
- both trees and networks are types of graphs
- both are trained with gradient descent
In this blog post, I’m going to concentrate on the graph aspect and ignore the training method.
From nearest neighbor search to advanced machine learning
Given that I have spent a few months creating a fast and scalable nearest neighbor search library, it’s unsurprising that I’m starting here. But, here’s what happened. As I was working on the nearest neighbor problem, I kept encountering various kinds of issues. In the end, many of those issues were resolved by the same techniques that are widely used in machine learning. As it turns out, a deep understanding of the nearest neighbor problem can actually lead to deeper understanding of the workings of SVM, Regression, Decision Trees and finally Neural Networks.
The nearest neighbor problem explained
As one of the running examples I am using the mnist dataset (as prepapred by Kaggle). This is a very simple dataset, but nevertheless very useful to illustrate some concepts such as dimenionality of the data, sparsity, the value of thresholding, specifying patterns with dot products and so on.
The mnist problem is a typical supervized problem: given an image of a digit, we want software to output the “label” of the digit (as text). There are ten possible digits, so this is a 10 class problem.
This is imporant! Machine learning is most effective on problems that contain a small number of classes. In the case of MNIST we have to predict out of ten classes. Some other problems such as whether a customer will buy based on our email or not require predicting one out of two classes: When the output falls into more classes, the problem is harder. When there are very large number of classes, there are tricks that will be mentioned later on.
The nearest neighbor problem is one of the most trivial machine learning algorithms, yet it is deeply connected to the most complex algorithms. This method works like this:
Nearest Neighbor Algorithm -------------------------- input: a query image algorithm: 1. find the most similar images to the query in the training database (here the cosine similarity will do, more advanced methods will be mentioned later) 2. take the most similar $k$ examples, and compute (a weighted) majority based on their labels 3. return the most common label
Now there are two problems with this method:
First it’s slow. This is because we need to perform a search over the whole database.
Second, under certain conditions (that will be mentioned later) that are typical in practice, the similarity used to find the nearest neighbors will be quite noisy.
Those two issues are due to what is known as the Curse of Dimensionality. It’s really the same problem, but the above two are different aspects. The first aspect is the computational one, while the second has to do with the quality of the similarity measure.
Cosine as similarity measure
From nearest neighbor search to machine learning.
To see how nearest neighbor search might work one can imagine the geometry of the data points. Those points are situated in a high dimensional space (dimension = number of features). Then really the closest points are lying in a small sphere around the query (see curse of dimensionality point 1). The efficiency of the search is related how to get to the points in that sphere without doing a brute force search.
One way this problem is solved, is to put a tree structure just like a decison tree in machine learning and use this structure to navigate the geometrical space. Walking a tree is making a series of decisions: walk left or right. Every decision is a similarity comparision (a similarity comparision using the cosine), to two points (left or right). Simply we are asking if the query is closest to the left point or to the right point (see curse of dimensionality point 2). In the end we reach a leaf which contains a predefined number of points (for example 100 points). In nearest neighbor search we are going to need each of those points and examine their label. However, one can notice that since only the labels are needed, we can discard the actual points and keep a distribution (table of counts) at the leaves. Here we are already reaching one difference between search and machine learning. In search we want the actual points, while in machine learning we are satisfied with an aggregate histogram. From a quality perspective,
Facts about the curse of dimensionality
- Sphere with a small radius has a lot of volume
close in feature space, close in label space
each is easy to do, but doing simultaneously not so easy