Company

LLM開発コンペ2025参戦記: マルチノードで動かすvLLM そのハマりどころと要点

TL;DR

東京大学松尾・岩澤研究室主催のLLM開発コンペ2025における評価班の取り組みについて説明します。参加したRAMENチームはなんと予選通過を果たしました。そこで得た、LLMモデルをHPCクラスタ(マルチGPU/マルチノード)環境下でモデル評価ツールvLLMを動かすための注意点を列挙。原因・背景も記載したので、できれば機序を理解しつつFIXにご利用ください。(本記事はQiitaのNeural Group Advent Calendar 2025 - Day4 にも掲載しています)

# 一つだけ覚えてください → mmapエラーは `ulimit -v unlimited`

はじめに

 今年2025年 東京大学松尾研が主催する 松尾研LLM開発コンペ2025 に参加しました。昨年 Geniac2024の枠組みで実施された松尾研LLM開発コンペでは当時最新のモデルに劣らない性能の「 Tanuki-8×8B 」が開発・発表され話題となったことは記憶に新しいところです。
 今回参加した松尾研LLM開発コンペ2025ではチームRAMENに参加させていただき、開発モデル評価を行う評価を担当しました。そこではvLLMを運用したのですが、大いにハマり倒しました。その記録を皆さんに共有し、この記事を読む皆さんには、私のように哀しい思いをせずに、コンペを楽しんでいただければ良いかなと考えて筆を取りました。

 この記事では、vLLM利用におけるハマりどころの概要を説明します。時間があれば、詳細についても記事にする予定です。

LLM開発コンペ2025

vLLMとは?

  vLLM は推論専用のエンジンです。学習プラットフォームではありません。vLLMは様々なオープンモデル( HuggingFace にあるような重み公開モデル)に対応しており、LLMアーキテクチャ研究の現場では欠かせないツールとなっています。
 反面、様々なモデルの対応が次々と増築されてきた経緯から非常に不安定な振る舞いをすることがあります。特に今回のコンペのような最新モデルをターゲットにするLLM開発では、vLLMがそれらの最新モデルに十分に対応できていないこともしばしばあります。これが、vLLMの取り扱いを難しくする要因でもあります。

コンペのターゲット: クラスタでしか使えないモデル規模

 vLLMは、マルチGPUだけでなく、マルチノード環境に強いツールとして人気があります。マルチノードに対応する評価ツールには他にもあり、一例ではSGLangなどが挙げられます。
 この数年間でllmが要求するGPUメモリ規模は急激に巨大化し、DNNのようにゲーミングPC一台で動かす事などすっかり不可能となったことは皆さんもよくご存知でしょう。
 巨大なモデルを動かすためのハードウエアにはHPCクラスタが使われます。高性能なHPCをノードとしてネットワークで連携させ、その上でモデルを分散動作させるアプローチが取られます。実際、2025年の最新モデルQWEN3-235Bぐらいなら、 NVIDIA H100 GPU 80GB を8台積んだHPC一台でなんとかvLLMを動かせるぐらいです。トークンサイズを煽ったりしようとすると途端に厳しくなります。さらに学習ツールを使ったファインチューニングとなると、マルチノード環境が必須です。
 vLLMは様々なクラスターオーケストレーションツールと連携します。中でもOSSである” Ray ”が研究現場ではよく使われます。
 大規模モデルをvLLMを扱うためにはこの Rayを確実に設定する ことが重要になります。

NVIDIA DGX Station

共用環境での運用

 HPCクラスタは非常に高価なインフラとなるので、複数のユーザー間で計算リソース(CPUコア、メモリ、GPUなど)を効率的に使うためにSlurmなどのリソース管理ツールを使うことが一般的です。vLLMやRay自体はSlurm下でも動作しますが、特段Slurmに特化しているわけではないので、Slurm利用特有のリソース割り当て作法を取り入れる必要があります。
 もう一つポイントとなるのは、root権限もらえない環境である点です。その流れから利用できるコンテナも制約を受けます(Dockerが使えず。他のコンテナツールを使うしかない)。
 個人マシンでのroot運用に慣れた人には非常にストレスです。環境変数の設定や運用でカバーすることになるため、障害が起こりやすくなります。特にRayは、コンテナ化するかシステムサービスとしての運用が望ましいツールなのですが、後者で運用しようとするとクラスタ管理者との調整が必要なためなかなか辛いところです


