2022-03-21 15:14:36 +08:00
|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2022-06-05 10:58:58 +08:00
|
|
|
#
|
2022-03-21 15:14:36 +08:00
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
2022-06-05 10:58:58 +08:00
|
|
|
#
|
2022-03-21 15:14:36 +08:00
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
2022-06-05 10:58:58 +08:00
|
|
|
#
|
2022-03-21 15:14:36 +08:00
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
2025-09-05 10:52:13 +08:00
|
|
|
import warnings
|
2022-11-29 18:50:04 +08:00
|
|
|
from argparse import REMAINDER, ArgumentParser
|
2024-07-03 10:41:41 +08:00
|
|
|
|
|
|
|
|
from paddle.utils import strtobool
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
env_args_mapping = {
|
2023-07-24 17:54:05 +08:00
|
|
|
'POD_IP': ('host', str),
|
|
|
|
|
'PADDLE_MASTER': ('master', str),
|
|
|
|
|
'PADDLE_DEVICES': ('devices', str),
|
|
|
|
|
'PADDLE_NNODES': ('nnodes', str),
|
|
|
|
|
'PADDLE_RUN_MODE': ('run_mode', str),
|
|
|
|
|
'PADDLE_LOG_LEVEL': ('log_level', str),
|
|
|
|
|
'PADDLE_LOG_OVERWRITE': ('log_overwrite', strtobool),
|
|
|
|
|
'PADDLE_SORT_IP': ('sort_ip', strtobool),
|
|
|
|
|
'PADDLE_NPROC_PER_NODE': ('nproc_per_node', int),
|
|
|
|
|
'PADDLE_JOB_ID': ('job_id', str),
|
|
|
|
|
'PADDLE_RANK': ('rank', int),
|
|
|
|
|
'PADDLE_LOG_DIR': ('log_dir', str),
|
|
|
|
|
'PADDLE_MAX_RESTART': ('max_restart', int),
|
|
|
|
|
'PADDLE_ELASTIC_LEVEL': ('elastic_level', int),
|
|
|
|
|
'PADDLE_ELASTIC_TIMEOUT': ('elastic_timeout', int),
|
|
|
|
|
'PADDLE_SERVER_NUM': ('server_num', int),
|
|
|
|
|
'PADDLE_TRAINER_NUM': ('trainer_num', int),
|
|
|
|
|
'PADDLE_SERVERS_ENDPOINTS': ('servers', str),
|
|
|
|
|
'PADDLE_TRAINERS_ENDPOINTS': ('trainers', str),
|
|
|
|
|
'PADDLE_GLOO_PORT': ('gloo_port', int),
|
|
|
|
|
'PADDLE_WITH_GLOO': ('with_gloo', str),
|
|
|
|
|
'PADDLE_START_PORT': ('start_port', int),
|
|
|
|
|
'PADDLE_IPS': ('ips', str),
|
|
|
|
|
"PADDLE_AUTO_PARALLEL_CONFIG": ('auto_parallel_config', str),
|
2024-07-29 11:11:34 +08:00
|
|
|
'PADDLE_AUTO_CLUSTER': ('auto_cluster_config', strtobool),
|
2022-03-21 15:14:36 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_envs():
|
2025-09-18 17:05:36 +08:00
|
|
|
for proxy_key in ("http_proxy", "https_proxy"):
|
|
|
|
|
if os.environ.get(proxy_key) is not None:
|
|
|
|
|
os.environ[f"{proxy_key}_original"] = os.environ.pop(proxy_key)
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Unset '{proxy_key}' to ensure stable NCCL communication in distributed training "
|
|
|
|
|
f"(backed up as '{proxy_key}_original').",
|
|
|
|
|
category=UserWarning,
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
return os.environ.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
|
|
|
|
|
|
base_group = parser.add_argument_group("Base Parameters")
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--master",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="the master/rendezvous server, ip:port",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
2023-05-10 16:05:01 +08:00
|
|
|
"--legacy", type=strtobool, default=False, help="use legacy launch"
|
2022-10-23 20:01:27 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--rank", type=int, default=-1, help="the node rank"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--log_level", type=str, default="INFO", help="log level. Default INFO"
|
|
|
|
|
)
|
|
|
|
|
|
2023-05-10 16:05:01 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--log_overwrite",
|
|
|
|
|
type=strtobool,
|
|
|
|
|
default=False,
|
|
|
|
|
help="overwrite exits logfiles. Default False",
|
|
|
|
|
)
|
|
|
|
|
|
2023-06-08 17:21:24 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--sort_ip",
|
|
|
|
|
type=strtobool,
|
|
|
|
|
default=False,
|
|
|
|
|
help="rank node by ip. Default False",
|
|
|
|
|
)
|
|
|
|
|
|
2023-06-25 16:30:24 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--enable_gpu_log",
|
|
|
|
|
type=strtobool,
|
|
|
|
|
default=True,
|
|
|
|
|
help="enable capture gpu log while running. Default True",
|
|
|
|
|
)
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--nnodes",
|
|
|
|
|
type=str,
|
|
|
|
|
default="1",
|
|
|
|
|
help="the number of nodes, i.e. pod/node number",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--nproc_per_node",
|
|
|
|
|
type=int,
|
|
|
|
|
default=None,
|
|
|
|
|
help="the number of processes in a pod",
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--log_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="log",
|
2022-10-23 20:01:27 +08:00
|
|
|
help="the path for each process's log. Default ./log",
|
|
|
|
|
)
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--run_mode",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="run mode of the job, collective/ps/ps-heter",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--job_id",
|
|
|
|
|
type=str,
|
|
|
|
|
default="default",
|
|
|
|
|
help="unique id of the job. Default default",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
base_group.add_argument(
|
|
|
|
|
"--devices",
|
|
|
|
|
"--gpus",
|
|
|
|
|
"--npus",
|
|
|
|
|
"--xpus",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="accelerate devices. as --gpus,npus,xpus",
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
base_group.add_argument("--host", type=str, default=None, help="host ip")
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--ips",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="nodes ips, e.g. 10.10.1.1,10.10.1.2",
|
|
|
|
|
)
|
2022-08-11 14:41:31 +08:00
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--start_port", type=int, default=6070, help="fix port start with"
|
|
|
|
|
)
|
2022-08-11 14:41:31 +08:00
|
|
|
|
2023-04-13 15:06:23 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--auto_parallel_config",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="auto parallel config file absolute path, the file should be json format",
|
|
|
|
|
)
|
|
|
|
|
|
2024-07-29 11:11:34 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--auto_cluster_config",
|
|
|
|
|
type=strtobool,
|
|
|
|
|
default=0,
|
|
|
|
|
help="auto parallel auto cluster config switch",
|
|
|
|
|
)
|
|
|
|
|
|
2022-10-23 20:01:27 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"training_script",
|
|
|
|
|
type=str,
|
|
|
|
|
help="the full path of py script,"
|
|
|
|
|
"followed by arguments for the "
|
|
|
|
|
"training script",
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
2023-06-14 16:51:14 +08:00
|
|
|
base_group.add_argument(
|
|
|
|
|
"--auto_tuner_json",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="auto tuner json file path",
|
|
|
|
|
)
|
|
|
|
|
|
2022-03-21 15:14:36 +08:00
|
|
|
base_group.add_argument('training_script_args', nargs=REMAINDER)
|
|
|
|
|
|
|
|
|
|
ps_group = parser.add_argument_group("Parameter-Server Parameters")
|
|
|
|
|
# for parameter server
|
2022-10-23 20:01:27 +08:00
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--servers", type=str, default='', help="servers endpoints full list"
|
|
|
|
|
)
|
|
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--trainers", type=str, default='', help="trainers endpoints full list"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--trainer_num", type=int, default=None, help="number of trainers"
|
|
|
|
|
)
|
|
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--server_num", type=int, default=None, help="number of servers"
|
|
|
|
|
)
|
|
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--gloo_port", type=int, default=6767, help="gloo http port"
|
|
|
|
|
)
|
|
|
|
|
ps_group.add_argument(
|
|
|
|
|
"--with_gloo", type=str, default="1", help="use gloo or not"
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
# parameter elastic mode
|
|
|
|
|
elastic_group = parser.add_argument_group("Elastic Parameters")
|
2022-10-23 20:01:27 +08:00
|
|
|
elastic_group.add_argument(
|
|
|
|
|
"--max_restart",
|
|
|
|
|
type=int,
|
|
|
|
|
default=3,
|
|
|
|
|
help="the times can restart. Default 3",
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
|
|
|
|
elastic_group.add_argument(
|
|
|
|
|
"--elastic_level",
|
|
|
|
|
type=int,
|
|
|
|
|
default=-1,
|
2022-10-23 20:01:27 +08:00
|
|
|
help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart",
|
2022-03-21 15:14:36 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elastic_group.add_argument(
|
|
|
|
|
"--elastic_timeout",
|
|
|
|
|
type=int,
|
|
|
|
|
default=30,
|
2022-10-23 20:01:27 +08:00
|
|
|
help="seconds to wait before elastic job begin to train",
|
|
|
|
|
)
|
2022-03-21 15:14:36 +08:00
|
|
|
|
2023-06-30 16:06:18 +08:00
|
|
|
args = parser.parse_known_args()
|
|
|
|
|
env_rank = int(os.getenv('PADDLE_TRAINER_ID', -1))
|
|
|
|
|
if env_rank >= 0:
|
|
|
|
|
assert hasattr(args[0], "rank")
|
|
|
|
|
args[0].rank = env_rank
|
|
|
|
|
return args
|