Chasing the Next Transformer Killer - Part 1
Jul 20, 2024
Transformers have rocked the artificial intelligence scene since 2021 with what they can offer. ChatGPT broke the steepest customer acquisition statistics blowing everything else out of the water by a large margin. The transformer architecture was invented by Google researchers and published in the famous paper titled “Attention Is All You Need”. There were already models out there for years, but OpenAI hit a sweet spot with two major enhancements:
- Increasing the model size to be large enough to reach the level of exhibited intelligence that captured everyone.
- Strategically adding human feedback into the training loop in the form of RLHF (Reinforcement Learning from Human Feedback). This latter one is an extremely important “secret sauce” that sets apart ChatGPT from earlier attempts.
The motivation behind the transformer architecture was to overcome the limitations the earlier generation models were facing. The preceding models were some form of RNNs (Recurrent Neural Networks). The LSTM (Long Term Short Term) based architectures were improvements that are still used in various deep neural networks. Two major problems with them are:
- Even though they aim to cover both long-term and short-term remembrance of important parts of the incoming information, they can still forget certain bits too early, or remember other things for too long. The LSTM partially solves the vanishing gradient problem of less enhanced RNN autoencoders, but they are still prone to the exploding gradient problem.
- The training cannot be well parallelized due to the recurrence in the building block architecture. This means that to come to a certain point in the processing of a text, you must take it token-by-token to get to that state.
The transformer attention architecture doesn’t look at the text in that serial fashion as the RNNs. I must emphasize that transformers work also in an auto-regressive manner, so the generation happens token-by-token. Instead, it looks at the whole context window and weighs the relationship between all combinations of those token pairs. The result is that the attention heads will be able to focus on the important sections “magically” by just broadly looking at the whole window. The “magic” is the result of the training, and given the process and source data are adequate the neural network will be able to do that.
The fact that each token is weighed with each token within the window means many matrix multiplications of big matrices. That is actually what all the GPUs and AI hardware accelerators are mostly all about under the hood: multiplication of giant matrixes. Coincidentally other ML problems fit that capability too such as recommendation models.
The transformer building block has an encoder and a decoder section. A good article about some details and an illustration from KiKaBeN:
There were such original encoder + decoder models available along with encoder-only and decoder-only models. Each of them has their strengths:
Model Type | Example | Purpose | Strengths | Weaknesses |
---|---|---|---|---|
Encoder-Only | BERT (Google) | Text understanding and contextual representation | Captures bidirectional context, excellent for understanding text, pre-trained models adaptable for various tasks | Not designed for text generation, requires additional components for specific tasks |
Decoder-Only | GPT (OpenAI) | Text generation | High-quality and coherent text generation, autoregressive nature suitable for creative tasks, pre-trained models adaptable | Limited consideration of full input context, may struggle with tasks requiring deep understanding of complex relationships |
Encoder-Decoder | T5 (Google) | General-purpose tasks involving both text understanding and generation | Combines strengths of encoder-only and decoder-only models, versatile architecture | More computationally expensive, training can be more complex |
Encoder-Decoder | BART (Google) | Primarily text generation, but also effective for text understanding | Pre-trained on noisy text, combine bidirectional encoding and autoregressive decoding, strong performance on various tasks | Can be computationally expensive, pre-training on noisy text may introduce biases |
The bottom line is that transformers turned out to be excellent at focusing on important information parts and then generating results. And even though matrix multiplications (needed both during the training and inference) are costly and can be accelerated, the training procedure in general can be parallelized. That is extremely crucial, so much so that GPTs would not be feasible without that. Matrix multiplications are cubic (not even quadratic) asymptotically. The transformer-based services deal with quadratic algorithmic cost inflations in several aspects, such as context window size or KV cache size.
Long before the GPT boom researchers were already aware of these downsides and were actively working on overcoming them, see for example Performers or Reformers. Here is a comparison with factoring in Linear Transformers and FlashAttention versions:
Mechanism | Source | Time Complexity | Memory Complexity | Strengths | Weaknesses |
---|---|---|---|---|---|
Standard Transformer | O(n^2) | O(n^2) | Captures global dependencies effectively | Not scalable to long sequences due to quadratic complexity | |
Linear Transformer | GitHub | O(n) | O(n) | Efficient for long sequences | May sacrifice some modeling power compared to standard attention |
Reformer | GitHub | O(n) | O(n) | Efficient for long sequences, utilizes locality-sensitive hashing for local attention | LSH approximation may introduce noise or miss some long-range dependencies |
Performer | O(n log n) | O(n log n) | Balances efficiency and expressiveness, approximates standard attention using FFT | Approximation may be less accurate than standard attention in some cases | |
FlashAttention | GitHub | O(n) | Reduced | Very fast in practice, especially on GPUs, works with both standard and sparse attention | Relies on hardware optimizations, less portable than other methods |
FlashAttention-2 | GitHub | O(n) | Further Reduced | Improves upon FlashAttention by further reducing memory usage, while maintaining the speed and flexibility of the original. | Similar to FlashAttention, it relies on hardware optimizations and may be less portable than other methods. |
Click here for the second part of this post.
Thanks to Gemini Advanced for the comparison sheets.