ハマりどころ解説

この記事の読者は、vLLMを動かそうとする人がほとんどかと思いますので、動作確認に沿った順でハマりどころを説明していきます

■ HPC由来のハマりどころ

Linuxプロセスリソース上限デフォルト値がHPCタスクとマッチしてない

 Linux/UNIX では各プロセスに対して リソース上限(resource limits) があります。デフォルト値では全く問題ないのですが、 HPCクラスタがターゲットにするタスクではこの値が低すぎる ことがあります。LLMモデルの推論などは正にその典型例です。シェル組込みコマンドulimitを使い、プロセスごとのファイルディスクリプタ数やプロセス数・スタック・仮想メモリ上限を変更する必要があります。

 変更を忘れたがために発生する典型的な症状は、 mmapエラーの頻発 です。「共有ライブラリが読めなかった」「pythonのimportに失敗」「メモリマッピングに失敗」など多岐に渡ったメッセージがでますが、全ては以下に起因するmmapエラーです。

 詳細を省きますが、 これは「共有メモリのアドレス空間(Address Space)の予約」に失敗した ことを指し示しています。まさに「メモリのマッピング処理の失敗」です。注意して欲しいのは、メモリの確保に失敗しているのではないということ。HPCではCPUメモリサイズが1.7TBぐらいを搭載していることも珍しくないので、「これだけメモリあるんだからメモリ確保に失敗するはずないんだけど」と早とちりしてしまいがちです。失敗しているのは「利用予定の仮想アドレス範囲を予約すること」です。

 vLLM は起動時に (a)Pytorchコンパイルや、(b)大量のライブラリ読み込みが一気に発生します。Python上の制約からマルチプロセス動作するため、「プロセス数」の分これらの処理を倍化させることになります。 具体例を挙げると、--tensor-parallel-size 2 だと ワーカープロセスが2本立つため、インポート直後のピークメモリが倍化し、上限値に引っかかりやすくなります。これを回避するためには、シェルごとに以下のulimitコマンドで制限を変更する必要があります。

ulimit -v unlimited

■ vLLMのハマりどころ

(1)vllm serveコマンドのパラメタ設定には制約ルールあり

 vllmのパラメータには厳格なルールがあります。当たり前すぎるのか公式マニュアルぐらいしか記載がない(?)。初めてvLLMを使うケースでは間違いなくハマります

  1. TP/PP/DP設定には相互関係ルール: 総 GPU 数 = TP * PP * DP
  2. モデル特性(attention-headルール): 現在の Tensor Parallel(TP)実装は「モデルの総 attention-head 数(および KV-head 数)が tensor-parallel-size で割り切れること」を要求。割り切れないとエラーになります。
  3. MoE(Expert Parallel)モデルであれば、関連フラグ ‘--enable-expert-parallel ‘を指定しないといけない。また 総 GPU 数 = TP * PP * DP * EP(MoE次元)となる。
カテゴリ ルール・要件 詳細・注意点
1. パラメタ間ルール 総 GPU 数 = TP * PP * DP TP, PP, DPの積が、使用する物理GPUの総数と一致する必要があります。
2. モデル特性 (TP) Head数 % TPサイズ = 0 モデルの総 attention-head 数(および KV-head 数)が TP サイズで 割り切れる 必要があります。割り切れない場合はエラーになります。
3. MoE (EP) 設定 「フラグ指定: --enable-expert-parallel 」および「総 GPU 数 = TP * PP * DP * EP(MoE次元数)」 MoE (Expert Parallel) モデルを使用する場合には、このフラグの指定が必須となります。またパラメタ関係式にEPが含まれてきます

(2)【注意】vllmはエラーメッセージの出し方に偏りがあります

 vllmのエラーハンドリングではをしていません。これはPythonによく見られるEAFP型コーディングの負の部分をvLLMが代弁しているといえます。

  1. 「Easier to Ask Forgiveness than Permission(許可を得るよりも、許しを請うほうが簡単)」という考え方。事前チェックなどを行わず、エラーがでてから例外処理をすればよいというコーディングスタイル。

  2. vllmでは設定項目のチェックがなく全てEAFPベース。実際に動かさないと設定間違いが見つからない仕組み。時間が溶けまくります。

    • このためメモリ例外が起きるところまで進んでから落ちる。全てのエラーメッセージがメモリエラーに。
    • メッセージを間に受けると時間が溶けます(溶かしました)。
  3. コード全体が「動けば良い」のレベルで実装が積み上がっていくので、同じEAFP型の制御でも振る舞いに一貫性がない(例:serveコマンドの引数を一つとっても、エラー時の振舞いが全く違うという...)。

