This post is part of a series of paper reviews, covering the ~30 papers Ilya Sutskever sent to John Carmack to learn about AI. To see the rest of the reviews, go here.
Paper 9: Order Matters: Sequence to Sequence for Sets
High Level
A sequence-to-sequence model is a type of model that specializes in turning one sequence into another. Originally popularized for machine translation in 2014, this model became a cornerstone of the deep learning toolkit through the mid-2010s, eventually being the basis for Google Translate.
A sequence-to-sequence model is characterized by two RNNs (generally, an LSTM). One is an ‘encoder’ that produces, for the entire input sequence, a vector that represents that input. And the other is a decoder that takes in the input-vector and outputs a token one step at a time.
As it turns out, a whole bunch of things can be represented as sequences, even if they aren’t sequential. For example, there is a long history of parsing images using models that are ostensibly for sequences (PixelRNN, ViTs, etc.). That often raises the question: what is the right order for the sequence to use? In many problem settings, we have sets of data instead of sequences. Even inputs that are ostensibly sequential aren’t actually pure sequences — for example, sentences often may have dependencies where words depend on other words outside of the sequence. Order matters, but it’s not directly sequential. So being able to reason about sets matters, and has a lot of potential impact in downstream tasks1.
The first thing the paper explores is whether, for sets where there is no ‘natural’ ordering, the input order still matters. Empirically, there are already tasks where we know changing the ordering matters. For example, going back to Pointer Nets, if you sort the input the model learns better, faster. Intuitively this makes some sense. A sorted input may allow the model to learn representations of the individual set elements faster, which in turn will help it understand the entire set better. This is roughly the opposite intuition for why you want to randomize your training input. Imagine you were training a color detector model, and you didn’t randomize your input data. If your first 100 mini-batches are all red, your model will get stuck in a ‘local minima’ and will only ever predict red in the future.
Ok so maybe order matters. How do we then figure out what the optimal input ordering is for a set? Ideally we could learn some kind of map that is a function of the input, that determines which element in the sequence we should look at next.
Spoiler: it's attention.
Attention is an order invariant operation. You take a bunch of different vectors and do a weighted sum across them to get the output. Basic math: you can reorder the elements of a sum and get the same output.
The authors define a model where each element of a set is turned into a 'memory' vector. An 'LSTM' runs for t steps (it's not really an LSTM in the way we understand it. There's no 'input'. Rather, they run a set of weights N times, and each time they pull out the vector to pull in additional context for the next step. Only after the LSTM is done running t times do they feed the final output to the decoder), and at each step generates an embedding that is used to 'query' the different 'memory' vectors, creating the attention weights. These are then summed, and passed back into the LSTM. After doing this t times, the final output is fed into the decoder — in the paper they use a PointerNet, but it can be any decoder including another LSTM. The decoder also has the ability to generate attention scores to create context vectors from the input (they call this a 'glimpse' in the paper).
And as usual this all works well on some specific example task, in this case taking an unsorted input and sorting it.
Ok so all of this is about input sets. What about output sets?
First the authors again show that order matters in the output too. They do this by constructing a bunch of arbitrary tasks — language modelling, sentence parsing, etc. In each setting they vary the expected ordering of the output, and find that some orderings work much better than others. For example, in the sentence parsing setting, they discover that it is much easier for models to learn depth first parse trees than breadth first ones, even though they are logically equivalent. And in general, the authors motivate that it is important to pick a reasonable ordering — otherwise, if any output set is valid, "there are n! possible outputs for a given X, all of which are perfectly valid. If our training set is generated with any of these permutations picked uniformly at random, our mapping (when perfectly trained) will have to place equal probability on n! output configurations for the same input X. Thus, this formulation is much less statistically efficient."
How do you find the right output ordering though? There are n! possible options. The authors propose a two step training regime. First, they 'pretrain' the model to predict any of the n! possible sets. Then, during training, they pick a particular ordering weighted on the losses of the model above. So if the model finds it easy to output an ordered set like [1, 2, 3, 4, 5], that ordering will appear more frequently because the loss is lower. (And, in fact, the authors then show that where possible the model 'prefers' to converge on natural orderings like numeric-increasing or numeric-decreasing).
Woof. Long paper.
Insights
For the most part, the specific mechanics of this paper are not really that interesting. Especially in the year 2024. No one uses models like this, no one does this weird pretraining 'find the optimal sequence' step. This is not a paper that, in my opinion, has groundbreaking insights that will impact your day to day as an ML practitioner.
I think this paper is most interesting because it is trying to grapple with a huge problem: very few data problems are purely sequential. You never really have a sequence, for any interesting problems there are non-sequential dependencies that the model somehow has to…well, model!
And in order to solve that problem, this paper expands on some attention-based patterns that become bread-and-butter in later models.
One that stood out to me is the concept of the LSTM just doing 'processing' without inputs. This is a pattern that basically exploits the 'scratchpad' model of LSTMs. Each step, the model has the ability to read from and write to its own hidden state, pulling in relevant context from the input tokens as needed. At each timestep t, the LSTM uses the same weights. But this doesn't have to be the case, and in fact as we will later see with ResNets, it is very natural to use different weights to read from and write to this state. Later work in transformers calls this a 'residual stream'; this paper seems like a prototype for that kind of reasoning. I also see the hints of what would later become diffusion models — using the same weights to continuously modify a stream of data is exactly how diffusion models work for image generation.
Another pattern that stood out was the 'query' semantics deployed in the paper. Each step of the LSTM produces a 'query' that generates a weighted sum of the 'memory' vectors. This is, of course, just standard attention! But it's interesting that the query language was being used even before Attention is All You Need. I always personally thought the query/key/value language was arbitrary and confusing; now I see that it's coming from a much longer lineage of research.
I think when looking over this paper, it becomes clear that the sequential nature of the LSTM is a hindrance, or at least not particularly useful. The authors stick an LSTM in the middle of the model to generate these set-description vectors — vectors that somehow represent the input set in an order agnostic way. That makes sense, I get how it works, but I think my natural next question is 'why use an LSTM at all? What benefit does the sequential nature of an LSTM actually give us here?' And I would be right to think that! We get basically no benefit from using an LSTM, it's just what was popular back then. (Transformers are 'better' because their native structure is a graph. The QK attention matrix is a fully connected graph over tokens. And that is simply a much better representation for most tasks.)
One last miscellaneous thought. This paper really emphasizes how attention is useful in order invariant settings, because attention as an operation is itself order invariant. That makes sense, but importantly language generation isn't order invariant. You really do benefit from having a rough sequential ordering — Barack Obama is simply much more common than Obama Barack. This helped me better understand why positional encoding is so important for transformers, which only have attention and so have no native way to represent structure.
Naively, one way to reason about sets is to just pick a random ordering and call it a sequence. If the set is truly unordered, any sequence should work as well as any other sequence. But intuitively this maybe doesn’t work the way we want for a deep neural network. One important property of a set is that the representation of that set is order-invariant — that is, we should be able to swap around the sequence and still get the same representation. So if the words “quick brown fox” produce a vector embedding [0.1, 0.35, 6.231], I should be able to swap the words around like “fox quick brown” and get the same vector, [0.1, 0.35, 6.231].