通常、MPIプログラムは、実行中のどれかのrankがエラー等で以上終了した場合(あるいはMPI_Finalizeを呼び出さずに終了した場合)は全プロセスが強制終了されることが期待されます。
が、ChainerMNを含む mpi4py を用いたプログラムを実行している場合、Pythonの例外によってプロセスの1つが異常終了しても、その他のプロセスが終了せずにハングしてしまうという問題があります。
# test.py def func(): import mpi4py.MPI mpi_comm = mpi4py.MPI.COMM_WORLD if mpi_comm.rank == 0: raise ValueError('failure!') mpi4py.MPI.COMM_WORLD.Barrier() if __name__ == '__main__': func()
実行結果:
$ mpiexec -n 2 python test.py
Traceback (most recent call last):
File "main.py", line 27, in <module>
func()
File "main.py", line 21, in func
raise ValueError('failure!')
ValueError: failure! # <----- このまま固まっていて、プロセスが終了しない
この時、オンプレミスのクラスタ環境等であれば、logを監視したりなどして手動で kill すればよいのですが、クラウド環境やジョブスケジューラ下ではなかなかそうもいきません。特にクラウド環境では、計算が進んでいないのに課金だけが進んでしまうという状態になるので非常にまずいですね。
この場合、Pythonの例外処理機構をフックして、処理されていない例外が発生した場合に MPI_Abort() を呼び出すことによってMPIプロセスを強制終了することができます。
import sys # Global error handler def global_except_hook(exctype, value, traceback): import sys from traceback import print_exception print_exception(exctype, value, traceback) sys.stderr.flush() import mpi4py.MPI mpi4py.MPI.COMM_WORLD.Abort(1) sys.excepthook = global_except_hook def func(): import mpi4py.MPI mpi_comm = mpi4py.MPI.COMM_WORLD if mpi_comm.rank == 0: raise ValueError('failure!') mpi4py.MPI.COMM_WORLD.Barrier() if __name__ == '__main__': func()
実行結果:
$ mpiexec -n 2 python main.py
Traceback (most recent call last):
File "main.py", line 26, in <module>
func()
File "main.py", line 20, in func
raise ValueError('failure!')
ValueError: failure!
--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 0 in communicator MPI_COMM_WORLD
with errorcode 1.
NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes.
You may or may not see output from other processes, depending on
exactly when Open MPI kills them.
--------------------------------------------------------------------------