[対応策]

  • エラーメッセージをLLMに投げて分析してみるのも一案です。
    • まだ直接ズバリとは行きませんが近いヒントをくれます。

(3)モデルごとにvllm のバージョンを変える→ うっかり仮想環境破壊

 vLLMは非常に頻繁にパッチバージョンが上がります。ですが、内容的にはマイナーバージョン相当かと思う様なケースも少なくありません。このため、パッチバージョンアップなのに直前のバージョンで動いていた機能がエラーで動かなくなることも時々遭遇します。こうした災害から身を守るために、仮想環境ごとバックアップをとってください。
[注意点]

  • vLLMのバージョンアップで動かなくなる例:
    • vllmパッケージが内部に事前ビルドされたCUDAバイナリを含むため、ホスト側ドライバと整合しなくなって動かなくなるケース。

[対応策]

  • バージョンごとにドライバとの整合性を必ず確認すること。
    • 動かない場合は、ソースからビルドするのがよい。
    • 複数のPython仮想環境を用意して、運用環境を守る
      • Conda (minimamba) でOK
      • vllmのバージョンごとにconda環境を構築するとよい
        • 他のパッケージも更新され環境が破壊こともあるため
      • 可能であれば、conda-pack利用し、tarバックアップしておく

(4)実行前には必ずCleanupしましょう

 vllmは、KVキャッシュやオフロードで/tmpにスワップを残します。他にも多数のセッション・ログ・ソケット・ロックを作るなどの副作用があります。また今回の様なHPC共有環境では、複数ユーザ/マルチインスタンスとなるので/tmpでの権限競合・ロック問題が発生します。ログ・キャッシュ・ダンプの蓄積によるディスク枯渇が発生します。実際コンペでは他チームで度々HPC再起動が発生していました。

<対応策>

  • vllm終了後のClean-up手順
    • vllmプロセス終了
    • 古い一時ファイルの削除
    • [TIPS] /tmpではなく、ユーザローカルにおく様に設定する
      • /tmpを使い切ってしまい、HPC再起動が必要になるケースがある
      • 環境変数TMPDIRで設定するとよい

vLLM / 実行環境(PyTorch/NCCL/CUDA を含む)

変数 例(63 / 64) 目的・狙い どこで設定
★VLLM_HOST_IP 192.168.11.63 / 192.168.11.64 各ノードで“自分自身の”IPを明示(OOB通信用など)。 各ノード
VLLM_PORT 8000 API サーバの待受ポート(使用中なら vLLM が 8001 にフォールバック)。 フロント(63)
VLLM_WORKER_MULTIPROC_METHOD spawn vLLM ワーカー生成方式を明示(デッドロック回避)。 各ノード
VLLM_USE_V1 1(デフォ有効) vLLM v1 エンジン系フラグ(デフォで有効、明示したい時のみ)。 必要時
VLLM_USE_RAY_SPMD_WORKER 1(デフォ有効) Ray SPMD ワーカー使用。 必要時
VLLM_USE_RAY_COMPILED_DAG 1(必要時) Ray の Compiled DAG を使う最適化フラグ。 必要時
★HUGGING_FACE_HUB_TOKEN *** モデル取得が必要な場合の認証。 フロント(63)
★LD_LIBRARY_PATH CUDA/NCCL のパス CUDA/NCCL を確実に解決。 各ノード
CUDA_HOME ~/.conda/envs/llm-env-20250812/ CUDA ツールチェーンのルート。 各ノード
CUDA_VISIBLE_DEVICES 0,1,2,3,4,5,6,7 利用GPUの固定。 各ノード
★PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:256(63)/ 2048(64) 断片化抑制。大規模モデルでは 1024–2048 を推奨で統一すると安定。 各ノード
CUDA_LAUNCH_BLOCKING 0 デバッグ用でなければ 0。 各ノード
CUDA_CACHE_DISABLE 0 CUDA キャッシュの抑止を無効化(通常 0)。 各ノード
NCCL_SOCKET_IFNAME enp86s0np0 通信NICを固定(誤ったサブネット掴み防止)。 各ノード
★NCCL_IB_DISABLE 1 今回は IB を使わず Socket を使用(混在環境の誤検出回避)。 各ノード
NCCL_IB_HCA mlx5_0:1,...,mlx5_11:1 IB を使う場合のHCA指定(今回は将来オンにする時のため保持)。 各ノード
NCCL_NET_PLUGIN none NCCL Net プラグインを無効化(Socket を強制)。 各ノード
NCCL_P2P_LEVEL SYS P2P パス選択のレベル明示。 各ノード
NCCL_DEBUG INFO NCCLログを INFO で取得(初期化失敗時の原因追跡に有用)。 各ノード
NCCL_DEBUG_SUBSYS INIT,ENV,GRAPH どのサブシステムのログを出すか指定。 各ノード
NCCL_NVLS_ENABLE 0 NVLink Switch を使わない明示。 各ノード
NCCL_CUMEM_ENABLE 0 CUDA メモリアロケータの挙動調整(互換性目的)。 各ノード

