UCXを試す日記(8)
(追記:完全にグダグダなただの日記なので、タイトルに「日記」と付け足しました)
ここまで、UCXのディストリビューションに含まれている uct_hello_world.c
を自作C++ラッパーに移植するという目標で勉強をしてきたのですが、C++ラッパーの実装の方針を大幅に変更することにしました。
これまではC++での完全なラッパーを目指していたので、コンポーネントの論理的な階層とC++のクラスクラス階層を一致させ、メモリ管理は std::shared_ptr
を使うという方針でした。
しかし、
* コンポーネント間の依存性が複雑で循環参照を避けるのが面倒なこと
* 単一の shared_ptr
のインスタンスをあちこちコピーし回す必要があり、関数の関数の引数が増えたり、クラスのメンバ関数が不必要に増えたりして却って複雑になってしまうこと
などの事情が見えてきました。
なので、よりシンプルな、非常に薄いラッパーに路線変更するか、そもそもC++ラッパーを作るのをやめるという方針のどちらかに方向転換するよていです。
実際にアプリを書きながら、考えてみようと思います。
とりあえずは、ディープラーニングで定番である Allreduce
関数の実装でもしてみようかと思っています。
MPIでRank順に出力する
小ネタ。
MPIのプログラムから printf
などを使ってデバッグ出力をする場合、全プロセスから一斉に同じ出力をしたときの順序は保証されていません。
例えば、下のようなプログラムを考えます。
#include <stdio.h> #include <mpi.h> int main(int argc, char **argv) { int rank; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); printf("I'm rank %d\n", rank); MPI_Finalize(); }
# 実行例 $ mpiexec -n 10 ./a.out I'm rank 4 I'm rank 5 I'm rank 6 I'm rank 7 I'm rank 9 I'm rank 0 I'm rank 1 I'm rank 8 I'm rank 2 I'm rank 3
この順序を制御して、例えばRank順に出力したいということは頻繁にあります。MPIを使っている人なら、だいたい自前で書いてしまう処理で、以下のように簡単に書けます。
for (i = 0; i < size; i++) { if (i == rank) { printf("I'm rank %d\n", rank); } MPI_Barrier(MPI_COMM_WORLD); }
# 実行例 $ mpiexec -n 10 ./a.out I'm rank 0 I'm rank 1 I'm rank 2 I'm rank 3 I'm rank 4 I'm rank 5 I'm rank 6 I'm rank 7 I'm rank 8 I'm rank 9
次に、mpi4py
を使ったPythonスクリプトでも簡単に書きたいなーと思ったので書いてみました。 ループを書くのは面倒なので with
構文で書けないかなーと思い、やってみたら意外と簡単にできました。 with
便利。
class RankOrdered(object): def __init__(self, comm): self._comm = comm def __enter__(self): for i in range(0, comm.rank): comm.barrier() def __exit__(self, exception_type, exception_value, traceback): for i in range(comm.size - comm.rank): comm.barrier() # 利用例 from mpi4py import MPI with RankOrdered(MPI.COMM_WORLD): print("rank {}".format(comm.rank))
ChainerMNをMPIで実行中に、例外でプロセスが死んでも実行が止まらない問題
通常、MPIプログラムは、実行中のどれかのrankがエラー等で以上終了した場合(あるいはMPI_Finalize
を呼び出さずに終了した場合)は全プロセスが強制終了されることが期待されます。
が、ChainerMNを含む mpi4py
を用いたプログラムを実行している場合、Pythonの例外によってプロセスの1つが異常終了しても、その他のプロセスが終了せずにハングしてしまうという問題があります。
# test.py def func(): import mpi4py.MPI mpi_comm = mpi4py.MPI.COMM_WORLD if mpi_comm.rank == 0: raise ValueError('failure!') mpi4py.MPI.COMM_WORLD.Barrier() if __name__ == '__main__': func()
実行結果:
$ mpiexec -n 2 python test.py Traceback (most recent call last): File "main.py", line 27, in <module> func() File "main.py", line 21, in func raise ValueError('failure!') ValueError: failure! # <----- このまま固まっていて、プロセスが終了しない
この時、オンプレミスのクラスタ環境等であれば、logを監視したりなどして手動で kill
すればよいのですが、クラウド環境やジョブスケジューラ下ではなかなかそうもいきません。特にクラウド環境では、計算が進んでいないのに課金だけが進んでしまうという状態になるので非常にまずいですね。
この場合、Pythonの例外処理機構をフックして、処理されていない例外が発生した場合に MPI_Abort()
を呼び出すことによってMPIプロセスを強制終了することができます。
import sys # Global error handler def global_except_hook(exctype, value, traceback): import sys from traceback import print_exception print_exception(exctype, value, traceback) sys.stderr.flush() import mpi4py.MPI mpi4py.MPI.COMM_WORLD.Abort(1) sys.excepthook = global_except_hook def func(): import mpi4py.MPI mpi_comm = mpi4py.MPI.COMM_WORLD if mpi_comm.rank == 0: raise ValueError('failure!') mpi4py.MPI.COMM_WORLD.Barrier() if __name__ == '__main__': func()
実行結果:
$ mpiexec -n 2 python main.py Traceback (most recent call last): File "main.py", line 26, in <module> func() File "main.py", line 20, in func raise ValueError('failure!') ValueError: failure! -------------------------------------------------------------------------- MPI_ABORT was invoked on rank 0 in communicator MPI_COMM_WORLD with errorcode 1. NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes. You may or may not see output from other processes, depending on exactly when Open MPI kills them. --------------------------------------------------------------------------