What my AI learnt on its holidays

What my AI learnt on its holidays

I spent a bit of downtime recently playing with LSTMs and RNNs - Long Short-Term Memories and Recurrent Neural Networks. These are the kind of things that power automatic image tagging, recognition, and generating silly names for paint colours. As is usual for me, I spent most of my time having no clue what was going on, and to a certain extent it's all still magic in there, but Andrej Karpathy has a great explanation of the concepts here.

I decided to look at character RNNs, which are essentially sequence predictors for text. For example, if the network has been trained on a list of greetings and it sees the sequence "How are" it knows that the most likely continuation is going to be " you?"

This may seem a bit useless and somewhat of a toy, but imagine if we were able to train it on a huge quantity of data from online support chats so the network learns that, "what is the support telephone number?" is always followed by "0800 811 8181".

I started out by downloading Christian Baldi's Docker images for building RNNs in Torch, because setting up Torch is a bit of an involved process and I wanted to get straight to the fun bit, which is watching the progress report of the AI trawling through the input data for several hours.

The reason for this time is you usually need a lot of input data if you want an interesting RNN. Usually. There's a little bit of Shannon there in that predictable datasets don't need to be quite so big for the network to figure them out.

For example, if I train a network with a memory of two characters on the following dataset:

Toast.
Toast.
Toast.

