在TPU上部署LLaMA 3¶
Serving LLaMA 3 on TPUs
📋 章节概览¶
所属部分:Transformer 原文标题:Serving LLaMA 3 on TPUs 原文地址:https://jax-ml.github.io/scaling-book/applied-inference 翻译时间:2026年03月30日
🎯 本章要点¶
本章将深入探讨在TPU上部署LLaMA 3的相关内容,包括:
- 核心概念:理解在TPU上部署LLaMA 3的基本原理
- 技术实现:掌握相关的技术实现方法
- 实践应用:了解在实际项目中的应用场景
- 优化策略:学习性能优化和最佳实践
翻译状态:初步翻译
技术说明: 1. 专业术语已进行基础翻译 2. 复杂概念保留英文原文 3. 公式和代码保持原样 4. 需要进一步的人工校对和完善
学习建议: - 结合原英文文档理解复杂概念 - 参考相关技术文档加深理解 - 实践教材中的代码示例
Serving LLaMA 3-70B on TPUs | How To Scale Your 模型 * How To Scale Your 模型 Toggle navigation
* Previous Part * Next Part * Sections Part 0. Introduction Part 1. Intro to Rooflines Part 2. All About TPUs Part 3. Sharded Matmuls Part 4. Transformers Part 5. 训练 Part 6. 训练 LLaMA Part 7. 推理 Part 8. Serving LLaMA Part 9. Profiling Part 10. All About JAX Part 11. Conclusions Part 12. GPUs
*
# Serving LLaMA 3-70B on TPUs Part 8 of [How To Scale Your 模型](/scaling-book) ([Part 7: 推理](../推理) | [Part 9: Profiling](../profiling))
Let's take a close look at how we'd serve LLaMA 3-70B models on TPU v5e. How expensive are different models to serve at roofline? How large are their KV caches? What batch sizes should we use? How are the parameters and activations sharded during 推理? Let's work through some back-of-the-envelope estimates for 延迟 and 吞吐量 in production.
### Contents [What's the LLaMA Serving Story?](#what-s-the-llama-serving-story) [](#)
Visualizing the 延迟 吞吐量 Tradeoff Worked Problems This section will look at what it takes to serve LLaMA-3 and how efficiently it can be done. As in the previous “applied” section, try to work out the answers on your own with a pen and paper before looking them up!
## What’s the LLaMA Serving Story? Let’s remind ourselves what LLaMA 3-70B looks like (see Section 6 for reference):
hyperparam value \(n_\text{layers}\) (L) 80 \(d_\text{模型}\) (D) 8,192 \(d_{ff}\) (F) 28,672 \(n_\text{heads}\) (N) 64 \(n_\text{kv heads}\) (K) 8 \(d_\text{qkv}\) (H) 128 \(n_\text{embeddings}\) (V) 128,256 Let’s start with a simple question: what hardware should we serve on? The answer is basically, whichever is cheapest in FLOPs / dollar.This isn't always true, sometimes more HBM or ICI 带宽 is critical rather than FLOPs, but this is a good heuristic. For this reason, we typically want to serve on TPU v5e, our current dedicated 推理 芯片 (cost comes from [Google Cloud pricing](https://cloud.google.com/tpu/pricing) as of February 2025):
TPU type bfloat16 FLOPs/s Google Cloud USD / hour FLOPs / $ H100 9.9e14 $10.8 3.3e17 v5p 4.59e14 $4.2 3.9e17 v5e 1.97e14 $1.2 5.8e17 Each TPU v5e has 16GB of HBM which will require us to shard our 模型 fairly aggressively. Let’s start by thinking about some basic quantities that might matter for us:
Question: How large are LLaMA 3-70B’s KV caches per token? You can assume we store them in int8. This determines how large our batch size can be on a given topology.
Click here once you’ve thought it through! LLaMA 3-70B has 8 KV heads, so the size per token is 2 * K * H * L = 2 * 8 * 128 * 80 = 160kB.
Note just how big this is! If we have a sequence length of 32k tokens (as is common), this uses 160e3 * 32,768 = 5.3GB / sequence. For BS=240, this is 1.3TB! Since TPU v5e only have 16GB a piece, we would need about (70e9 + 1.3e12) / 16e9 = 86 TPU v5e chips to even fit this much 内存. Also note how large this is compared to the 70GB of 模型 parameters.
Question: Let’s say we want to serve L3 70B at batch size 32 and 8192 sequence length with everything (params and KVs) in int8. How much total 内存 will this use? What’s the smallest slice we could serve this on?
Answer Since our KVs are 160e3 bytes in int8, our total KV 内存 is 160e3 * 8192 * 32 = 41.9e9 bytes. Our parameters are 70e9 bytes, since we have 1 byte per 参数. Thus, our total 内存 usage is 41.9e9 + 70e9 = 112GB.
The smallest slice we could use would have 112e9 / 16e9 = 7 TPUs, or (rounding to an even size), TPU v5e 4x2. This will be a tight fit and we might not be able to quite fit this accounting for other overhead, so we might need a 4x4 at minimum (or to drop the batch size).
Question: At this batch size and quantization on a TPU v5e 4x2, roughly what 延迟 would we expect per decode step? What 吞吐量 (tokens / sec / 芯片). What about a 4x4? Assume we perform our FLOPs in bfloat16 and everything is fully sharded.
Answer We can invoke the formula from the previous section that
[\begin{align} \tiny \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV 缓存 Size}}{\tiny \text{Total 内存 带宽}}}{\text{注意力 (always 带宽-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{参数 Count}}{\text{Total FLOPs/s}}, \frac{\text{参数 Size}}{\text{Total 内存 带宽}}\right)} \end{align}}] Here our critical batch size will be about 120 since our parameters are in int8 but our FLOPs are in bfloat16. We could also manually calculate the RHS maximum, but that’s basically a calculation we’ve already done several times. So we’re well into the 内存-bound regime for both our matmul and our FLOPs.
Strictly looking at 内存 带宽 then, our step time is basically (KV size + param size) / (8 * HBM 带宽) = 112e9 / (8 * 8.1e11) = 17ms. So theoretically our step time is about 17ms. Our 吞吐量 would be 32 / .017 = 1882 tokens / sec, or 1882 / 8 = 235 tokens / sec / 芯片.
There’s one caveat here which is to check if we might be ICI bound on our matmuls. We could dedicate 2 axes to it here, so we’re ICI bound in theory when \(Y > 2 * F / 2200 = 2 * 28672 / 2200 = 26\), so we’re golden!
If we were to run on a 4x4, we’d still be fine ICI-wise, so our 延迟 would drop to 17 / 2 = 8.5ms, but our 吞吐量 per-芯片 would remain the same.
### Thinking about 吞吐量 Let’s spend a little time thinking purely about 吞吐量. When we optimize for 吞吐量, we want to be compute bound, meaning we come close to utilizing all the TPU MXU capacity. Typically that means we want the batch size to be as large as possible, so we are doing as much work as possible.
Question: On TPU v5e, using bfloat16 weights and activations, how large do our batch sizes need to be for us to be compute-bound in our matmuls? What if we do int8 weights but perform our FLOPs in bfloat16? What about int8 weights with int8 FLOPs?
Answer As discussed in Section 7, for any bfloat16 matmul for which \(B \ll D, F\) we have
[\begin{equation} T_\text{math} > T_\text{comms} \leftrightarrow \frac{2BDF}{2DF} \geq \frac{\text{TPU bfloat16 FLOPs/s}}{\text{HBM 带宽}} = 240 \end{equation}] When our weights are in int8, we lose a factor of 2 in the denominator, so we have \(2BDF / DF = 2B > 240\), or equally \(B > 120\), half the critical batch size from before. That’s really helpful for us! When we do int8 weights and int8 FLOPs, we have to use the int8 value for TPU FLOPs/s, which goes from 1.97e14 for bfloat16 to 3.94e14, nearly double. That means we’re back where we started at about \(B > 240\).
The case of int8 weights and bfloat16 FLOPs is quite common, since quantizing parameters losslessly is often easier than doing low-precision arithmetic.
Question: What is the smallest TPU v5e topology we could serve LLaMA 3-70B on using bfloat16, int8, and int4 (both KVs and parameters) with 8k context? You can think of KV caches as negligibly small for this one.
Answer This is easy! If we’re OK with a tiny batch size then the only limit is fitting 参数 内存 in HBM, i.e. it is just ceil(num_params * sizeof(dtype) / HBM per TPU), or ceil(70e9 * sizeof(dtype) / 16e9) rounded to the nearest reasonable topology (some multiple of 2):
dtype param size KV size / token (bytes) min TPU v5es actual min slice remaining HBM for KV caches num KV caches @ 8k bf16 140GB 324kB 8.75 4x4 = 16 chips 116 43 int8 70GB 162kB 4.38 4x2 = 8 chips 58 43 int4 35GB 81kB 2.81 2x2 = 4 chips 29 43 That’s pretty cool! It tells us we could fit LLaMA 70B on a TPU v5e 2x2 if we wanted to. Except you’ll notice the number of KV caches is very small. That’s our batch size! That means we’ll be getting terrible FLOPs utilization. We’d be very happy to use a larger topology in order to push our batch size up to 240.
Question: Assume we use the largest batch size that fits on these topologies, what 延迟 could we expect for each generate step?
Answer This is also easy, since we’re picking our batch size to fill up all our HBM! This is just a question of how long it takes to load a full TPU v5e’s worth of bytes into the MXU. This is just v5e HBM / v5e HBM 内存 带宽 = 16GB / 8.2e11 = 19ms, so this is 19ms / step. Assuming our generations have a median length of 512 tokens, that is about 9s for each decode. Note that we could get marginally better 延迟 with a smaller batch size, for instance if we only looked at 模型 parameters in int4 our minimum 延迟 is about 10ms / step, since HBM is no longer full.
Takeaway: we can always lower bound decode 延迟 by asking how long it takes to load all the 模型’s parameters from HBM into the MXU. When our KV caches are small, you can think about each 层 as just loading the weights chunk-by-chunk and then discarding them. Unless we’re using large batch sizes or lots of inter-device comms, this is often a reasonable bound (within 1.5x). When our batch size is bigger, we need to 模型 the KV 缓存 loading as well, since that dominates the parameters.
Likewise, in the FLOPs-bound regime (e.g. 训练 or big-batch 推理), we can use the \(\text{Total FLOPs} / (N \cdot C) = 2 \cdot \text{param count} \cdot B / (N \cdot C)\) lower bound, which assumes no communication.
Question: For each of these, what 吞吐量 per 芯片 does this give us (in terms of queries / 芯片)? You can assume our median decode length is 512 tokens.
Answer This is an important question because it’s exactly correlated with cost / token.
With our assumption about median decode length, our 吞吐量 is just \(B / (\text{per-step 延迟} \cdot \text{median steps} \cdot N) \approx 43 / (0.019 * 512 * N)\). This gives us roughly \((4.42 / N)\) QPS, so plugging in \(N\) we get:
dtype QPS / 芯片 bfloat16 0.27 int8 0.55 int4 1.11 Note that this is rather optimistic since it totally ignores the working 内存 of the forward pass (内存 allocated to activations and 注意力). This is not ridiculous with Flash 注意力, but it is also not realistic. The real numbers are likely around 1/2 of this. For absolutely maximum 吞吐量 we would probably want to more than double the number of chips and increase the batch size significantly as well.
Question: How would our peak 吞吐量 change if we doubled our topology for each of the above examples?
Answer If we used a 4x8 slice in bfloat16, we would have 372GB remaining for KV caches, which would let us increase our batch size to 140. Then since our step time would remain the same, we would have a 吞吐量 of 14.39 / num_chips, or
dtype QPS / 芯片 bfloat16 (on 4x8) 0.44 int8 (on 4x4) 0.90 int4 (on 2x4) 1.80 A further increase would give an even bigger win! The big takeaway is that the smallest topology is not the most performant topology in all cases, if we’re limited by KV 缓存 size.
Question: Now let’s dig into the question of 分片. Let’s say we wanted to serve in bfloat16 on a TPU v5e 4x8. What 分片 would we use for our 模型 on a TPU v5e 4x8 during generation? Can we avoid being communication bound?
Answer As discussed in the previous section, we only really have one option for 分片 during generation: 模型 parallelism. How much can we do before we become communication bound? As we’ve discussed in the previous section, our models become communication bound roughly when
[Y > \frac{F \cdot M_Y}{2200}] For LLaMA 3-70B we have F = 28,672, so if we do 2 axes of 模型 分片 this gives us roughly \(Y = 28672 \cdot 2 / 2200 = 26\), so in general we could scale up to about 16 chips without being communication bound, which lets us use a 4x4 but not a 4x8. Generally, since we do not perfectly overlap 计算, even this estimate is overly optimistic.
Takeaway: we cannot actually serve on a 4x8 with pure 模型 parallelism. The best we can do here is a 4x2 or maybe a 4x4.
However, as we’ve discussed, when our batch size is small we can often do more 模型 parallelism without significantly hurting 吞吐量, since our 模型 is 内存-带宽-bound and not FLOPs bound. We said before that this value is roughly \(Y=F / (8\cdot B)\), so if we did batch size 64, we could in theory go up to Y = 28,672 / (8 * 64) = 56 way 模型 parallelism before we become ICI-bound. To sanity check this, we can look at \(T_\text{ici comms}\), \(T_\text{hbm comms}\), and \(T_\text{math}\) for a single matmul. We clearly have:
[\begin{align}T_\text{ici comms} = \frac{2BD}{W_\text{ici}} && T_\text{hbm comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{math} = \frac{2BDF}{Y \cdot C}\end{align}] For a 4x8, this would give us \(T_\text{ici comms}\) = (2 * 64 * 8192) / 9e10 = 11us, \(T_\text{hbm comms}\) = (2 * 8192 * 28,672) / (32 * 8.1e11) = 18us, and \(T_\text{math}\) = (2 * 64 * 8192 * 28,672) / (32 * 1.97e14) = 4us, so in theory we’re still HBM 带宽 bound, which is great! Note that scaling up from a 4x4 to a 4x8 probably isn’t helpful from a 吞吐量 standpoint, but it’ll reduce our 延迟!
If we look at the int8 and int4 configs, we can do those with pure 模型 parallelism. So we’ve hit a point at which quantization actually gives us a meaningful advantage beyond faster FLOPs: it lets us use a larger batch size before we become comms-bound. So the end of this story is that we can’t achieve peak 吞吐量 on a 4x8, but for the int8 and int4 configs we could do pure 模型 parallelism.
Tip: the maximum amount of useful 模型 parallelism depends on \(d_{ff}\) and the number of axes over which you’re 分片 your 模型. The maximum value usually ranges between 8 and 32 depending on the 模型 size. You can scale beyond this limit to improve 延迟 at some 吞吐量 cost.
### What about prefill? We’ve mostly ignored prefill here because it’s much simpler. Let’s put a couple of concepts together and think about the end-to-end picture.
Question: Assume we achieve a 40% FLOPs utilization during prefill. How long will a prefill of length 8192 take on 16 TPU v5e chips?
Answer At 8k tokens, we are solidly compute bound, so we just need to reason about FLOPs. We know our 模型 has 70e9 parameters so each forward pass uses 2 * 70e9 * B FLOPs. Assuming 40% MFU (FLOPs utilization), this gives us a runtime of about 2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s. Compared to the numbers we’ve been looking at before, that’s actually quite a lot!
Question: Assume we have a median prefill length of 8192 tokens and a median decode length of 4096 tokens. Say we have a generate batch size of 32. On average how many sequences finish decoding per step? On average how many tokens are evicted from our KV 缓存 each step?
Answer This is kind of straightforward. Since we have a median decode length of 4096 tokens, a sequence will finish roughly every 1 / 4096 tokens. Given a batch size of 32, this means we have 32 / 4096 sequences evicted per step. Since our KV 缓存 length is roughly 8192 + 4096, this is 32 * (8192 + 4096) / 4096 = 96 tokens evicted per step. The general formula is \(B * (P + G) / G\) where \(P\) and \(G\) are the prefill and generate lengths.
Question: Assume we do disaggregated serving with a median prefill length of 8192 and a median decode length of 512. Assume the prefill and generate latencies calculated above in bfloat16. What ratio of prefill:generate servers will you need to keep both fully saturated.
Answer This is kind of a fun question. Let \(P\) be the number of prefill servers and \(G\) be the number of generate servers. So generally speaking, this is a pipeline problem where we feed sequences in at a rate of P / prefill_latency and consume them at a rate of B * G / (generate_latency * median_decode_length). We had calculated 910ms per prefill step and 19ms per decode step at batch size 43 (let’s call that 32). Therefore we need P / 0.91 = 32 * G / (0.019 * 512) or P = 3G, i.e. we need about 3 times more prefill servers than generation servers!
## Visualizing the 延迟 吞吐量 Tradeoff Sticking with LLaMA 70B for a second, let’s actually look at the 延迟 and 吞吐量 for different batch sizes during generation. As we showed in the previous section for PaLM models, this gives us a Pareto frontier for 吞吐量/延迟. Let’s assume 16-way 张量 parallelism since that’s a reasonable bound on what we can use while staying compute-bound in the MLP blocks. We’ll use a TPU v5e 4x4 topology here. The slider controls the sequence length so you can see the effect of larger KV caches.
- See how dramatic the tradeoff is between cost and 延迟. At the cost of doubling per-token 延迟, we can achieve a roughly 100x reduction in per-token cost. Also, our 延迟 can range anywhere from 5.5ms with low batch size to 20 ms with very large batches.
- Note how at 2k context the 吞吐量 effectively plateaus at around 1 token / ms / 芯片 when it hits the BS 120 roofline (120 here because we do int8 weights but bf16 FLOPs). As the sequence length increases, however, we can no longer fit this batch size in 内存, so we never hit the point of full saturation.
- Note how much higher the 延迟 is at large batch sizes for the same 吞吐量, since KV loading becomes dominant (instead of 参数 loading).
We can understand this better by breaking down the sources of cost and 延迟 into param loading time, KV loading time, and FLOPs time. The red sector is the region in which we expect to be compute-bound in our MLP blocks.
This tells quite a story. You can see that initially, 参数 loading represents the vast majority of the 延迟, until the batch size becomes large enough that FLOPs and KV loading become more significant. Notably, at all sequence lengths greater than 2048, we spend more time on KV 缓存 loading than we do on FLOPs! So while we can improve our hardware utilization by increasing batch size, at long context lengths KV loading always dominates the total step time.
Takeaway: for LLaMA 3-70B, we are strongly KV 缓存 内存 带宽-bound (and HBM-bound) in almost all of these configurations, highlighting just how important reducing KV 缓存 size is for generation 吞吐量. Also note just how dramatic the 延迟/吞吐量 tradeoff remains here.
The code for this is quite simple. Here’s the code for computing these rooflines:
``` import numpy as np
num_chips = 16 # we fix 16 as the amount of total 模型 parallelism we do bytes_per_param = 1 # int8 means 1 byte per param param_count = 70e9 param_size = bytes_per_param * param_count sequence_length = 8192 # can vary this
hbm_bandwidth = 8.20E+11 # v5e flops = 1.97E+14 # v5e
def kv_cache_size(bs): return 2 * bs * 128 * 8 * 80
def min_topology(bytes): return 2 ** np.ceil(np.log2(bytes / 16e9))
def get_max_batch_size( num_chips: int, sequence_length: int, param_size: float, ) -> int: batch_sizes = np.arange(1, 1024, 4) kv_sizes = kv_cache_size(sequence_length * batch_sizes) required_chips = min_topology(kv_sizes + param_size) max_idx = np.where(required_chips <= num_chips)[0][-1] return max_idx
max_idx = get_max_batch_size( num_chips=num_chips, sequence_length=sequence_length, param_size=param_size, ) # get the largest batch size that can fit batch_sizes = np.arange(1, 512, 1)[:max_idx] kv_sizes = kv_cache_size(sequence_length * batch_sizes)
kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)
param_comms_time = param_size / (num_chips * hbm_bandwidth) param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])
flops_time = 2 * param_size * batch_sizes / (num_chips * flops) # roughly true in a 2ND sense
mlp_time = np.maximum(flops_time, param_comms_time) attn_time = kv_comms_time # always 带宽-bound for generate
延迟 = 1000 * (mlp_time + attn_time) 吞吐量 = batch_sizes / (延迟 * num_chips)
``` Note how we very explicitly break out 延迟 into two sources: KV loading and param loading, and how the 延迟 is either bound by FLOPs or comms, whichever is bigger.
## Worked Problems Here are a few worked problems. Some of these repeat things that are worked above, but might be pedagogically useful.
Question 1: How many FLOPs does each forward pass for LLaMA 3-405B use per-token? Assuming we’re FLOPs bound, what is a lower bound on a single forward pass on N chips on TPU v5e? What if we’re comms bound? Ignore the fact that the 模型 does not fit on a single 芯片.
Question 2: Assume we want to serve LLaMA 3-8B with BS240 using int8 weights and int8 KV caches. How many bytes are used by (a) 模型 parameters (b) KV caches and (c) peak working activations (roughly)? What’s the smallest topology we can run this on?
Question 3: How would you serve LLaMA 3-405B on TPU v5e? Assume int8 weights and bfloat16 FLOPs. Let’s say we have a firm limit of 15ms / token, what’s the highest 吞吐量 configuration we could achieve? What is the theoretical minimum step time?
### That’s all for Part 8! For Part 9, with a deep dive into XLA and TPU profiling, click here. ### Miscellaneous *Work done at Google DeepMind, now at MatX.
### Citation For attribution in academic contexts, please cite this work as:
``` Austin et al., "How to Scale Your 模型", Google DeepMind, online, 2025.
``` or as a BibTeX entry:
``` @article{scaling-book, title = {How to Scale Your 模型}, author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner}, publisher = {Google DeepMind}, howpublished = {Online}, note = {Retrieved from https://jax-ml.github.io/scaling-book/}, year = {2025} }
``` Please enable JavaScript to view the comments powered by giscus. © Copyright 2026 . Powered by Jekyll with al-folio theme. Hosted by GitHub Pages.
🔗 相关资源¶
- 官方文档:
- JAX官方文档
- XLA编译优化
-
参考论文:
- Transformer原始论文
- Attention Is All You Need
-
实践项目:
- JAX示例代码库
- Transformer实现示例
- TPU使用教程
💡 学习建议¶
理论学习¶
- 先通读全文,了解整体框架
- 重点理解核心概念和技术原理
- 结合图表和公式深入理解
实践学习¶
- 运行教材中的代码示例
- 尝试修改参数观察效果
- 应用到自己的项目中
深入学习¶
- 阅读参考文献和扩展阅读
- 参与相关技术社区讨论
- 关注最新的技术发展
本翻译由OpenClaw自动生成,正在不断完善中。 翻译问题反馈:请通过博客反馈渠道联系。