通过 Saxml 在 GKE 上使用多主机 TPU 提供 LLM


本教程介绍如何使用 Google Kubernetes Engine (GKE) 上的 Saxml 和张量处理单元 (TPU) 来提供大型语言模型 (LLM)。

背景

Saxml 是一个实验性系统,可提供 PaxmlJAXPyTorch 框架。借助这些框架,您可以使用 TPU 加快数据处理速度。为了演示 GKE 中 TPU 的部署,本教程提供 175B 的 LmCloudSpmd175B32Test 测试模型。GKE 分别在具有 4x8 拓扑的两个 v5e TPU 节点池上部署此测试模型。

为了正确部署测试模型,TPU 拓扑是根据模型大小定义的。由于 N0 亿 16 位模型大约需要 2 倍 (2xN) GB 的内存,因此 175B 的 LmCloudSpmd175B32Test 模型需要大约 350 GB 的内存。TPU v5e 单芯片有 16 GB。为了支持 350 GB,GKE 需要 21 个 v5e 芯片 (350/16= 21)。根据 TPU 配置的映射,本教程的正确 TPU 配置是:

  • 机器类型:ct5lp-hightpu-4t
  • 拓扑:4x8(32 个 TPU 芯片)

在 GKE 中部署 TPU 时,选择正确的 TPU 拓扑来处理模型非常重要。如需了解详情,请参阅规划 TPU 配置

目标

本教程适用于希望使用 GKE 编排功能来提供数据模型的 MLOps 或 DevOps 工程师或平台管理员。

本教程介绍以下步骤:

  1. 使用 GKE Standard 集群准备环境。该集群有两个具有 4x8 拓扑的 v5e TPU 节点池。
  2. 部署 Saxml。Saxml 需要一个管理员服务器、一组用作模型服务器的 Pod、一个预建的 HTTP 服务器和一个负载均衡器。
  3. 使用 Saxml 提供 LLM。

下图展示了以下教程实现的架构:

GKE 上的多主机 TPU 的架构。
:GKE 上的多主机 TPU 的示例架构。

准备工作

  • 登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  • 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  • 确保您的 Google Cloud 项目已启用结算功能

  • 启用所需的 API。

    启用 API

  • 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  • 确保您的 Google Cloud 项目已启用结算功能

  • 启用所需的 API。

    启用 API

  • 确保您拥有项目的以下一个或多个角色: roles/container.admin, roles/iam.serviceAccountAdmin

    检查角色

    1. 在 Google Cloud 控制台中,前往 IAM 页面。

      转到 IAM
    2. 选择项目。
    3. 主账号列中,找到您的电子邮件地址所在的行。

      如果您的电子邮件地址不在此列,则表示您没有任何角色。

    4. 在您的电子邮件地址所在的行对应的角色列中,检查角色列表是否包含所需的角色。

    授予角色

    1. 在 Google Cloud 控制台中,前往 IAM 页面。

      转到 IAM
    2. 选择项目。
    3. 点击 授予访问权限
    4. 新的主账号字段中,输入您的电子邮件地址。
    5. 选择角色列表中,选择一个角色。
    6. 如需授予其他角色,请点击 添加其他角色,然后添加其他各个角色。
    7. 点击 Save(保存)。

准备环境

  1. 在 Google Cloud 控制台中,启动 Cloud Shell 实例:
    打开 Cloud Shell

  2. 设置默认环境变量:

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export REGION=COMPUTE_REGION
      export ZONE=COMPUTE_ZONE
      export GSBUCKET=PROJECT_ID-gke-bucket
    

    替换以下值:

创建 GKE Standard 集群

使用 Cloud Shell 执行以下操作:

  1. 创建使用适用于 GKE 的工作负载身份联合的 Standard 集群:

    gcloud container clusters create saxml \
        --zone=${ZONE} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    VERSION 替换为 GKE 版本号。GKE 在 1.27.2-gke.2100 及更高版本中支持 TPU v5e。如需了解详情,请参阅 GKE 中的 TPU 可用性

    集群创建可能需要几分钟时间。

  2. 创建名为 tpu1 的第一个节点池:

    gcloud container node-pools create tpu1 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    
  3. 创建名为 tpu2 的第二个节点池:

    gcloud container node-pools create tpu2 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    

您已创建以下资源:

  • 具有四个 CPU 节点的标准集群。
  • 两个具有 4x8 拓扑的 v5e TPU 节点池。每个节点池表示 8 个 TPU 节点,每个节点有 4 个芯片。

175B 模型必须在具有 4x8 拓扑切片(32 个 v5e TPU 芯片)的多主机 v5e TPU 切片上提供。

创建 Cloud Storage 存储桶

创建 Cloud Storage 存储桶以存储 Saxml 管理员服务器配置。正在运行的管理员服务器会定期保存其状态以及已发布模型的详细信息。

在 Cloud Shell 中,运行以下命令:

gcloud storage buckets create gs://${GSBUCKET}

使用适用于 GKE 的工作负载身份联合配置工作负载访问权限

