diff --git a/src/lib.rs b/src/lib.rs index c9809d0d..2ce93a66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ #![feature(const_fn)] #![feature(alloc_layout_extra)] #![feature(const_in_array_repeat_expressions)] +#![feature(wake_trait)] #![test_runner(crate::test_runner)] #![reexport_test_harness_main = "test_main"] diff --git a/src/main.rs b/src/main.rs index b1384dab..1ce2165f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ extern crate alloc; use blog_os::println; -use blog_os::task::{keyboard, simple_executor::SimpleExecutor, Task}; +use blog_os::task::{executor::Executor, keyboard, Task}; use bootloader::{entry_point, BootInfo}; use core::panic::PanicInfo; @@ -27,16 +27,13 @@ fn kernel_main(boot_info: &'static BootInfo) -> ! { allocator::init_heap(&mut mapper, &mut frame_allocator).expect("heap initialization failed"); - let mut executor = SimpleExecutor::new(); - executor.spawn(Task::new(example_task())); - executor.spawn(Task::new(keyboard::print_keypresses())); - executor.run(); - #[cfg(test)] test_main(); - println!("It did not crash!"); - blog_os::hlt_loop(); + let mut executor = Executor::new(); + executor.spawn(Task::new(example_task())); + executor.spawn(Task::new(keyboard::print_keypresses())); + executor.run(); } /// This function is called on panic. diff --git a/src/task/executor.rs b/src/task/executor.rs new file mode 100644 index 00000000..c7211d79 --- /dev/null +++ b/src/task/executor.rs @@ -0,0 +1,95 @@ +use super::{Task, TaskId}; +use alloc::{ + collections::{BTreeMap, VecDeque}, + sync::Arc, + task::Wake, +}; +use core::task::{Context, Poll, Waker}; +use crossbeam_queue::ArrayQueue; + +pub struct Executor { + task_queue: VecDeque, + waiting_tasks: BTreeMap, + wake_queue: Arc>, + waker_cache: BTreeMap, +} + +impl Executor { + pub fn new() -> Self { + Executor { + task_queue: VecDeque::new(), + waiting_tasks: BTreeMap::new(), + wake_queue: Arc::new(ArrayQueue::new(100)), + waker_cache: BTreeMap::new(), + } + } + + pub fn spawn(&mut self, task: Task) { + self.task_queue.push_back(task) + } + + pub fn run(&mut self) -> ! { + loop { + self.wake_tasks(); + self.run_ready_tasks(); + } + } + + fn run_ready_tasks(&mut self) { + while let Some(mut task) = self.task_queue.pop_front() { + let task_id = task.id(); + if !self.waker_cache.contains_key(&task_id) { + self.waker_cache.insert(task_id, self.create_waker(task_id)); + } + let waker = self.waker_cache.get(&task_id).expect("should exist"); + let mut context = Context::from_waker(waker); + match task.poll(&mut context) { + Poll::Ready(()) => { + // task done -> remove cached waker + self.waker_cache.remove(&task_id); + } + Poll::Pending => { + if self.waiting_tasks.insert(task_id, task).is_some() { + panic!("task with same ID already in waiting_tasks"); + } + } + } + } + } + + fn wake_tasks(&mut self) { + while let Ok(task_id) = self.wake_queue.pop() { + if let Some(task) = self.waiting_tasks.remove(&task_id) { + self.task_queue.push_back(task); + } + } + } + + fn create_waker(&self, task_id: TaskId) -> Waker { + Waker::from(Arc::new(TaskWaker { + task_id, + wake_queue: self.wake_queue.clone(), + })) + } +} + +struct TaskWaker { + task_id: TaskId, + wake_queue: Arc>, +} + +impl TaskWaker { + fn wake_task(&self) { + self.wake_queue.push(self.task_id).expect("wake_queue full"); + } +} + +impl Wake for TaskWaker { + fn wake(self: Arc) { + self.wake_task(); + } + + fn wake_by_ref(self: &Arc) { + self.wake_task(); + } +} diff --git a/src/task/mod.rs b/src/task/mod.rs index 7624cca7..a69650a3 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -5,6 +5,7 @@ use core::{ task::{Context, Poll}, }; +pub mod executor; pub mod keyboard; pub mod simple_executor; @@ -22,4 +23,14 @@ impl Task { fn poll(&mut self, context: &mut Context) -> Poll<()> { self.future.as_mut().poll(context) } + + fn id(&self) -> TaskId { + use core::ops::Deref; + + let addr = Pin::deref(&self.future) as *const _ as *const () as usize; + TaskId(addr) + } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +struct TaskId(usize);