# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import ipaddress import json import random import sys import threading import time from paddle.distributed.launch.utils.kv_client import KVClient from paddle.distributed.launch.utils.kv_server import KVServer ETCD_PROTOCOL = 'etcd://' def _cmp_by_ip(x): x = json.loads(x) ip_x = x.get('candidate', '127.0.0.1:8080').split(':')[0] return int(ipaddress.IPv4Address(ip_x)) class Master: ''' Master is a distributed store design to exchange info among nodes ''' MAIN = "main" STANDBY = "standby" PARTICIPANT = "participant" def __init__(self, ctx): self.ctx = ctx self.server = None self.initialized = False self.endpoint = None def stop(self): raise NotImplementedError def set_status(self, status): pass def get_status(self): return None def restart_peer(self): pass def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): raise NotImplementedError @classmethod def factory(cls, ctx): if ctx.args.master and ctx.args.master.startswith(ETCD_PROTOCOL): return ETCDMaster(ctx) else: return HTTPMaster(ctx) class HTTPMaster(Master): def lazy_init(self): if self.initialized: return self.role = Master.PARTICIPANT if self.ctx.args.master: self.endpoint = self.ctx.args.master ip, port = self.endpoint.split(':') if ip in ['127.0.0.1', self.ctx.node.ip]: time.sleep(2 * random.random()) while not self.ctx.node.is_server_ready(ip, int(port)): try: self.server = KVServer(int(port)) self.role = Master.MAIN break except Exception as e: self.ctx.logger.warning(f"start master failed {e}") time.sleep(0.1) continue else: port = self.ctx.node.get_free_port() self.endpoint = f"{self.ctx.node.ip}:{port}" self.server = KVServer(port) self.role = Master.MAIN print("Copy the following command to other nodes to run.") cmd = [ sys.executable.split('/')[-1], "-m", "paddle.distributed.launch", ] cmd.extend(["--master", self.endpoint]) cmd.extend(sys.argv[1:]) print("-" * 80) print(" ".join(cmd)) print("-" * 80) if int(self.ctx.args.rank) >= 0: self.ctx.logger.warning( "--rank set in the command may not compatible in auto mode" ) if '127.0.0.1' in self.endpoint: self.endpoint = self.endpoint.replace('127.0.0.1', self.ctx.node.ip) self.client = KVClient(self.endpoint) self.initialized = True self._start_server() def _start_server(self): if self.server and not self.server.started: self.server.start() self.ctx.logger.debug(f"KV server start at {self.endpoint}") def _stop_server(self): if self.server and not self.server.stopped: self.server.stop() self.ctx.logger.debug("KV server stopped") def stop(self): self._stop_server() def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): if size < 2: return [value], 0 self.ctx.logger.info("Waiting peer start...") self.lazy_init() while not self.ctx.status.is_done(): if self.client.wait_server_ready(timeout=5): break else: self.ctx.logger.warning("master not ready") time.sleep(0.1) # 'aaaaaa' make sure main pod (master server) as rank 0 ky = 'aaaaaa' if rank < 0 and self.role == Master.MAIN else key k = f"{prefix}/{ky}/{rank}" while not self.ctx.status.is_done(): if not self.client.put(k, value): self.ctx.logger.warning("put value failed") time.sleep(0.1) continue rjson = self.client.get_prefix(prefix) self.ctx.logger.debug(f"sync peers {rjson}") if rjson and len(rjson) == size: if self.ctx.args.sort_ip: ret = sorted(rjson.values(), key=_cmp_by_ip) idx = ret.index(value) return ret, idx elif rank < 0: keys = list(rjson.keys()) keys.sort() ret = [rjson[k] for k in keys] idx = ret.index(value) return ret, idx else: ret = [None] * size for k, v in rjson.items(): ret[int(k.split('/')[-1])] = v return ret, rank else: time.sleep(0.5) return [], 0 class ETCDMaster(Master): def __init__(self, ctx): super().__init__(ctx) if self.ctx.args.master: # etcd://localhost:2379 self.endpoint = self.ctx.args.master.removeprefix("etcd://") import etcd3 from ..utils.etcd_client import ETCDClient host, port = self.endpoint.split(':') if ctx.is_auto_tuner_mode(): self.client = ETCDClient(host=host, port=port) else: self.client = etcd3.client(host=host, port=port) def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): ''' sync_peers gather all value for key under scope prefix result always be sorted either by rank or alphabet of pod.name ''' if size < 2: return [value], 0 self.ctx.logger.info("Waiting peer start...") path = f"{prefix}/{key}/{rank}" self.client.delete_prefix(prefix) self.ctx.logger.debug(f"sync path {path} value {value}") while not self.ctx.status.is_done(): self.client.put(path, value.encode('latin-1')) result = list(self.client.get_prefix(prefix)) result = copy.deepcopy(result) self.ctx.logger.debug(f"sync peers {result}") if len(result) == size: if self.ctx.args.sort_ip: values = [i[0].decode() for i in result] ret = sorted(values, key=_cmp_by_ip) idx = ret.index(value) return ret, idx elif rank < 0: keys = [i[1].key.decode() for i in result] sorted_keys = [i[1].key.decode() for i in result] sorted_keys.sort() values = [i[0].decode() for i in result] ret = [values[keys.index(k)] for k in sorted_keys] idx = ret.index(value) return ret, idx else: ret = [None] * size for v, k in result: ii = int(k.key.decode().split('/')[-1]) if ii < 0: self.ctx.logger.error(f"rank {ii} error in sync") ret[ii] = v.decode() return ret, rank else: time.sleep(0.5) def register_heartbeat(self, job_id, pod_id, ttl=10): if hasattr(self, 'heartbeat_prefix'): self.ctx.logger.warning("Heartbeat already done") return self.job_prefix = f'/paddle/{job_id}' self.heartbeat_prefix = f'{self.job_prefix}/heartbeat' self.client.delete_prefix(self.job_prefix) lease = self.client.lease(ttl) # self.client.delete_prefix(self.job_prefix) beat_path = f"{self.heartbeat_prefix}/{pod_id}" self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) def _beat_watch(event): self.ctx.status.restart() beat_watch = self.client.add_watch_prefix_callback( self.heartbeat_prefix, _beat_watch ) def _heartbeat(): while not self.ctx.status.is_done(): try: lease.refresh() if pod_id not in self.fetch_peer_alive(): self.client.put( beat_path, pod_id.encode('latin-1'), lease=lease ) self.ctx.logger.debug("Heartbeat register again") except Exception as e: self.ctx.logger.error(f"Heartbeat error {e}") time.sleep(ttl / 2) self.ctx.logger.debug("Heartbeat done") self.client.cancel_watch(beat_watch) self.beat_thread = threading.Thread( name='heartbeat', target=_heartbeat, daemon=True ) self.beat_thread.start() def fetch_peer_alive(self): peer_alive = [ i[0].decode() for i in self.client.get_prefix(self.heartbeat_prefix) ] self.ctx.logger.debug(f"peer alive {peer_alive}") return peer_alive def wait_peer_ready(self, replicas_min, replicas_max, timeout): timeout = timeout if timeout > 1 else 3 end = time.time() + timeout np_pre = len(self.fetch_peer_alive()) while not self.ctx.status.is_done() and time.time() < end: np = len(self.fetch_peer_alive()) if np == replicas_max: # maximum replicas reached, return immediately return (True, replicas_max) elif np != np_pre: # replicas are changing, reset timeout end = time.time() + timeout np_pre = np time.sleep(0.2) else: time.sleep(0.5) np = len(self.fetch_peer_alive()) if np >= replicas_min and np <= replicas_max: return (True, np) else: return (False, np) def restart_peer(self): self.client.delete_prefix(self.heartbeat_prefix) def set_status(self, status): assert self.client.put( self.job_prefix, status.encode('latin-1'), lease=self.client.lease(600), ), f"set status failed {status}" def get_status(self): value = self.client.get(self.job_prefix)[0] return value.decode() if value is not None else '' def stop(self): if hasattr(self, 'beat_thread'): self.ctx.status.done() # daemon thread # self.beat_thread.join()