2015-11-19 17:52:41 -05:00
#!/usr/bin/env python
2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2015-11-19 17:52:41 -05:00
"""
Launch a distributed job
"""
import argparse
import os , sys
import signal
import logging
2016-03-23 23:17:04 -04:00
curr_path = os . path . abspath ( os . path . dirname ( __file__ ) )
2018-03-25 19:45:24 -07:00
sys . path . append ( os . path . join ( curr_path , " ../3rdparty/dmlc-core/tracker " ) )
2016-03-23 23:17:04 -04:00
def dmlc_opts ( opts ) :
""" convert from mxnet ' s opts to dmlc ' s opts
"""
args = [ ' --num-workers ' , str ( opts . num_workers ) ,
' --num-servers ' , str ( opts . num_servers ) ,
' --cluster ' , opts . launcher ,
' --host-file ' , opts . hostfile ,
' --sync-dst-dir ' , opts . sync_dst_dir ]
2018-01-31 11:34:12 -08:00
# convert to dictionary
dopts = vars ( opts )
for key in [ ' env_server ' , ' env_worker ' , ' env ' ] :
for v in dopts [ key ] :
args . append ( ' -- ' + key . replace ( " _ " , " - " ) )
args . append ( v )
args + = opts . command
2016-09-20 14:37:40 -07:00
try :
from dmlc_tracker import opts
except ImportError :
print ( " Can ' t load dmlc_tracker package. Perhaps you need to run " )
print ( " git submodule update --init --recursive " )
raise
2016-03-23 23:17:04 -04:00
dmlc_opts = opts . get_opts ( args )
return dmlc_opts
2015-11-19 17:52:41 -05:00
def main ( ) :
parser = argparse . ArgumentParser ( description = ' Launch a distributed job ' )
parser . add_argument ( ' -n ' , ' --num-workers ' , required = True , type = int ,
help = ' number of worker nodes to be launched ' )
parser . add_argument ( ' -s ' , ' --num-servers ' , type = int ,
help = ' number of server nodes to be launched, \
in default it is equal to NUM_WORKERS ' )
parser . add_argument ( ' -H ' , ' --hostfile ' , type = str ,
2016-03-23 23:17:04 -04:00
help = ' the hostfile of slave machines which will run \
the job. Required for ssh and mpi launcher ' )
parser . add_argument ( ' --sync-dst-dir ' , type = str ,
2015-11-19 17:52:41 -05:00
help = ' if specificed, it will sync the current \
2016-03-23 23:17:04 -04:00
directory into slave machines \' s SYNC_DST_DIR if ssh \
launcher is used ' )
2015-11-19 17:52:41 -05:00
parser . add_argument ( ' --launcher ' , type = str , default = ' ssh ' ,
2016-03-23 23:17:04 -04:00
choices = [ ' local ' , ' ssh ' , ' mpi ' , ' sge ' , ' yarn ' ] ,
help = ' the launcher to use ' )
2018-01-31 11:34:12 -08:00
parser . add_argument ( ' --env-server ' , action = ' append ' , default = [ ] ,
help = ' Given a pair of environment_variable:value, sets this value of \
environment variable for the server processes. This overrides values of \
those environment variable on the machine where this script is run from. \
Example OMP_NUM_THREADS:3 ' )
parser . add_argument ( ' --env-worker ' , action = ' append ' , default = [ ] ,
help = ' Given a pair of environment_variable:value, sets this value of \
environment variable for the worker processes. This overrides values of \
those environment variable on the machine where this script is run from. \
Example OMP_NUM_THREADS:3 ' )
parser . add_argument ( ' --env ' , action = ' append ' , default = [ ] ,
help = ' given a environment variable, passes their \
values from current system to all workers and servers. \
Not necessary when launcher is local as in that case \
all environment variables which are set are copied. ' )
2015-11-19 17:52:41 -05:00
parser . add_argument ( ' command ' , nargs = ' + ' ,
help = ' command for launching the program ' )
args , unknown = parser . parse_known_args ( )
2016-03-23 23:17:04 -04:00
args . command + = unknown
2015-11-19 17:52:41 -05:00
if args . num_servers is None :
args . num_servers = args . num_workers
2016-03-23 23:17:04 -04:00
args = dmlc_opts ( args )
2017-08-08 16:36:23 -07:00
2016-08-23 01:51:59 +08:00
if args . host_file is None or args . host_file == ' None ' :
if args . cluster == ' yarn ' :
from dmlc_tracker import yarn
yarn . submit ( args )
elif args . cluster == ' local ' :
from dmlc_tracker import local
local . submit ( args )
elif args . cluster == ' sge ' :
from dmlc_tracker import sge
sge . submit ( args )
else :
raise RuntimeError ( ' Unknown submission cluster type %s ' % args . cluster )
2015-11-19 17:52:41 -05:00
else :
2016-08-23 01:51:59 +08:00
if args . cluster == ' ssh ' :
from dmlc_tracker import ssh
ssh . submit ( args )
elif args . cluster == ' mpi ' :
from dmlc_tracker import mpi
mpi . submit ( args )
else :
raise RuntimeError ( ' Unknown submission cluster type %s ' % args . cluster )
2015-11-19 17:52:41 -05:00
def signal_handler ( signal , frame ) :
2017-01-23 03:10:47 +09:00
logging . info ( ' Stop launcher ' )
2015-11-19 17:52:41 -05:00
sys . exit ( 0 )
if __name__ == ' __main__ ' :
2016-03-23 23:17:04 -04:00
fmt = ' %(asctime)s %(levelname)s %(message)s '
logging . basicConfig ( format = fmt , level = logging . INFO )
2015-11-19 17:52:41 -05:00
signal . signal ( signal . SIGINT , signal_handler )
main ( )