■ Rayのハマりどころ

一番のハマりどころになります。ただ、正しく設定すればきちんと動きます。シングルノードでvllmを稼働させる場合にはRayを使わないので問題が起きない。マルチノードとなることでハマりどころが表面化する。

Ray

変数 目的・狙い どこで設定
★RAY_ADDRESS 192.168.11.63:6379 vLLM(フロント)のプロセスを既存 Ray クラスタに接続。 vLLMを起動するノード(192.168.11.63) のシェル
(参考)RAY_DEDUP_LOGS 0(必要時) Ray ログ重複の抑止/解除(トラブル時に詳細を見たい場合だけ)。 必要時のみ

設定すべき環境変数(カテゴリ別)

(1)Ray特有のハマりどころ

Headerプロセスとworkerプロセスで引数が違う

 分散計算環境ツールである Rayは Header-workerモデルを利用しており、対象ノードでそれぞれのプロセスをあらかじめ起動しておく必要があります。当たり前ですが、Headerプロセスの起動オプション/環境変数とWorkerプロセスのオプションはそれぞれ異なります。ここまでは当たり前のことですが、設定を混在させると思わぬ副作用に悩むことになります

Rayの「『よしな』にFallback」が招く地獄

 他のツールでは HeaderへWorkderのオプションを設定するといった間違えには、起動エラーで止まってくれるのですぐ気づきますが、Rayの場合、気を聞かせて自分で設定を変更してなんとか動かそうとします。それでも失敗するのですが、「よしなに」で動くことを知らないと、おかしなエラーメッセージに遭遇して悩むことになります。

(2)マルチノード特有のハマりどころ

Slurm と Ray、vLLM の三重管理に起因

  1. GPU割当/CUDA デバイス管理について、同じパラメータを其々のプロセスへ与える必要があります。矛盾があると動きませんが、辛いのはRayが平然と起動してしまう点です。
  2. Slurmの設定が間違っている or Rayの設定が間違っている のに、vllmの失敗で初めて気づくので、間違ってvllmの原因究明をしてしまうことが多々あります

ネットワーク/ポート

ノード間バックエンド(特にNCCL)の設定周りは、トラブルがつきものです。タイミングクリティカルでありながら、メモリエラーに直結します。原因を追いかけてみたら、ここの設定だったというのがしばしばあります。

Ray のオブジェクトストア(メモリ)枯渇・オブジェクトスピル問題

こちらのオブジェクトストアでの障害も、モデルサイズが大きくなって初めて表に出てくる設定問題です。そもそもモデルサイズが小さい時にも本来は間違っている設定になっていて、「このモデルでは動いたのになんでこっちはあかんのや!!」と叫びたくなります。落ち着いてください。最初から拙いだけです。

(3)共用環境特有のハマりどころ

  1. 一時ファイル/ランタイムディレクトリの競合(/tmp や /var/tmp)
  2. 一時ファイルのクリーンアップ不足(プロセスの残存)

<対応策>

  • SlurmとRay、vLLMを一気通貫で自動設定する対話ツールとその設定を使うWrapスクリプトを書くのが吉です。→ 時間があればこのツールの紹介記事を書く予定です

■ Slurm環境(sbatch化)でのハマりどころ

