diff --git a/kernel/src/asynk/mod.rs b/kernel/src/asynk/mod.rs index 6a5ca4cd..97b15323 100644 --- a/kernel/src/asynk/mod.rs +++ b/kernel/src/asynk/mod.rs @@ -32,7 +32,7 @@ use core::{ mem::MaybeUninit, pin::Pin, sync::atomic::{AtomicUsize, Ordering}, - task::{Context, Poll, Waker}, + task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, }; impl_simple_intrusive_adapter!(TaskletNode, Tasklet, node); @@ -43,6 +43,7 @@ pub struct Tasklet { lock: ISpinLock, future: Pin>>, blocked: Option, + state: AtomicUsize, } impl Tasklet { @@ -52,6 +53,7 @@ impl Tasklet { future, lock: ISpinLock::new(), blocked: None, + state: AtomicUsize::new(TASKLET_IDLE), } } @@ -61,6 +63,12 @@ impl Tasklet { } type AsyncWorkQueue = ArcBufferingQueue; +const TASKLET_IDLE: usize = 0; +const TASKLET_QUEUED: usize = 1; +const TASKLET_POLLING: usize = 2; +const TASKLET_WOKEN: usize = 3; +const TASKLET_COMPLETED: usize = 4; + static mut POLLER_STORAGE: SystemThreadStorage = SystemThreadStorage::new(ThreadKind::AsyncPoller); static mut POLLER: MaybeUninit = MaybeUninit::zeroed(); static POLLER_WAKER: AtomicUsize = AtomicUsize::new(0); @@ -89,12 +97,11 @@ fn create_tasklet(future: impl Future + 'static) -> Arc { pub fn block_on(future: impl Future + Send + 'static) { let t = scheduler::current_thread(); - let mut task = create_tasklet(future); - enqueue_active_tasklet(task.clone()); + let task = create_tasklet(future); let mut w = task.lock(); w.blocked = Some(t.clone()); t.disable_preempt(); - wake_poller(); + wake_tasklet(task.clone()); scheduler::suspend_me_for::(Tick::MAX, Some(w)); t.enable_preempt(); } @@ -106,19 +113,23 @@ fn wake_poller() { pub fn spawn(future: impl Future + Send + 'static) -> Arc { let task = create_tasklet(future); - enqueue_active_tasklet(task.clone()); - wake_poller(); + wake_tasklet(task.clone()); task } pub fn enqueue_active_tasklet(t: Arc) { + wake_tasklet(t); +} + +fn enqueue_queued_tasklet(t: Arc) { #[cfg(debugging_scheduler)] crate::trace!( "[TH:0x{:x}] is enqueuing tasklet", scheduler::current_thread_id() ); let mut q = ASYNC_WORK_QUEUE.get_active_queue(); - q.push_back(t); + let ok = q.push_back(t); + debug_assert!(ok); #[cfg(debugging_scheduler)] crate::trace!( "[TH:0x{:x}] has enqueued tasklet", @@ -126,28 +137,125 @@ pub fn enqueue_active_tasklet(t: Arc) { ); } +fn wake_tasklet(task: Arc) { + loop { + match task.state.load(Ordering::Acquire) { + TASKLET_IDLE => { + if task + .state + .compare_exchange( + TASKLET_IDLE, + TASKLET_QUEUED, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + enqueue_queued_tasklet(task); + wake_poller(); + return; + } + } + TASKLET_POLLING => { + if task + .state + .compare_exchange( + TASKLET_POLLING, + TASKLET_WOKEN, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + return; + } + } + TASKLET_QUEUED | TASKLET_WOKEN | TASKLET_COMPLETED => return, + _ => unreachable!(), + } + } +} + +fn tasklet_waker_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + |data| { + let task = unsafe { Arc::from_raw(data as *const Tasklet) }; + let cloned = task.clone(); + core::mem::forget(task); + RawWaker::new(Arc::into_raw(cloned) as *const (), tasklet_waker_vtable()) + }, + |data| { + let task = unsafe { Arc::from_raw(data as *const Tasklet) }; + wake_tasklet(task); + }, + |data| { + let task = unsafe { Arc::from_raw(data as *const Tasklet) }; + let cloned = task.clone(); + core::mem::forget(task); + wake_tasklet(cloned); + }, + |data| { + let _ = unsafe { Arc::from_raw(data as *const Tasklet) }; + }, + ) +} + +fn tasklet_waker(task: Arc) -> Waker { + let raw_waker = RawWaker::new(Arc::into_raw(task) as *const (), tasklet_waker_vtable()); + unsafe { Waker::from_raw(raw_waker) } +} + +fn finish_pending_poll(task: Arc) { + match task.state.compare_exchange( + TASKLET_POLLING, + TASKLET_IDLE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => {} + Err(TASKLET_WOKEN) => { + task.state.store(TASKLET_QUEUED, Ordering::Release); + enqueue_queued_tasklet(task); + wake_poller(); + } + Err(state) => unreachable!("unexpected tasklet state after poll: {}", state), + } +} + fn poll_inner() { - let mut ctx = Context::from_waker(Waker::noop()); - let mut w = ASYNC_WORK_QUEUE.advance_active_queue(); - for task in w.iter() { + loop { + let Some(task) = ({ + let mut w = ASYNC_WORK_QUEUE.advance_active_queue(); + w.pop_front() + }) else { + break; + }; + + let old = task.state.compare_exchange( + TASKLET_QUEUED, + TASKLET_POLLING, + Ordering::AcqRel, + Ordering::Acquire, + ); + debug_assert_eq!(old, Ok(TASKLET_QUEUED)); + + let waker = tasklet_waker(task.clone()); + let mut ctx = Context::from_waker(&waker); let mut l = task.lock(); - if let Poll::Ready(()) = l.future.as_mut().poll(&mut ctx) { - if let Some(t) = l.blocked.take() { - let ok = scheduler::queue_ready_thread(thread::SUSPENDED, t); - debug_assert_eq!(ok, Ok(())); + match l.future.as_mut().poll(&mut ctx) { + Poll::Ready(()) => { + task.state.store(TASKLET_COMPLETED, Ordering::Release); + let blocked = l.blocked.take(); + drop(l); + if let Some(t) = blocked { + let ok = scheduler::queue_ready_thread(thread::SUSPENDED, t); + debug_assert_eq!(ok, Ok(())); + } + } + Poll::Pending => { + drop(l); + finish_pending_poll(task); } - - // In SMP case, task might be dropped immediately after being detached, - // so we need to drop the lock before detaching. - drop(l); - // If we detach the task what ever it's ready or - // pending, it would be edge-level triggered. Now - // we're using level-trigger mode conservatively. - AsyncWorkQueue::WorkList::detach(&mut unsafe { Arc::clone_from(task) }); - } else { - // FIXME: This is not an efficient impl right now. We - // might need a waker for each future, so that the poller - // doesn't need to poll all futures when woken up. } } } @@ -159,3 +267,26 @@ extern "C" fn poll() { atomic_wait::atomic_wait(&POLLER_WAKER, n, Tick::MAX); } } + +pub fn yield_now() -> impl Future { + YieldNow { polled: false } +} + +struct YieldNow { + polled: bool, +} + +impl Future for YieldNow { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.polled { + Poll::Ready(()) + } else { + let waker = cx.waker().clone(); + waker.wake_by_ref(); + self.get_mut().polled = true; + Poll::Pending + } + } +} diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 0b229037..8e984e92 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -732,6 +732,15 @@ mod tests { } } + async fn yield_now() { + asynk::yield_now().await; + } + + #[test] + fn test_yield_now() { + asynk::block_on(yield_now()); + } + #[cfg(target_abi = "eabihf")] #[test] fn test_basic_float_add_sub() {