/ Machine Learning

The Black Box of Deep Learning

Deep Learning has many different definitions depending on who you ask. I like to think of deep learning as machine learning using deep neural networks — i.e., networks with more than one hidden layer. But that’s just one possible definition; I’m sure there are much better ones out there.

Deep Learning is responsible for numerous breakthroughs in machine learning and artificial intelligence. It’s being used to solve problems previously thought to be unsolvable by a computer. But it’s not without its problems, chief amongst which is the so-called black box problem.

In machine learning, particularly supervised learning, an algorithm looks at a bunch of data, finds patterns in it, and builds a model that can be used for prediction or classification tasks. What sets machine learning apart from traditional statistics is that the computer builds the model, not us. That’s a huge benefit, but depending on the learning algorithm we choose, that model might in fact be a black box: something we can feed inputs to and get outputs from, but whose inner workings can’t be interpreted verbatim.

How is this even possible? It sounds a bit crazy. Even though a computer can build a model for us, there’s nothing stopping us from peering inside, right? Why can’t we look at the model and get some insights? In order to explain, let’s start with something that isn’t a black box: linear regression.

The core idea of linear regression is pretty simple: find a line that best fits the points in our data set. Let’s say we want to predict a person’s height based on their weight. In mathematical terms, we start with a hypothesis: height = a * weight + b. The job of the linear regression algorithm is to iterate over our data and find the best values for the parameters a and b so that, if we plot a * weight + b, the line will fit our data points as well as possible. That’s all there is to it.

Linear regression is one of simplest machine learning algorithms out there. Unfortunately, it’s not suitable for solving more complex, non-linear problems. But the models created by linear regression are easy to interpret, and that’s a profound benefit. In the example above, let’s say our learning algorithm found the best values for a and b to be 2.8 and 14, respectively. Based on this information, we can say that for a unit of increase in weight, the height increases by 2.8 units. We can also calculate how much of the variance of our dependent variable (height) is explained by our independent variable (weight).

In industry, this type of interpretability is often very useful. Let’s say you’ve used machine learning to build a linear regression model that is great at predicting what a product a certain customer is likely to buy. You nervously present your work to the management team. Chances are someone will ask “so what exactly contributes to what product a customer is likely to buy?”. Someone might even want a detailed breakdown of the rules by which the algorithm makes its prediction. If you are using a white box model, answering such questions is possible.

If you are using neural networks, it is not.

The term neural network describes a learning algorithm architecture wherein neurons, each doing some simple calculation, are connected to each other via a set of weights. These neurons are organised into a number of layers (one hidden layer, one output layer and one or more hidden layers). For any given input, a neuron does a relatively simple calculation, but the result of that calculation gets multiplied by a weight and fed as input into all the other neurons it’s connected to. Those then perform a new calculation, and the cycle is repeated. Since it’s relatively common for a neuron to be connected to tens or even hundreds of other neurons, if you attempt to figure out what a neuron in deeper layers actually calculates, you’ll end up with a ridiculously complex function. It’s nigh on impossible to figure out the true form of the function that is being approximated.

As if that wasn’t enough, it’s also entirely possible for a bunch of neural networks with the same architectures but different weights to solve a task with the same accuracy. Why neural networks work as well as they do is still an active area of research.

Why use neural networks, then, if they suffer from the black box problem? Simple: for some tasks, neural networks–deep neural networks in particular–perform better than other learning algorithms. Much, much better. The vast majority of recent breakthroughs in machine learning use deep neural networks in one form or another. As we try to solve increasingly complex problems, this situation isn’t likely to change any time soon. The black box is here to stay, at least for a while.

And that’s fine, in my opinion. In the end, choosing to use a neural network over some other learning algorithm/architecture is a tradeoff between interpretability and performance. If interpretability is important to your business, neural networks aren’t a good fit. I, however, maintain that understanding exactly what contributes to a learning algorithm’s decision-making should almost never be valued over model accuracy. It’s a somewhat controversial opinion to have, but let me leave you with an example:

You’ve trained a neural network that can detect malignant tumours more accurately than any other model in the world, but you don’t know how it decides if a tumour is malignant or not. Do you replace it for a less accurate model?

I know what my answer would be.