Model Optimization with PyTorch

Greetings, I am Keith, a Machine Learning Engineer at Money Forward. Today, we will take a quick look at the techniques and tools applied in the example of optimizing SegmentAnything implementation, highlighted in the influential blog post: “Accelerating Generative AI” from the PyTorch blog. The images in this post is from the PyTorch original blog post.


  • Model Optimization with PyTorch
    • Outline
    • Background
    • Baseline
    • Progressive Optimization
      • Bfloat16 Half Precision
      • GPU Syncs Removal
      • Torch.compile
      • Scaled Dot Product Attention(SDPA)
      • Triton
      • NestedTensor and batching predict_torch
      • Int8: quantization and approximating matmul
      • Sparse: Semi-structured (2:4) sparsity
    • Conclusion
    • Source


In recent years, the field of Generative AI and large models has seen remarkable advancements. As the model architecture is getting more complex and larger, the trainable parameters is increasing drastically. Optimization becomes an important process to speed up the computation and reduce cost.

In the original blog, the PyTorch team rewrote Meta’s Segment Anything Model (“SAM”) implementation, which resulted in 8 times faster computation than the original implementation, with no loss of accuracy (according to the original post), all using native PyTorch optimizations. The team applied a series of optimization techniques progressively and observed the significant improvement of throughput and memory.


Before applying any optimization, the team set a baseline model to be the reference point. It is the original Meta Research’s SAM implementation, using float32 dtype and a batch size of 1. All following optimization steps were compared against this baseline model.

The computation of the baseline model and optimization steps were profiled by using Perfetto. Here is how it looks, with the computation thread and the operations breakdown.

 Image from the PyTorch blog. Profiled using Perfetto.

Here is another figure from the original blog. It was used in each optimization step to show the comparison of model performance. The figure of vit_b and vit_h represent the smallest and largest vision transformer backbones respectively.

 Image from the PyTorch blog. Memory and throughput of the baseline SAM model.

Progressive Optimization

In this session, I will summarize the optimization techniques applied in the example. You can check the resources session or the original blog for detailed explanation and implementation.

Before looking at each optimization step, let’s jump to the comparison among all progressive optimizations first. Image below is a bar charts showing all optimizations applied progressively (from left to right). The left-most bar represents the baseline model.

 Image from the PyTorch blog. Memory and throughput progressive improvement.

Bfloat16 Half Precision

The progressive optimization steps started from simple but effective methods. The first one applied was the Bfloat16 Half Precision. It is a common half-precision type to both speed up computation and reduce memory consumption. With more efficient memory consumption, a large batch size was possible. For more about precision types, refer to posts from PyTorch or here if you are using PyTorch Lightning.

Be aware that it is critical to validate the end-to-end model accuracy for this optimization technique because it reduces the number’s precision that potentially leads to accuracy drop.

GPU Syncs Removal

By analyzing the profiling result, the PyTorch team found a major bottleneck in the threads caused by GPU Syncs process. The team looked for the several pieces of code that caused this and then optimized the computation by rewriting these pieces properly to avoid GPU syncs. This optimization led to a significant boost to throughput.


The team compiled the model with compile method, which is the latest method to speed up PyTorch code. You might refer to the official doc for detailed usage. It fuses operations together smartly to reduce kernel calls in order to speed up the whole operation.

Scaled Dot Product Attention(SDPA)

The feature scaled_dot_product_attention is a PyTorch’s operation to specifically optimize attention calculation in transformer models.

After applying this method, it becomes viable to greatly increase the batch size because the memory consumption is reduced. From the example, batch size was increased from 8 to 32, 4 times larger but the memory consumption is only 2 to 2.5 times.


By applying SDPA, the memory consumption improved generally. From the profiling result, memory consumption spikes were observed by the team within the image encoder due to the large size of attention variable.

However, the team met with difficulty in directly improving the SDPA operation, as it was written in CUDA. So the team turned to Triton and implemented the Flash Attention. (You might check the resource session for details of Flash Attention). After this optimization, the memory consumption was greatly improved to become 2 to 3 times lower.

NestedTensor and batching predict_torch

At this point, the image encoder was almost well optimized. So the team shifted their focus to other components: Prompt Encoder and Mask Decoder. The team found that some tensors in prompt encoder and mask decoder are of different sizes (shapes). The team optimized it by using PyTorch’s NestedTensor feature. From the figure above, we can observe that the throughput is improved especially for vit_b.

Int8: quantization and approximating matmul

The first step of optimization was about using a lower precision. The team tried to push it further by testing a precision lower than bfloat16, with int8 quantization. That increased the throughput of the largest model vit_h, but we also observe that throughput dropped a little for the smallest model vit_b. As a tradeoff, the accuracy also dropped a bit according to the original post, due to the lower precision.

Sparse: Semi-structured (2:4) sparsity

After all optimizations above were applied, matrix multiplication became the bottleneck again. So the team used a classic method of approximate matrix multiplication: Sparsification. It was achieved by pruning small weights in the weight tensors to reduce model size, without significant loss of accuracy according to the original blog.


PyTorch team successfully optimized and rewrote the fastest Segment Anything (according to the blog) in pure PyTorch with no or very low loss of accuracy.

Through their great blog post, we learnt the amazing Segment Anything model, a model performance profiling method, a series of optimization techniques and, most importantly, a set of progressive optimization skills.

In addition, it is a great example showing the importance of optimization especially for large and complex model architecture. The benefit is not limited to faster computation or lower cost, but also more efficient memory usage, which enable larger batch size for higher throughput.

In this post, we listed the methods PyTorch team used. I highly recommend to read the detail in the original blog and checkout the sequel blog post as well.