At this point I'm going to start using 'I' rather than 'We'. The 'We' was only really interested in the 'what if we stuffed audio into it?' question, and has now gone back to building complicated DSP things in Plogue.
So. Reading around on what was causing the loss of the network to fluctuate wildly over time, basically diverging from a solution rather than converging to one, it turns out this has something to do with gradient descent and the learning rate. In order to explain what was going on I'll take a simple case of using gradient descent and how it works solving a simple problem. This isn't specifically related to neural networks (It turns out gradient descent is a technique that has more general application), but it will hopefully illustrate the basic idea behind what was going wrong in our attempt to train the char generating rnn with audio data.
Simple linear regression by gradient descent.
(relevant XKCD: Linear Regression)
From wikipedia, "linear regression is an approach for modeling the relationship between a scalar dependent variable y and one or more explanatory variables (or independent variables) denoted X. The case of one explanatory variable is called simple linear regression."
Most people know it as 'fitting a straight line to the data'.
Given a set of (x,y) data, we would like to plot a straight line through it to either see a trend, or make possible predictions about what y may be given a particular value for x, or make a decision about what y may be in the future. An example from my day job could be 'Given data on the number of tapes used by our daily backups, at what point in the future might we consider buying more tapes to ensure we don't run out'. Now, there are other ways to achieve this than using gradient descent, but they are not of interest to us here. We want to use gradient descent to hopefully throw (some) light on what (may) have been happening when our network was failing to train.
Lets assume we have been measuring the tape usage over time of our backup system, and we have a dataset that looks like this
Week | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
Tape use | 13 | 15 | 21 | 18 | 21 | 22 | 21 | 24 | 22 | 25 | 26 | 28 | 25 | 29 |
We can see tape use fluctuates over time - some of our data expires and tapes are freed up, but tape use does seem to increase in general - overall we are generating more and more data that needs to be backed up.
(Note, this is an entirely fictional situation - we use very different techniques to actually analyse things like tape use in the real world)
When we plot this out as a graph, we get the following (weeks along the x axis, tapes on the y axis)
It's a fairly simple plot, and it's pretty easy to see where the straight line fit should be; so we can be sure with a visual check that our gradient descent method is actually coming up with the correct fitting line.
Now, we know that the equasion for a straight line is
So we are looking for the values of c and m that will produce the straight line that best fits our data.
We could, if we really wanted to, start trying values of c and m by hand, and keep drawing out the lines until we found one we are happy with. That might even end up being pretty straightforward for our simple dataset. But judging the line by eye means we may not get it exactly right, and what we really want is an algorithm that will find it for us with the least amount of work possible on our part. I'm a developer, and a lazy developer is a good developer, because they are always looking for ways to get the computer to do as much of the work as possible.
If we want to do this automatically, then first we need a way of judging how good any given pair of c,m values are at generating a line that fits the data. One good way of dong this is to measure how far away all the data points are from the line - the best possible line will make sure all the data points are the smallest distance away possible. We can measure the distance of a datapoint from the line by measuring the y value of our line for a given x value, and subtracting that from the real y (tape use) value of the datapoint at that x (week) position.
So for example, we know that the data says at an x (week) value of 10, the y (tape use) value is 25. Lets suppose we had chosen a value for c as 2 and m as 3, and we want to see how well the line produced by y=c+mx or y=2+3x 'fits' the data at x (week) = 10.
We calculate the value of y=2+3x for x=10, and we find how far away this is from the real data value of week 10 (x) = 25 tapes used (y). We get
Our c and m values give us;
y=2+(3*10)
y=32
Our real data tells us that y is 25. So, our 'prediction' using c=2 and m=3 is wrong by
If we repeat that process for all of the real data points we have, and add up all the differences, we will get the total distance from the line formed by y=2+3x to all of our data points.
or will we?
Some of the values may result in negative distances! if our line predicted a value of 20 but the real data point had a value of y=15 then we would get a distance of -5! obviously negative distances don't make sense here, so we use a mathematical trick to ensure the value is always positive. We square the result for each datapoint, this ensures the distances are all positive (There are some other reasons this is a good thing, but we wont go into that right now. Just trust me and be happy that we have decided to square the resulting distances to make sure they are all positive)
So, we have figured out a way to calculate how badly our line fits the data for a given x value. We will call this the 'Cost function', and we can write it like this; (I'm going to use pseudo-style code, rather than real mathematical symbols. This might be confusing if you come from a mathematical background, but should be fine if you are from a coding background. Deal with it.)
j = ((c+mx)-y)^2
We'll use j because we already used c. y is our real data value for any given x. (tapes used for any given week, in our example)
We can calculate the total cost for any line generated by a guessed c,m pair by adding up all the costs for all the real data points we have. So, the full cost (we'll call it J) of a given guess of our c,m values can be written as;
for all the x and values in our real set of data.
This is known as the square cost function.
Ok. Now we know how to figure out how bad of a guess a particular c,m pair is. Our goal is to minimise that cost, and when we have the smallest cost, we know we have the line represented by c,m that best fits our data.
So now what? Do we just guess loads of values for c & m, calculate the costs and then pick the guess with the smallest cost?
Nope. This is where gradient descent comes in (finally!).
Now.. one more thing. The actual formula for gradient descent in the case of our square cost function (it is different for different functions) requires some figuring out of differentials. Calculus is beyond the scope of this blog (and me, mostly ;) ) so trust me on this next step.
Because of the calculus we don't want to get into right now, we are going to add a 1/(2*n)* to the front of our square cost function J. Trust me, a mathematician told me it's fine. (She whispered something about the differential of x^2 in my ear, and when I woke up she was gone). This makes our cost function J look like this;
for all the x and y values in our real set of data.
You might notice an 'n' has crept in there. 'n' is the total number of datapoints we have in our real data set. (14 in our example, for the 14 weeks)
Now, back to gradient descent. It turns out, that if we draw a graph of the full cost J for all possible c,m pairs, we get a graph that looks sort of like this;
The line J represents the full costs for all of the possible c,m values we could choose, when we compare the c,m values with our real data set. What we are interested in is the values of c,m when the line J is it's lowest value, right at the bottom of that curve.
This is the point that Gradient Descent will help us find.
Now, I'm not actually going to go into the maths of gradient descent. We don't need to, to get an intuitive feel for what it does, and besides it involves more calculus.
So what does it do?
Well. In the case of gradient descent for simple linear regression (that's what we're doing - fitting a line to the data) we start by giving Gradient Descent a guessed pair of c,m values. Gradient descent takes them and calculates the cost, and figures out whereabouts on the curve above that cost falls. Then, it looks at the slope of the curve at that point. If the slope is positive (IE, it goes from bottom left to top right) it produces a new guess for c,m that is a litte farther left along the curve. If the slope is negative (it goes from top left to bottom right) it produces a new guess for c,m a little to the right. This picture hopefully helps illustrate that;
Simple, right?.. Gradient descent will take our guess, figure out how much it costs, and then produce a new guess that will nudge the cost value along the curve towards the curves minimum.
Then, you take the new guess produced by gradient descent and feed it right back into the gradient descent formula. You do this iteratively, and eventually it will produce a guess that is very close to the minimum point of the curve. At this point, the slope of the curve is neither negative or positive, and gradient descent will produce a guess that is the same as the number we plugged back into it. Et Voila, we have discovered the c,m values that minimise our cost function. We have found the values that will produce a best fit for our data.
Hooray! Lets go home!...
not so fast!
Remember, we are going through all of this to try and figure out why our neural network didn't appear to be 'learning' when we fed it audio data - the value of the 'cost' for the network was getting bigger! not smaller!
Theres one thing we haven't talked about yet; the Learning Rate. Inside the gradient descent formula there is a term called the learning rate. We will call it L for now. Don't panic, there isn't any more maths coming, it's just easier than typing Learning Rate all the time.
L is a number in the gradient descent formula that multiplies the size of the 'step' taken along the curve when gradient descent makes a new guess for the c,m values. Now, if L is too big, gradient descent might produce a new guess that is much too far along the curve, and it ends up missing the minimum point completely! We carry on feeding the guesses back into gradient descent, and it carries on producing new guesses that overshoot the minimum. If things are really bad, it will even produce guesses that start to climb right back up the curve - ignoring the minimum point altogether. It's guesses just get worse and worse and bigger and bigger until computers fall over and the sky falls down, or something. It doesn't matter. What matters is, it's broken and it isn't going to work.
And essentially, we can see that intuitively as what was happening when our char predicting RNN was having so much trouble trying to learn our audio data. Our 'Cost' was not able to settle on a minimum, and kept hopping about all over the curve. In reality, I'm pretty sure it's much more complex than this, but this essentially captures the problem. Probably.
I say probably because.. even when we made the learning rate very very small indeed, it still had troubles, although the cost would achieve lower values than it did before. So I'm sure I'm still missing something. - probably something to do with trying to make the character predicting RNN work with audio data by ramming an alphabet of 65000 letters into it's inputs and hoping for the best.
Setting aside our attempts to abuse a character predicting RNN to generate audio though, this general intuition regarding gradient descent isn't a bad approximation of what happens when the cost function for a particular neural network starts to get bigger and bigger, or just fails to settle down to a minimum. So that's something!
Next Time...
Next time, we will use all of this to actually fit a straight line to our tape usage data, and we can see some real examples of Gradient Descent and the learning rate in action! - If your curious and know a bit of python, I have an implementation of simple linear regression by gradient descent here Simple Linear regression by Gradient Descent
Have fun!
No comments:
Post a Comment