# Copyright (C) 2022-2026 Exaloop Inc. from threading import Lock, ThreadLocal from time import time as _time, sleep as _sleep from sys import stderr as _stderr import internal.gc as gc import internal.static as static _FUTURE_STATE_PENDING: Literal[int] = 0 _FUTURE_STATE_FINISHED: Literal[int] = 1 _FUTURE_STATE_EXCEPTION: Literal[int] = 2 _FUTURE_STATE_CANCELLED: Literal[int] = 3 @pure @derives @llvm def _bitcast(x: T, D: type, T: type) -> D: %y = bitcast {=T} %x to {=D} ret {=D} %y @tuple class WorkItem: coro: cobj # Coroutine handle task: cobj # Raw Task pointer, or null result_size: int # Size in bytes of enclosed Future/Task result class WorkNode: work: WorkItem prev: Optional[WorkNode] next: Optional[WorkNode] def __init__(self, work: WorkItem): self.work = work self.prev = None self.next = None def cancelled(self): return self.work.cancelled() class Timer: work: WorkItem when: float def cancelled(self): return self.work.cancelled() class EventLoop: _lock: Lock _work_head: Optional[WorkNode] _work_tail: Optional[WorkNode] _work_curr: Optional[WorkItem] _timers: Ptr[Timer] _timers_len: int _timers_cap: int _running: bool _closed: bool _stop_flag: bool @tuple class Handle: _node: WorkNode _loop: EventLoop def cancel(self): loop = self._loop with loop._lock: if not self.cancelled(): loop._cancel(self._node) def cancelled(self): return self._node.cancelled() @tuple class TimerHandle: _timer: Timer _loop: EventLoop def cancel(self): loop = self._loop with loop._lock: if not self.cancelled(): loop._cancel(self._timer) def cancelled(self): return self._timer.cancelled() def when(self): return self._timer.when class Future: _result: R _exception: Optional[BaseException] _lock: Lock _loop: EventLoop _cancel_msg: str _done_callbacks: Ptr[WorkItem] _done_callbacks_len: int _done_callbacks_cap: int _state: int R: type class Task(Static[Future[R]]): _name: str _coro: cobj _waiting_on_task: Optional[Task[None]] _waiting_on_lock: cobj R: type def _raw_task(task: Task) -> Task[None]: # Skip over `result` field return _bitcast(task.__raw__() + gc.sizeof(task.R), Task[None]) @extend class WorkItem: def __new__(coro: cobj): return WorkItem(coro, cobj(), 0) def __new__(coro: Coroutine): return WorkItem(coro.__raw__(), cobj(), 0) def __new__(task: Task): data = task.__raw__() coro = task._coro result_size = gc.sizeof(task.R) return WorkItem(coro, data, result_size) def __new__(): return WorkItem(cobj()) def raw_task(self) -> Task[None]: # Skip over `result` field return _bitcast(self.task + self.result_size, Task[None]) def cancelled(self): return not bool(self.coro) class InvalidStateError(Exception): def __init__(self, message: str = ''): super().__init__(message) class CancelledError(BaseException): def __init__(self, message: str = ''): super().__init__(message) async def _callback_wrapper(callback, *args): callback(*args) async def _future_callback_wrapper(callback, future): callback(future) _current_loop: ThreadLocal[Optional[EventLoop]] = None _running_loop: ThreadLocal[Optional[EventLoop]] = None @tuple class _EnterLoop: loop: EventLoop old_current_loop: Optional[EventLoop] old_running_loop: Optional[EventLoop] def __new__(loop: EventLoop): return _EnterLoop(loop, _current_loop, _running_loop) def __enter__(self): global _current_loop global _running_loop _current_loop = self.loop _running_loop = self.loop def __exit__(self): global _current_loop global _running_loop _current_loop = self.old_current_loop _running_loop = self.old_running_loop @extend class EventLoop: def __init__(self): TIMERS_CAP_INIT: Literal[int] = 8 # must be power of 2 self._lock = Lock() self._work_head = None self._work_tail = None self._work_curr = None self._timers = Ptr[Timer](TIMERS_CAP_INIT) self._timers_len = 0 self._timers_cap = TIMERS_CAP_INIT self._running = False self._closed = False self._stop_flag = False def _ensure_open(self): if self._closed: raise RuntimeError("Event loop is closed") def _ensure_not_running(self): if self._running: raise RuntimeError("This event loop is already running") def close(self): with self._lock: if not self._closed: self._work_head = None self._work_tail = None self._work_curr = None self._timers = Ptr[Timer]() self._timers_len = 0 self._timers_cap = 0 self._running = False self._closed = True self._stop_flag = True def time(self): return _time() def stop(self): with self._lock: self._stop_flag = True def is_running(self): return self._running def is_closed(self): return self._closed def _call_soon(self, work: WorkItem, return_handle: Literal[bool] = False): with self._lock: self._ensure_open() node = self._work_enqueue(work) if return_handle: return Handle(node, self) def _call_soon_no_duplicate(self, work: WorkItem): with self._lock: self._ensure_open() node = self._work_head while node is not None: if node.work.coro == work.coro: return False node = node.next self._work_enqueue(work) return True def call_soon(self, callback, *args): item = WorkItem(_callback_wrapper(callback, *args)) return self._call_soon(item, return_handle=True) def call_soon_threadsafe(self, callback, *args): with self._lock: return self.call_soon(callback, *args) def _call_later(self, work: WorkItem, delay: float, return_handle: Literal[bool] = False): timer = Timer(work, self.time() + delay) with self._lock: self._ensure_open() self._timers_push(timer) if return_handle: return TimerHandle(timer, self) def call_later(self, delay: float, callback, *args): item = WorkItem(_callback_wrapper(callback, *args)) return self._call_later(item, delay, return_handle=True) def call_at(self, delay: float, callback, *args): return self.call_later(delay, callback, *args) def _handle_exception(self, exc: BaseException): rtti = _bitcast(Ptr[cobj](exc.__raw__())[1], TypeInfo) _stderr.write(f'{rtti.nice_name}: {exc}\n') def _enqueue_timers(self, now: float): while self._timers_len > 0 and self._timers[0].when <= now: timer = self._timers_pop() if not timer.cancelled(): self._work_enqueue(timer.work) def _step(self): work: Optional[WorkItem] = None now = self.time() stop = False with self._lock: self._enqueue_timers(now) if not self._work_empty(): work = self._work_dequeue() self._work_curr = work stop = self._stop_flag if stop: return True if work is not None: g = Generator[None](work.coro) exception: Optional[BaseException] = None try: g.__resume__() except CancelledError: pass except SystemExit: self._work_curr = None raise except AssertionError: self._work_curr = None raise except BaseException as e: exception = e if work.task: task = work.raw_task() callbacks = Ptr[WorkItem]() num_callbacks = 0 with task._lock: if not task.done(): if exception is not None: callbacks, num_callbacks = \ task._finish_with_exception(exception) elif g.__done__(): str.memcpy(work.task, g.__promise__().as_byte(), work.result_size) callbacks, num_callbacks = task._finish() task._schedule_callbacks(callbacks, num_callbacks) elif exception is not None: self._handle_exception(exception) self._work_curr = None else: sleep_time = 0.01 # 10ms default with self._lock: if self._timers_len > 0: dt = self._timers[0].when - now if dt > 0 and dt < sleep_time: sleep_time = dt if sleep_time > 0: _sleep(sleep_time) return False def run_forever(self): self._ensure_open() self._ensure_not_running() self._running = True self._stop_flag = False with _EnterLoop(self): while True: stop = self._step() if stop: break self._running = False def run_until_complete(self, future): self._ensure_open() self._ensure_not_running() self._running = True self._stop_flag = False with _EnterLoop(self): while not future.done(): stop = self._step() if stop: break self._running = False return future.result() def _work_empty(self): return self._work_tail is None def _work_enqueue(self, work: WorkItem): node = WorkNode(work) tail = self._work_tail if tail is None: self._work_head = node self._work_tail = node else: node.prev = tail tail.next = node self._work_tail = node return node def _work_dequeue(self): # caller must ensure non-empty head = self._work_head self._work_head = head.next if head.next is None: self._work_tail = None else: head.next.prev = None head.next = None return head.work def _cancel(self, node: WorkNode): node.work = WorkItem() if self._work_head is None or self._work_tail is None: return head: WorkNode = self._work_head tail: WorkNode = self._work_tail if node is head: self._work_head = head.next if head.next is None: self._work_tail = None else: head.next.prev = None head.next = None elif node is tail: self._work_tail = tail.prev if tail.prev is None: self._work_head = None else: tail.prev.next = None tail.prev = None else: node.prev.next = node.next node.next.prev = node.prev def _cancel(self, timer: Timer): timer.work = WorkItem() def _timers_reserve(self, new_cap: int): old_cap = self._timers_cap if new_cap <= old_cap: return sz = gc.sizeof(Timer) self._timers = Ptr[Timer](gc.realloc( self._timers.as_byte(), new_cap * sz, old_cap * sz)) self._timers_cap = new_cap def _timers_swap(self, i: int, j: int): timers = self._timers tmp = timers[i] timers[i] = timers[j] timers[j] = tmp def _timers_push(self, t: Timer): if self._timers_len == self._timers_cap: self._timers_reserve(self._timers_cap * 2) i = self._timers_len self._timers_len += 1 timers = self._timers timers[i] = t # Sift up while i > 0: parent = (i - 1) >> 1 if timers[parent].when <= timers[i].when: break self._timers_swap(parent, i) i = parent def _timers_pop(self): # caller must ensure non-empty timers = self._timers out = timers[0] self._timers_len -= 1 timers_len = self._timers_len if timers_len > 0: timers[0] = timers[timers_len] i = 0 while True: left = 2*i + 1 right = 2*i + 2 smallest = i if (left < timers_len and timers[left].when < timers[smallest].when): smallest = left if (right < timers_len and timers[right].when < timers[smallest].when): smallest = right if smallest == i: break self._timers_swap(i, smallest) i = smallest return out def create_future(self, T: type = NoneType): return Future[T](loop=self) def create_task(self, coro: Coroutine, name: Optional[str] = None): self._ensure_open() task = Task(coro, loop=self, name=name) work = WorkItem(task) self._call_soon(work) return task @extend class Future: def __init__(self, loop: Optional[EventLoop] = None): self._exception = None self._lock = Lock() if loop is None: if _running_loop is not None: self._loop = _running_loop else: self._loop = get_event_loop() else: self._loop = loop self._cancel_msg = '' self._done_callbacks = Ptr[WorkItem]() self._done_callbacks_len = 0 self._done_callbacks_cap = 0 self._state = _FUTURE_STATE_PENDING def _result_size(self): return gc.sizeof(R) def _reset_callbacks(self): self._done_callbacks = Ptr[WorkItem]() self._done_callbacks_len = 0 self._done_callbacks_cap = 0 def add_done_callback(self, callback): self._add_done_callback( WorkItem(_future_callback_wrapper(callback, self))) def _add_done_callback(self, work: WorkItem, add_if_done: Literal[bool] = True): lock = self._lock lock.acquire() if self.done(): lock.release() if add_if_done: self._loop._call_soon(work) return False n = self._done_callbacks_len m = self._done_callbacks_cap if m == 0: self._done_callbacks = Ptr[WorkItem](1) self._done_callbacks_cap = 1 elif n >= m: new_m = m * 2 sz = gc.sizeof(WorkItem) self._done_callbacks = Ptr[WorkItem]( gc.realloc( self._done_callbacks.as_byte(), new_m * sz, m * sz)) self._done_callbacks_cap = new_m self._done_callbacks[n] = work self._done_callbacks_len += 1 lock.release() return True def _schedule_callbacks(self, callbacks: Ptr[WorkItem], num_callbacks: int): for i in range(num_callbacks): item = callbacks[i] if item.task: task = item.raw_task() waiting_on_lock = task._waiting_on_lock if waiting_on_lock != self._lock.p: continue task._waiting_on_task = None task._waiting_on_lock = cobj() self._loop._call_soon(item) if callbacks: gc.free(callbacks.as_byte()) def result(self): with self._lock: state = self._state if state == _FUTURE_STATE_CANCELLED: raise CancelledError(self._cancel_msg) elif state == _FUTURE_STATE_EXCEPTION: raise self._exception.__val__() elif state == _FUTURE_STATE_PENDING: raise InvalidStateError("Result is not set.") else: return self._result def _finish(self): if self.done(): raise InvalidStateError("Invalid state") self._state = _FUTURE_STATE_FINISHED callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() return callbacks, num_callbacks def _finish_with_exception(self, exception: BaseException): if self.done(): raise InvalidStateError("Invalid state") self._state = _FUTURE_STATE_EXCEPTION self._exception = exception callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() return callbacks, num_callbacks def set_result(self, result: R): callbacks = Ptr[WorkItem]() num_callbacks = 0 with self._lock: if self.done(): raise InvalidStateError("Invalid state") self._result = result self._state = _FUTURE_STATE_FINISHED callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() self._schedule_callbacks(callbacks, num_callbacks) def _set_result_if_not_done(self, result: R): callbacks = Ptr[WorkItem]() num_callbacks = 0 with self._lock: if self.done(): return False self._result = result self._state = _FUTURE_STATE_FINISHED callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() self._schedule_callbacks(callbacks, num_callbacks) return True def set_exception(self, exception: BaseException): callbacks = Ptr[WorkItem]() num_callbacks = 0 with self._lock: if self.done(): raise InvalidStateError("Invalid state") self._exception = exception self._state = _FUTURE_STATE_EXCEPTION callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() self._schedule_callbacks(callbacks, num_callbacks) def _set_exception_if_not_done(self, exception: BaseException): callbacks = Ptr[WorkItem]() num_callbacks = 0 with self._lock: if self.done(): return False self._exception = exception self._state = _FUTURE_STATE_EXCEPTION callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() self._schedule_callbacks(callbacks, num_callbacks) return True def cancelled(self): return self._state == _FUTURE_STATE_CANCELLED def done(self): return self._state != _FUTURE_STATE_PENDING def cancel(self, msg: Optional[str] = None): callbacks = Ptr[WorkItem]() num_callbacks = 0 with self._lock: if self.done(): return False self._state = _FUTURE_STATE_CANCELLED if msg is not None: self._cancel_msg = msg callbacks = self._done_callbacks num_callbacks = self._done_callbacks_len self._reset_callbacks() self._schedule_callbacks(callbacks, num_callbacks) return True def get_loop(self): return self._loop def exception(self) -> Optional[BaseException]: with self._lock: state = self._state if state == _FUTURE_STATE_CANCELLED: raise CancelledError(self._cancel_msg) elif state == _FUTURE_STATE_EXCEPTION: return self._exception.__val__() elif state == _FUTURE_STATE_PENDING: raise InvalidStateError("Exception is not set.") else: return None async def __await__(self): return await self _default_task_name_counter = 1 def _default_task_name(): global _default_task_name_counter n = _default_task_name_counter _default_task_name_counter += 1 return f'Task-{n}' @extend class Task: def __init__(self, coro: Coroutine[R], loop: Optional[EventLoop] = None, name: Optional[str] = None): super().__init__(loop) if name is None: self._name = _default_task_name() else: self._name = name self._coro = coro.__raw__() self._waiting_on_task = None self._waiting_on_lock = cobj() def get_name(self): return self._name def set_name(self, value: str): self._name = value def get_coro(self, T: type = NoneType) -> Coroutine[T]: return Coroutine[T](self._coro) def add_done_callback(self, callback): super().add_done_callback(callback) def _add_done_callback(self, work: WorkItem, add_if_done: Literal[bool] = True): return super()._add_done_callback(work, add_if_done) def _schedule_callbacks(self, callbacks: Ptr[WorkItem], num_callbacks: int): super()._schedule_callbacks(callbacks, num_callbacks) def get_loop(self): return super().get_loop() def done(self): return super().done() def cancelled(self): return super().cancelled() def result(self): return super().result() def _finish(self): return super()._finish() def _finish_with_exception(self, exception: BaseException): return super()._finish_with_exception(exception) def cancel(self, msg: Optional[str] = None): if not super().cancel(msg): return False with self._lock: waiting_on_task = self._waiting_on_task if waiting_on_task is not None: waiting_on_task.cancel(msg) self._waiting_on_task = None self._waiting_on_lock = cobj() self._loop._call_soon_no_duplicate(WorkItem(self)) return True async def __await__(self): return await self def _wait_on(future): loop = get_running_loop() if loop is not future.get_loop(): raise RuntimeError("running loop is not the same as task loop") work_curr = loop._work_curr if work_curr is None: return False added = future._add_done_callback(work_curr, add_if_done=False) if added and work_curr.task: task = work_curr.raw_task() task._waiting_on_lock = future._lock.p if isinstance(future, Task): task._waiting_on_task = _raw_task(future) return added def _requeue(): loop = get_running_loop() loop._call_soon(loop._work_curr) def _promise(coro): g = Generator[coro.T](coro.__raw__()) return g.__promise__()[0] def _done(coro): g = Generator[coro.T](coro.__raw__()) return g.__done__() def _resume(coro): g = Generator[coro.T](coro.__raw__()) g.__resume__() def _curr_task(): return get_running_loop()._work_curr.raw_task() def _is_waiting(task): return bool(task._waiting_on_lock) def _cancel_checkpoint(): loop = get_running_loop() work_curr = loop._work_curr if work_curr is not None and work_curr.task: task = work_curr.raw_task() if task.cancelled(): raise CancelledError(task._cancel_msg) def new_event_loop(): return EventLoop() def set_event_loop(loop: EventLoop): global _current_loop _current_loop = loop def get_event_loop(): global _current_loop if _current_loop is not None: return _current_loop loop = new_event_loop() _current_loop = loop return loop def get_running_loop() -> EventLoop: if _running_loop is None: raise RuntimeError("no running event loop") return _running_loop def create_task(coro, name: Optional[str] = None): return get_running_loop().create_task(coro, name=name) def isfuture(obj) -> Literal[bool]: return isinstance(obj, Future) def iscoroutine(obj) -> Literal[bool]: return isinstance(obj, Coroutine) def ensure_future(obj, loop: Optional[EventLoop] = None): if isfuture(obj): return obj if iscoroutine(obj): return (loop if loop is not None else get_event_loop()).create_task(obj) compile_error("An asyncio.Future, a coroutine or an awaitable is required") def current_task() -> Optional[Task[None]]: loop = get_running_loop() work_curr = loop._work_curr if work_curr is not None and work_curr.task: return work_curr.raw_task() else: return None def run(coro, debug=None, loop_factory=None): if loop_factory is not None: loop = loop_factory() else: loop = get_event_loop() task = loop.create_task(coro) return loop.run_until_complete(task) async def sleep(delay: float, result=None): loop = get_running_loop() future = loop.create_future(type(result)) if delay <= 0.0: loop.call_soon(future.set_result, result) else: loop.call_later(delay, future.set_result, result) return await future def gather(*aws): @pure @llvm def zero(T: type) -> T: ret {=T} zeroinitializer @pure @derives @llvm def gep(p: Ptr[T], idx: Literal[int], R: type, T: type) -> Ptr[R]: %q = getelementptr {=T}, ptr %p, i32 0, i32 {=idx} ret ptr %q @nocapture @llvm def atomic_decrement(i: Ptr[int]) -> int: %j = atomicrmw sub ptr %i, i64 1 seq_cst ret i64 %j @tuple class GatherCallback[T]: payload: T def __call__(self, future): outer, results, my_result, remaining = self.payload with future._lock: if future._state == _FUTURE_STATE_FINISHED: my_result[0] = future._result elif future._state == _FUTURE_STATE_EXCEPTION: outer._set_exception_if_not_done(future._exception) return if atomic_decrement(remaining) == 1: outer._set_result_if_not_done(results[0]) @tuple class CancelCallback[T]: futures: T def __call__(self, future): if future.cancelled(): for f in self.futures: f.cancel() loop = get_running_loop() futures = tuple(ensure_future(a, loop=loop) for a in aws) ret_type = type(tuple(zero(f.R) for f in futures)) outer = loop.create_future(ret_type) n: Literal[int] = static.len(futures) if n == 0: outer.set_result(()) return outer results = Ptr[ret_type](1) remaining = Ptr[int](1) remaining[0] = n for i in static.range(n): my_future = futures[i] my_result = gep(results, i, my_future.R) payload = (outer, results, my_result, remaining) my_future.add_done_callback(GatherCallback(payload)) outer.add_done_callback(CancelCallback(futures)) return outer