これまでの集大成的なハマり方になる

  • Slurmでのリソース予約が、下流となるRay→ vLLMへと影響する
    • ベストプラクティスは 設定スクリプト生成と各コマンドをWrapスクリプト経由起動へ
  • sbatchへのバッチタスクへの移行では、python仮想環境管理(=conda管理)が試される形になる

ベストプラクティス

  • フェーズ1)srunで対話シェル起動して開発
  • フェーズ2)sallocで対話的にリソース制御
  • フィーズ3)sbatch化して動作確認

Slurm

変数 目的・狙い どこで設定
(手動設定なし) このワークフローでは Slurm はノード割り当てとログインシェルの提供のみ。SLURM_* は自動付与を利用。

ハマりの振り返り

言うは簡単「マルチノードでvllmを使って巨大モデルを評価」

辛い点①:動作確認に時間がかかる

実際は中々大変でした。一番辛かったのは動作確認に時間がかかること。終盤のSlurm+Ray+vLLMのコンボで環境構築するには、1セッション3〜8分ぐらいがかかるため、めっちゃ時間が溶けていきました。Slurmではsbatchを投入してからリソースが割り当てられて稼働状態になるまで1分前後。その後のConda環境〜Ray起動まで1分。ここからvLLMが起動しモデルを読み込むのだが、vLLM起動が1分、モデル読み込みでTorchコンパイルが走ってエンドポイント起動まで3分〜5分かかる。ログも膨大でエラーログを探すのが辛い状況でした。

辛い点②: 動いてしまえば何ともない

終わってみてのこの結果。結構心を削りました。ベストな設定がわかってしまえばほぼほぼNoトラブルで運用できる様になります。正しい設定ができれば学びを生かす場所iはないわけで。LLM研究に携わるための経験値としては有意義ですが、コンペのような短期間決戦では、こうしたノウハウ構築にコンペが律速されるのは勿体無いと思いました。2026年コンペに参加される方があれば、この記事を読んでもらって、コンペをより快適に進めてもらえれば幸いです。

良かった点①: エラーログ解析へのLLM活用

これは自分でもかなりうまくいった試みです。2025年らしいアプローチだったと思います。エラーメッセージを拾っての検索でしたが、調査と原因絞り込みの時間を大幅に短縮できました。mmapエラーでは明後日に行きかけましたが、LLM提示の障害切り分けスクリプトを試していくことで、時間はかかりましたが真実にたどり着くことができました。LLM時代の幕開けらしい経験でした。


まとめ

 自分が主としてコンペへ参加できた期間は予選期間でした。運用したvLLM環境がうまく運用でき予選締切直前には様々なチューニングができるところまで漕ぎ着けられました。参加させていただいたチームRAMENは予選を突破し、その後の本戦では優勝🏆を勝ち取ってくれました。長い様で短い期間でしたが、優勝チームに参画できた形となり良い記念になりました。
 来年のコンペに参加される方は是非このページを参考にしていただき、今年よりも楽しくコンペに調整していただけると幸いです。by RickeyIron

YouTube video player

参考文献

  1. vLLM — GitHub(公式リポジトリ)
  2. vLLM — Releases(リリースノート)
  3. Ray — GitHub公式
  4. Ray - 公式サイト
  5. Slurm — sbatch(公式ドキュメント)
  6. Hugging Face Hub
  7. Qwen3(公式ブログ / リリース情報)
  8. NVIDIA H100 製品ページ(データセンター向け H100)
  9. NVIDIA NCCL — 環境変数ドキュメント(公式) NCCL_IB_DISABLE / NCCL_SOCKET_IFNAME 等の設定と意味の公式説明。ネットワーク周りのトラブルシュートに必須。
  10. Linux の仮想メモリ(オーバーコミット等)の解説 — Unix StackExchange(概説)
    Linux の overcommit(仮想メモリの振る舞い)に関する解説。アドレス空間予約の振る舞い理解に有用。 
  11. mmap エラー/ulimit に関する Q&A(実例) — Stack Overflow
    mmap が失敗する事例と ulimit/仮想メモリ周りのトラブルシューティング。 
  12. 実例:mmap: cannot allocate memory に関する GitHub issue(事例集)
    実運用で出た mmap 例外のログ例や対処議論の実例(Prometheus issue)。
  13. Apptainer(旧 Singularity)公式ドキュメント 非root環境でのコンテナ実行方法。本文の Singularity 回避案に対応。 