(Repeat for 126 characters in total. There's not much data here.)

Within a few epochs (one "epoch" is a single run through the input data) asking the neural net to generate me 20 characters results in this:

t.
Toast.
Toast.
Toa

Note that because I haven't primed it with any input data, it picks a random character from its vocabulary, which in this case happened to be "t". It's pretty good at picking up from there, though.

Also, because the network has no concept of words or anything like that, it quits after 20 characters, right in the middle of a word.

When I'm running the network I can tell it to run at a higher or lower "temperature" - how confident it needs to be in the next option before it picks it. Higher temperatures give more variation to the result and avoid the networking fixating on a single sequence, but at the cost of increasing the chance of a mistake. Obviously with a data set this simple and easy to predict, it's rare a mistake is made and typically only once in 10,000 characters will it output something like "Tost."

However, what if my data is less trivially predictable? Say a list of ways to say "Hello" in different languages:

Hello
Hola
Bonjour
Hallo
Hej
Guten Tag

I fed a whole 214 characters of this to the network - way more data than the toast example which resulted in an almost perfect network!

When asked to provide 50 characters, this network gives me the following:

Zdjt
Kaeiur
eonytai
oba
raon
Gonravo
Zdravou

That's a very confident network. What about a less confident one?

Hallo
Hej
He
Hallo
He
He
Hallo
He
He

The unconfident network is just picking bits of the input dataset it memorised, repeating itself, and still making mistakes. My dream of having a neural network that responds to "Hello" by saying the same thing in a different language isn't going to happen with such a small dataset, especially when I haven't even organised it into such a call-and-response structure.

This problem is obvious even while I'm training the model. I have two "loss" values which are output by the learning process. One is the "training loss" - this is how close my model is getting to the training output given the training input. The other is printed every time we save a checkpoint of progress, and is the "validation loss" - how close my model gets to a set of validation output when run on some validation input it's never seen before.

When I trained the toast network, both losses got almost to zero, as you'd expect. When I trained the hello network, while it got reasonably low losses while training, its validation loss was hopeless. This was not unexpected. What happened was "overfitting" - it didn't learn the implicit rules of the input data, it just memorised random ideas that happened to work on the training set.

If I had enough input data the model may have been able to gradually learn character sequences common to romance languages, Germanic languages and so on. With that knowledge it could get pretty close on the validation set. But with only 30 or so examples it hasn't a hope. All it can learn is that greetings usually start with a capital letter and definitely end with a line break, and that if it's not sure about things then best to start each new line with an "H".

So if I want to be able to model conversations, I need a decently big dataset that will let the model identify norms and gather useful data. I also need a larger and more powerful network that can deal with longer sequences, so I get meaningful conversations rather than the network just spelling out words it's seen previously. This becomes a challenge because bigger networks make it much easier for the model to overfit by remembering the things it's seen rather than the rules to produce them.

I have the ideal thing for this, which is a trove of logs from an IRC channel of University friends that's been going for years. Typical content might look something like this (names changed to protect the somewhat less than innocent):

15:04 < jim> they're using the international platforms now?
15:04 < Fox> Only for screwups atm.
15:04 < Fox> Regular usage next year.

There are megabytes and megabytes of this stuff. Enough for everyone in the channel to have their own network to reflect their distinct personality?

Well, here's the output of a network trained from someone who's said a fair amount over the years:

eond this now Denowh stued of Tidy!
Hanshick poon of Jose beer extelesk?
Honter.

Ah. The network gets the structure of English, it knows a few words, and it's making an attempt to punctuate - but there's not enough there for it to figure out full sentences. Someone who's only been an occasional contributor fares far worse:

wharC ut?Fwo mhot vende.
Wade ub (udpsot. I pews wufal- inh Ho?
As.

At this point the network is even struggling with getting enough data to infer how English works, not helped by this being a channel which often inserts URLs, bash commands and other such punctuation soup into the conversation.

So let's try training the network on the entire channel instead, leaving it running for as long as it takes the validation loss to reach its minimum.

--- Log opened Sun Oct 21 00:00:00 2017
01:28 <@Bob> Ever search from the future because at the mid-wing rail names.confscore thing :-P
11:11 < Fox> I really understand Terraform packs hold it about a builder.
11:03 < Fox> Do you work :-P
20:19 < Zap_> you should always go out
11:12 < Lizard> just should not be
20:14 < Fox> The only way to find something like "it's 7000 again not entertalty."

What's interesting here is that the network learnt the structure of an IRC log, including the incidental messages. It occasionally goes on a weird detour when it prints the log open messages or things like nickname changes, but it's extremely rare for it to make a mistake in producing an HH:MM time followed by a nickname in angle brackets. It almost always uses a real nickname from the channel, and works out how people punctuate. If you use lowercase in the channel, the neural network uses lowercase reproducing your text. It's learnt that certain people use certain words, and others don't.

What it's not so good at is context. It's outputting what would be plausible sentences - if you were a machine that doesn't understand conversations, meaning or even why it can't invent new words like "entertalty" if it feels like they ought to fit. It knows what times look like, but it doesn't understand that "11:12" can't come after "20:19" without a "log closed" and "log opened" message in between.

Even when it prints the occasional nugget of wisdom like "you should always go out", it doesn't know what it's doing. As with voice recognition, it's the context gap which causes the problem.

Well, maybe it doesn't need to know what's going on to have a pretend conversation. Let's prime it with some start text. I'll feed it with a common question from this channel: what time are we thinking of going to the pub? There are loads of two-line conversations of this form in the history, so the network should be able to answer that...

17:33 < Fox> What time were you thinking?
13:30 < jim> oh anymore, just about 500 monitoring and it search

Oh well. Maybe I'll ask the channel for its opinion on some common stuff:

Docker is a chilt tower battery letter
Linux is not thoroughly.
AI is also exploit from your new.jpg and test.
programming is insanely that way to cause this

Well, I guess the last one is sort of apt.

The problem here is that training such a generalised network on such a badly organised data set isn't going to give us anything useful. It's doing some very impressive things here - it's learning how an IRC log works, how the English language works, even how different people type based on nothing more than the probability of character sequences. But I don't have a task for the network other than "produce likely sequences", and I don't have a well-organised set of training examples targeted at solving any one problem.

So what's the lesson here? Other than, "you can amuse your friends by having a neural network try to impersonate them"?

Firstly, that we need a decent amount of data, that the data and training examples are correctly structured for our problem and also that the level of entropy in the data is low. You can see that in the IRC example - the network figured out the low-entropy stuff (what a log line looks like, who is speaking, what an English sentence looks like) very effectively. If we want a conversational bot that can answer basic questions then we'll need enough question/response data for the entropy of that to be similarly low. That's a lot of data and most likely a large network needed to work with it, which means a long training time even if we have some powerful GPUs to hand.

Secondly, that the totally generic character RNN isn't necessarily the right solution to the bot problem. Given how much of the network gets taken up with learning to speak English, its propensity to make mistakes even so, and the fact it doesn't understand the concept that it's being asked a question, maybe we want to train the network to extract intents and parameters instead. A simpler task for the network to figure out, but one that would require a lot of data preparation to create good training and validation sets for this.

In some ways, this takes me back to the days of doing linear regression. Back then it was all about making sure you had the right input dataset and sufficient volume of data to give you worthwhile results. RNNs pose many of the same problems in making sure your data has low entropy and the right sort of entropy - for example, a common problem in image classification is that the network learns to identify green fields with the tag "sheep", because it's rare for it to see a photo of a sheep in something other than a green field, or a green field that doesn't have at least a couple of sheep in it.

Good data and training are essential. But even in the absence of these, I at least got a network to say some preposterous things.

14:18 < jim> I need a dog and they just ignore the latest training and periods that does not use sleep