为应用分配 Kubernetes ServiceAccount,并将该 Kubernetes ServiceAccount 配置为充当 IAM 服务账号。

  1. 配置 kubectl 以与您的集群通信:

    gcloud container clusters get-credentials saxml --zone=${ZONE}
    
  2. 为您的应用创建 Kubernetes 服务账号:

    kubectl create serviceaccount sax-sa --namespace default
    
  3. 为您的应用创建 IAM 服务账号:

    gcloud iam service-accounts create sax-iam-sa
    
  4. 为您的 IAM 服务账号添加 IAM 政策绑定,以便对 Cloud Storage 执行读写操作:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. 通过在两个服务账号之间添加 IAM 政策绑定,允许 Kubernetes 服务账号模拟 IAM 服务账号。此绑定允许 Kubernetes 服务账号充当 IAM 服务账号,以便 Kubernetes 服务账号可以对 Cloud Storage 执行读写操作。

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. 使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解。这样,您的示例应用便知道要用于访问 Google Cloud 服务的服务账号。因此,在应用要使用任何标准 Google API 客户端库访问 Google Cloud 服务时,便会使用该 IAM 服务账号。

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

部署 Saxml

在本部分中,您将部署 Saxml 管理员服务器和 Saxml 模型服务器。

部署 Saxml 管理员服务器

  1. 创建以下 sax-admin-server.yaml 清单:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

  2. 应用清单:

    kubectl apply -f sax-admin-server.yaml
    
  3. 验证管理员服务器 Pod 是否已启动且正在运行:

    kubectl get deployment
    

    输出类似于以下内容:

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

部署 Saxml 模型服务器

在多主机 TPU 切片中运行的工作负载需要每个 Pod 使用稳定的网络标识符,以发现同一 TPU 切片中的对等体。要定义这些标识符,请使用IndexedJobStatefulSet使用无头 Service 或作业集它会自动为属于 JobSet 的所有 Job 创建无头 Service。以下部分介绍了如何使用 JobSet 管理多组模型服务器 Pod。

  1. 安装 JobSet v0.2.3 或更高版本。

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION 替换为 JobSet 版本。例如 v0.2.3

  2. 验证 JobSet 控制器是否正在 jobset-system 命名空间中运行:

    kubectl get pod -n jobset-system
    

    输出类似于以下内容:

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. 在两个 TPU 节点池中部署两个模型服务器。保存以下 sax-model-server-set 清单:

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

    在此清单中:

    • replicas: 2 是作业副本的数量。每个作业代表一个模型服务器。因此,一组 8 个 Pod。
    • parallelism: 8completions: 8 等于每个节点池中的节点数。
    • 如果任何 Pod 失败,backoffLimit: 0 必须为零才能将 Job 标记为失败。
    • ports.containerPort: 8471 是 TPU 虚拟机通信的默认端口
    • name: MEGASCALE_NUM_SLICES 会取消设置环境变量,因为 GKE 未运行多切片训练。
  4. 应用清单:

    kubectl apply -f sax-model-server-set.yaml
    
  5. 验证 Saxml 管理服务器和模型服务器 Pod 的状态:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

在此示例中,有 16 个模型服务器容器:sax-model-server-set-sax-model-server-0-0-nj4smsax-model-server-set-sax-model-server-1-0-8mlf5 是每个组中的两个主要模型服务器。

您的 Saxml 集群有两个部署在 v5e TPU 节点池上的模型服务器,这些节点池分别采用 4x8 拓扑。

部署 Saxml HTTP 服务器和负载均衡器

  1. 使用以下预构建的映像 HTTP 服务器映像。保存以下 sax-http.yaml 清单:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

  2. 应用 sax-http.yaml 清单:

    kubectl apply -f sax-http.yaml
    
  3. 等待 HTTP 服务器容器完成创建:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  4. 等待系统为 Service 分配外部 IP 地址:

    kubectl get svc
    

    输出类似于以下内容:

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

使用 Saxml

在 v5e TPU 多主机切片中的 Saxml 中加载、部署和投放模型:

加载模型

  1. 检索 Saxml 的负载均衡器 IP 地址。

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. 在两个 v5e TPU 节点池中加载 LmCloudSpmd175B 测试模型:

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    测试模型没有微调检查点,权重是随机生成的。模型加载最多可能需要 10 分钟。

    输出类似于以下内容:

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. 检查模型就绪情况:

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    输出类似于以下内容:

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    模型已完全加载。

  4. 获取模型相关信息:

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    输出类似于以下内容:

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

应用模型

处理提示请求:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

输出展示了模型响应的示例。此响应可能没有意义,因为测试模型具有随机权重。

取消发布模型

运行以下命令来取消发布模型:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

输出类似于以下内容:

{
  "model": "/sax/test/spmd"
}

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除已部署的资源

  1. 删除您为本教程创建的集群:

    gcloud container clusters delete saxml --zone ${ZONE}
    
  2. 删除服务账号:

    gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    
  3. 删除 Cloud Storage 存储桶:

    gcloud storage rm -r gs://${GSBUCKET}
    

后续步骤