Appedix - Next Action

本記事に挙げたハマりどころは 実はSingularity (Apptainer)を使うことでかなり回避できます。一つのSlurmスクリプトに閉じますし、最後にこれを使えばよかったかなと。

  • 一貫したパラメタ管理
    • GPU系:TP/PP/DP/EP
    • Network: ノードIP/ポート
  • 一時ファイルの管理
    • マルチユーザ環境でも衝突リスク皆無
  • 高いシステム透過性
    • ファイルアクセス
    • ホストネットワークベース
    • GPUアクセス
#!/bin/bash
#SBATCH --job-name=vllm-ray-cluster
#SBATCH --nodes=2                    # 合計ノード数 (Head + Worker)
#SBATCH --gpus-per-node=4            # ノードあたりのGPU数
#SBATCH --cpus-per-task=40           # ノードあたりのCPU数
#SBATCH --mem=640G                   # ノードあたりのメモリ
#SBATCH --time=01:00:00
#SBATCH --output=logs/vllm_%j.log    # ログ出力先

# === 設定 ===
# 作成したSIFファイルのパス
IMAGE_PATH="./vllm-openai.sif"

# 使用するモデル (HuggingFaceのパス または ローカルパス)
MODEL_NAME="facebook/opt-125m"
# ※ 本番では "meta-llama/Llama-2-7b-chat-hf" などに変更
# ※ HuggingFaceへのログインが必要なモデルの場合、HF_TOKEN環境変数が必須

# Tensor Parallelismのサイズ (全GPU数に合わせるのが一般的)
# ここでは 2ノード x 1GPU = 2
TP_SIZE=2

# Rayの設定
REDIS_PORT=6379
DASHBOARD_PORT=8265

# ==========================================
# 1. ネットワーク設定 (Headノードの特定)
# ==========================================
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo "Head Node: $head_node ($head_node_ip)"

# Rayの一時ディレクトリ用 (競合回避)
export RAY_TMPDIR="/tmp/ray_$SLURM_JOB_ID"
mkdir -p $RAY_TMPDIR

# ==========================================
# 2. Ray Headノードの起動
# ==========================================
echo "Starting Ray Head on $head_node..."
srun --nodes=1 --ntasks=1 -w "$head_node" \
    singularity exec --nv --bind /tmp:/tmp \
    $IMAGE_PATH \
    ray start --head \
    --node-ip-address="$head_node_ip" \
    --port=$REDIS_PORT \
    --dashboard-host=0.0.0.0 \
    --dashboard-port=$DASHBOARD_PORT \
    --block &

# Headが立ち上がるのを少し待つ
sleep 15

# ==========================================
# 3. Ray Workerノードの起動 (Head以外)
# ==========================================
worker_num=$((SLURM_JOB_NUM_NODES - 1))

if [ $worker_num -gt 0 ]; then
    echo "Starting $worker_num Ray Workers..."
    srun --nodes=$worker_num --ntasks=$worker_num --exclude="$head_node" \
        singularity exec --nv --bind /tmp:/tmp \
        $IMAGE_PATH \
        ray start --address="$head_node_ip:$REDIS_PORT" \
        --block &

    sleep 15
fi

# ==========================================
# 4. vLLM (OpenAI API Server) の起動
# ==========================================
# vLLMは環境変数 RAY_ADDRESS を見て既存クラスタに接続します
export RAY_ADDRESS="$head_node_ip:$REDIS_PORT"

echo "Starting vLLM Server..."
echo "Model: $MODEL_NAME"
echo "Tensor Parallelism: $TP_SIZE"

# Headノード上で vLLM のPythonモジュールを実行
# --host 0.0.0.0 で外部からのアクセスを許可
# --bind でモデルのキャッシュディレクトリなどをマウントするのを推奨
# 例: --bind $HOME/.cache/huggingface:/root/.cache/huggingface

singularity exec --nv --bind /tmp:/tmp \
    --env HF_TOKEN=$HF_TOKEN \
    $IMAGE_PATH \
    python3 -m vllm.entrypoints.openai.api_server \
    --model "$MODEL_NAME" \
    --tensor-parallel-size $TP_SIZE \
    --host 0.0.0.0 \
    --port 8000

# vLLMサーバーが終了したらジョブも終わる