Exploring Phi-3 as an MLE
Comprehensive deep dive into the technical details and modeling choices.
Phi-3 is an ambitious new large language model developed by Microsoft with the goal of advancing the state-of-the-art in natural language processing and generation. Phi-3 introduces several novel architectural choices and optimizations that allow it to achieve strong performance on a wide range of tasks while still maintaining computational efficiency.
In this blog post, we'll do a thorough review of the key components and design decisions in the Phi-3 modeling code. By the end, you'll have a solid understanding of how Phi-3 works under the hood and the motivations behind its architecture.
Encoder-Decoder Architecture
At its core, Phi-3 employs a Transformer encoder-decoder architecture, which has become the de facto standard for large language models in recent years. However, Phi-3 focuses solely on the decoder part, making it well-suited for autoregressive language modeling tasks like open-ended text generation.
The main Phi3Model class defines the high-level structure:
An embedding layer maps token IDs to dense vector representations
A stack of Phi3DecoderLayer modules applies the core self-attention and feedforward computations
A final normalization layer on top of the last decoder layer's hidden states
Flexible support for different self-attention implementations (more details below)
This modular design allows for easy experimentation with different layer configurations and building task-specific models on top of the base Phi3Model.
Phi3DecoderLayer Deep Dive
Each Phi3DecoderLayer performs the core computations that progressively refine the hidden states over multiple stacked layers. Here's a step-by-step breakdown:
Input normalization using Root Mean Square (RMS) layer normalization. This helps stabilize training.
Multi-headed self-attention (Phi3Attention):
Performs a query-key-value projection on the input hidden states
Applies rotary positional embeddings (RoPE) to the queries and keys
Splits into multiple attention heads to allow capturing different aspects of context
Computes the raw attention scores via dot product of queries and keys
Applies attention dropout and a softmax to get the final attention weights
Outputs a weighted sum of the value vectors using the attention weights
- Residual dropout and layer normalization:
A residual connection adds the self-attention output to the original hidden state
Dropout is applied to regularize the self-attention activations
Another RMS layer normalization is applied after the residual addition
- Position-wise feedforward network (Phi3MLP):
An upproject linear layer expands the hidden dimension
A non-linear activation function (GeLU or SwiGLU) is applied
A downproject linear layer returns to the original hidden dimension
Dropout is applied to the feedforward activations
- A final residual connection adds the feedforward output to the normalized hidden state from step 3
This multi-step refinement process allows each Phi3DecoderLayer to iteratively transform the hidden representations by incorporating contextual information through self-attention and capture higher-level features via the feedforward network.
Rotary Positional Embeddings (RoPE)
Instead of learning separate positional embeddings and adding them to the input embeddings, Phi-3 uses rotary positional embeddings (RoPE). RoPE represents position information by rotating the hidden states in the self-attention layers.
The key advantages of RoPE are:
Positions can be represented relative to each other vs absolute positions
No separate positional embeddings, reducing model size
Easier generalization to longer sequences than seen during training
More computationally efficient by avoiding a separate position-wise addition
Phi-3 supports two RoPE variants:
Phi3RotaryEmbedding: The default RoPE that uses a fixed max sequence length
Phi3LongScaledRotaryEmbedding: A variant that scales the angle frequency to better handle longer sequences. It applies different scaling for shorter vs longer sequences.
The rotary_pos_emb function is the core implementation that takes care of applying the rotary embeddings to the query and key vectors. This modifies the attention scores to be position-aware.
Efficient Self-Attention Variants
Computing full self-attention incurs a quadratic memory and time cost with respect to the sequence length. To mitigate this, Phi-3 offers several optimized self-attention implementations:
Phi3Attention: The default vanilla self-attention using einsum ops. Most flexible but least memory efficient.
Phi3FlashAttention2: An optimized variant using the custom CUDA kernels from the Tri Dao's FlashAttention library. Benefits:
Fused key-value projections to improve memory access patterns
Fused attention softmax and dropout for better GPU utilization
Support for unpadding sequences to attend only to non-padding tokens
More memory-efficient by recomputing attention weights during backprop
- Phi3SdpaAttention: Uses PyTorch's native scaled_dot_product_attention op which is decently optimized.
Having these different self-attention implementations allows easily trading off compute vs memory depending on the hardware and model scale. The FlashAttention variant tends to be the most computationally efficient.
Feedforward MLP Variants
The feedforward network (Phi3MLP) is the other key component in each decoder layer. Phi-3 uses a gated linear unit (GLU) structure:
Upsample hidden dim via a linear projection
Split into two chunks and apply an element-wise multiplication
Apply a non-linear activation function
Downsample to the original hidden dim via a linear projection
The implementation also supports using the SwiGLU activation (based on the Swish-1 function) which provides a nice boost over regular GeLUs.
For SwiGLU, Phi-3 can use a custom CUDA kernel from the FlashAttention library that fuses the Split, Swish and element-wise Multiply ops for greater efficiency.
The final piece is a residual connection and dropout layer to help training convergence and generalization.
Layer Normalization
Phi-3 employs RMSNorm (root mean square normalization) for both the input normalization and post-attention normalization in each decoder layer. RMSNorm offers a few benefits over vanilla LayerNorm:
Normalization based on the root mean square (RMS) instead of the mean and variance. More numerically stable.
Often converges faster and to better validation loss than LayerNorm.
Can optionally use a CUDA kernel version from FlashAttention for better compute efficiency.
The Phi3RMSNorm class provides a simple interface that takes care of the RMS computation, epsilon addition, and scaling by a learnable weight vector.
Specialized Output Heads
To make Phi-3 easy to adapt to different tasks, the code provides a set of task-specific head modules:
- Phi3ForCausalLM: A causal language modeling head that predicts the next token via softmax over the full vocab. Key pieces:
A linear projection maps the hidden states to vocab-sized logits
A CrossEntropyLoss is used to compute the language modeling loss
Includes specific logic for handling left and right padding
Supports fast autoregressive generation via caching of past key values
- Phi3ForSequenceClassification: A sequence classification head for tasks like sentiment analysis or textual entailment. Key aspects:
Pools a representation of the last hidden state, either via direct selection or an attention-weighted average
A linear projection maps to the number of output classes
Supports both regression (MSE) and classification (cross entropy) losses
Handles edge cases around different padding approaches
- Phi3ForTokenClassification: A token classification head for tasks like named entity recognition. Main functionality:
A linear projection maps each token's hidden state to the output classes
Optionally applies dropout for regularization
A cross entropy loss is used for the classification objective
These output heads make it simple to fine-tune Phi-3 for common NLP tasks with minimal code changes. The shared underlying Phi-3 architecture can then be leveraged for transfer learning.
Training Optimizations
To make training Phi-3 as efficient as possible, the code includes a few key optimizations:
FP16 mixed precision training via NVIDIA Apex AMP. This allows using half-precision tensors for most of the compute while still maintaining similar accuracy as full FP32 training. Key benefits are cutting memory usage in half and speeding up compute.
Gradient checkpointing: Phi-3 supports activation checkpointing, where a subset of activations are preserved during the forward pass and the rest are recomputed during the backward pass. This trades off extra compute for large memory savings, crucial for training billion-parameter scale models.
PyTorch compile: Phi-3 is compatible with the torch.compile optimizing compiler. This can provide speedups of 10-20% by JIT compiling the model code and performing operator fusion.
FlashAttention CUDA kernels: As mentioned before, leveraging the optimized CUDA kernels for the self-attention and activation functions from the FlashAttention library can provide substantial improvements in compute efficiency.
Dynamic sequence length support: Phi-3 has specific logic to handle variable sequence lengths and padding tokens. This avoids wasting compute on padding tokens.
Efficient initialization: The initialize_weights function takes care of setting the initial weights to a normal distribution with a tunable standard deviation. This helps achieve stable training.
These optimizations, combined with the core architecture, allow Phi-3 to be trained efficiently at very large scales, paving the way for more powerful language models.
Conclusion
I hope this deep dive has given you a comprehensive understanding of the key modeling choices and components that power Phi-3. From the multi-layer decoder structure to the rotary positional embeddings to the highly optimized self-attention and feedforward implementations, each piece works together to enable strong language modeling performance.
The extensive configurability and use of optimized primitives make Phi-3 well-suited for pushing the boundaries of large-scale language modeling. The Phi3Config class allows easily tweaking the model hyperparameters, hidden dimensions, number of layers, and more for experimentation.
The specialized output heads also make Phi-3 simple to adapt to a variety of practical NLP tasks with limited fine-tuning data. This flexibility as a few-shot learner is one of the hallmarks of powerful large language models.
Phi-3 represents an exciting step forward in the development of efficient and adaptable language models. I can't wait to see what new applications and research the community will build on top of it. I will try fine-tuning it on my favorite tasks.