GKE で Saxml を実行してマルチホスト TPU を使用して LLM を提供する


このチュートリアルでは、Saxml を使用して Google Kubernetes Engine(GKE)で Tensor Processing Unit(TPU)を使用して大規模言語モデル(LLM)を提供する方法について説明します。

背景

Saxml は、PaxmlJAXPyTorch の各フレームワークを提供する試験運用版のシステムです。TPU を使用すると、これらのフレームワークでデータ処理を高速化できます。GKE で TPU のデプロイのデモを行うため、このチュートリアルでは 175B の LmCloudSpmd175B32Test テストモデルを使用します。GKE は、このテストモデルをそれぞれ 4x8 トポロジの 2 つの v5e TPU ノードプールにデプロイします。

テストモデルを適切にデプロイするために、TPU トポロジはモデルのサイズに基づいて定義されています。N0 億の 16 ビットモデルには約 2 倍(2 x N)の GB 数のメモリが必要ですが、175B LmCloudSpmd175B32Test モデルには約 350 GB のメモリが必要です。TPU v5e シングルチップの容量は 16 GB です。350 GB をサポートするには、GKE に 21 個の v5e チップが必要です(350÷16= 21)。TPU 構成のマッピングに基づいて、このチュートリアルの適切な TPU 構成は次のようになります。

  • マシンタイプ: ct5lp-hightpu-4t
  • トポロジ: 4x8(32 個の TPU チップ)

GKE に TPU をデプロイする場合は、モデルの提供に適した TPU トポロジを選択することが重要です。詳細については、TPU 構成の計画をご覧ください。

目標

このチュートリアルは、データモデルを提供するために GKE オーケストレーション機能を使用する MLOps または DevOps エンジニア、プラットフォーム管理者を対象としています。

このチュートリアルでは、次の手順について説明します。

  1. GKE Standard クラスタで環境を準備します。クラスタには、4x8 トポロジの 2 つの v5e TPU ノードプールがあります。
  2. Saxml をデプロイします。Saxml には、管理者サーバー、モデルサーバーとして機能する Pod のグループ、事前に構築された HTTP サーバー、ロードバランサが必要です。
  3. Saxml を使用して LLM を提供します。

次の図は、このチュートリアルで実装するアーキテクチャを示しています。

GKE 上のマルチホスト TPU のアーキテクチャ。
図: GKE 上のマルチホスト TPU のアーキテクチャ例。

始める前に

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  • Google Cloud プロジェクトで課金が有効になっていることを確認します

  • 必要な API を有効にします。

    API を有効にする

  • Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  • Google Cloud プロジェクトで課金が有効になっていることを確認します

  • 必要な API を有効にします。

    API を有効にする

  • プロジェクトに次のロールがあることを確認します。 roles/container.admin, roles/iam.serviceAccountAdmin

    ロールを確認する

    1. Google Cloud コンソールの [IAM] ページに移動します。

      [IAM] に移動
    2. プロジェクトを選択します。
    3. [プリンシパル] 列で、自分のメールアドレスを含む行を見つけます。

      自分のメールアドレスがその列にない場合、ロールは割り当てられていません。

    4. 自分のメールアドレスを含む行の [ロール] 列で、ロールのリストに必要なロールが含まれているかどうかを確認します。

    ロールを付与する

    1. Google Cloud コンソールの [IAM] ページに移動します。

      [IAM] に移動
    2. プロジェクトを選択します。
    3. [ アクセスを許可] をクリックします。
    4. [新しいプリンシパル] フィールドに、自分のメールアドレスを入力します。
    5. [ロールを選択] リストでロールを選択します。
    6. 追加のロールを付与するには、 [別のロールを追加] をクリックして各ロールを追加します。
    7. [保存] をクリックします。

環境を準備する

  1. Google Cloud コンソールで、Cloud Shell インスタンスを起動します。
    Cloud Shell を開く

  2. デフォルトの環境変数を設定します。

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export REGION=COMPUTE_REGION
      export ZONE=COMPUTE_ZONE
      export GSBUCKET=PROJECT_ID-gke-bucket
    

    次の値を置き換えます。

GKE Standard クラスタを作成する

Cloud Shell で以下の操作を行います。

  1. GKE 用 Workload Identity 連携を使用する Standard クラスタを作成します。

    gcloud container clusters create saxml \
        --zone=${ZONE} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    VERSION は、GKE のバージョン番号に置き換えます。GKE は、バージョン 1.27.2-gke.2100 以降で TPU v5e をサポートしています。詳細については、GKE での TPU の可用性をご覧ください。

    クラスタの作成には数分かかることもあります。

  2. tpu1 という名前で 1 つ目のノードプールを作成します。

    gcloud container node-pools create tpu1 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    
  3. tpu2 という名前で 2 つ目のノードプールを作成します。

    gcloud container node-pools create tpu2 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    

次のリソースを作成しました。

  • 4 つの CPU ノードを持つ Standard クラスタ。
  • 4x8 トポロジを持つ 2 つの v5e TPU ノードプール。各ノードプールは、それぞれ 4 つのチップを持つ 8 つの TPU ノードを表します。

175B モデルは、少なくとも 4x8 トポロジ スライス(32 個の v5e TPU チップ)を持つマルチホスト v5e TPU スライスで提供する必要があります。

Cloud Storage バケットを作成する

Saxml 管理者サーバーの構成を保存する Cloud Storage バケットを作成します。実行中の管理者サーバーは、その状態と公開モデルの詳細を定期的に保存します。

Cloud Shell で次のコマンドを実行します。

gcloud storage buckets create gs://${GSBUCKET}

GKE 用 Workload Identity 連携を使用してワークロード アクセスを構成する

アプリケーションに Kubernetes ServiceAccount を割り当て、IAM サービス アカウントとして機能するようにその Kubernetes ServiceAccount を構成します。

  1. クラスタと通信を行うように kubectl を構成します。

    gcloud container clusters get-credentials saxml --zone=${ZONE}
    
  2. アプリケーションで使用する Kubernetes ServiceAccount を作成します。

    kubectl create serviceaccount sax-sa --namespace default
    
  3. アプリケーションの IAM サービス アカウントを作成します。

    gcloud iam service-accounts create sax-iam-sa
    
  4. IAM サービス アカウントの IAM ポリシー バインディングを追加して、Cloud Storage に対する読み取りと書き込みを行います。

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. 2 つのサービス アカウントの間に IAM ポリシー バインディングを追加して、Kubernetes ServiceAccount が IAM サービス アカウントの権限を借用できるようにします。このバインドで、Kubernetes ServiceAccount が IAM サービス アカウントとして機能するようになるため、Kubernetes ServiceAccount が Cloud Storage に対して読み書きを行うことができます。

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. Kubernetes サービス アカウントに IAM サービス アカウントのメールアドレスでアノテーションを付けます。これにより、サンプルアプリが Google Cloud サービスへのアクセスに使用するサービス アカウントを認識できます。そのため、アプリが標準の Google API クライアント ライブラリを使用して Google Cloud サービスにアクセスする場合は、その IAM サービス アカウントを使用します。

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

Saxml をデプロイする

このセクションでは、Saxml 管理者サーバーと Saxml モデルサーバーをデプロイします。

Saxml 管理者サーバーをデプロイする

  1. 次の sax-admin-server.yaml マニフェストを作成します。

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

  2. 次のようにマニフェストを適用します。

    kubectl apply -f sax-admin-server.yaml
    
  3. 管理者サーバーの Pod が稼働していることを確認します。

    kubectl get deployment
    

    出力は次のようになります。

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

Saxml モデルサーバーをデプロイする

マルチホスト TPU スライスで実行されるワークロードでは、同じ TPU スライス内のピアを検出するために、各 Pod に安定したネットワーク識別子が必要です。これらの識別子を定義するには、IndexedJobStatefulSet ヘッドレス Service または JobSet を使用します。これにより、JobSet に属するすべての Job に対してヘッドレス Service が自動的に作成されます。次のセクションでは、JobSet を使用してモデルサーバー Pod の複数のグループを管理する方法について説明します。

  1. v0.2.3 以降の JobSet をインストールします。

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION は、JobSet のバージョンに置き換えます。例: v0.2.3

  2. JobSet コントローラが jobset-system Namespace で実行されていることを確認します。

    kubectl get pod -n jobset-system
    

    出力は次のようになります。

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. 2 つの TPU ノードプールに 2 つのモデルサーバーをデプロイします。次の sax-model-server-set マニフェストを保存します。

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

    このマニフェストの内容:

    • replicas: 2 は、Job のレプリカの数です。各ジョブはモデルサーバーを表します。したがって、8 つの Pod のグループになります。
    • parallelism: 8completions: 8 は、各ノードプール内のノード数と等しくなります。
    • Pod が失敗した場合に Job を失敗としてマークするには、backoffLimit: 0 を 0 にする必要があります。
    • ports.containerPort: 8471 は、TPU VM 通信用のデフォルト ポートです。
    • GKE はマルチスライス トレーニングを実行していないため、name: MEGASCALE_NUM_SLICES は環境変数の設定を解除します。
  4. 次のようにマニフェストを適用します。

    kubectl apply -f sax-model-server-set.yaml
    
  5. Saxml 管理サーバーと Model Server Pod のステータスを確認します。

    kubectl get pods
    

    出力は次のようになります。

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

この例では、16 個のモデルサーバー コンテナがあります。sax-model-server-set-sax-model-server-0-0-nj4smsax-model-server-set-sax-model-server-1-0-8mlf5 は、各グループの 2 つのプライマリ モデルサーバーです。

Saxml クラスタには、それぞれ 4x8 トポロジを持つ 2 つの v5e TPU ノードプールにデプロイされた 2 つのモデルサーバーがあります。

Saxml HTTP Server とロードバランサをデプロイする

  1. 次のビルド済みイメージの HTTP サーバー イメージを使用します。次の sax-http.yaml マニフェストを保存します。

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

  2. sax-http.yaml マニフェストを適用します。

    kubectl apply -f sax-http.yaml
    
  3. HTTP サーバー コンテナの作成が完了するまで待ちます。

    kubectl get pods
    

    出力は次のようになります。

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  4. Service に外部 IP アドレスが割り当てられるまで待ちます。

    kubectl get svc
    

    出力は次のようになります。

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

Saxml を使用する

v5e TPU マルチホスト スライスの Saxml でモデルを読み込んでデプロイし、提供します。

モデルを読み込む

  1. Saxml のロードバランサの IP アドレスを取得します。

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. 2 つの v5e TPU ノードプールに LmCloudSpmd175B テストモデルを読み込みます。

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    テストモデルにはファインチューニングされたチェックポイントがなく、重みはランダムに生成されます。モデルの読み込みには最大 10 分かかります。

    出力は次のようになります。

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. モデルの準備状況を確認します。

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    出力は次のようになります。

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    モデルが完全に読み込まれました。

  4. モデルに関する情報を取得します。

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    出力は次のようになります。

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

モデルを提供する

プロンプト リクエストを処理します。

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

出力には、モデルのレスポンスの例が表示されます。テストモデルにはランダムな重みがあるため、このレスポンスは意味をなさない可能性があります。

モデルの公開を停止する

次のコマンドを実行して、モデルを非公開にします。

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

出力は次のようになります。

{
  "model": "/sax/test/spmd"
}

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

デプロイされたリソースを削除する

  1. このチュートリアル用に作成したクラスタを削除します。

    gcloud container clusters delete saxml --zone ${ZONE}
    
  2. サービス アカウントを削除します。

    gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    
  3. Cloud Storage バケットを削除します。

    gcloud storage rm -r gs://${GSBUCKET}
    

次のステップ