diff --git a/src/main.rs b/src/main.rs index 1bbd8136..9674abf5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ extern crate alloc; use alloc::{boxed::Box, rc::Rc, vec, vec::Vec}; -use blog_os::multitasking::{thread::Thread, with_scheduler}; +use blog_os::multitasking::{self, thread::Thread, with_scheduler}; use blog_os::{print, println}; use bootloader::{entry_point, BootInfo}; use core::panic::PanicInfo; @@ -55,33 +55,37 @@ fn kernel_main(boot_info: &'static BootInfo) -> ! { #[cfg(test)] test_main(); + let idle_thread = Thread::create(idle_thread, 2, &mut mapper, &mut frame_allocator).unwrap(); + with_scheduler(|s| s.set_idle_thread(idle_thread)); + for _ in 0..10 { let thread = Thread::create(thread_entry, 2, &mut mapper, &mut frame_allocator).unwrap(); with_scheduler(|s| s.add_new_thread(thread)); } - let thread = Thread::create_from_closure( - || loop { - print!("{}", with_scheduler(|s| s.current_thread_id()).as_u64()); - x86_64::instructions::hlt(); - }, - 2, - &mut mapper, - &mut frame_allocator, - ) - .unwrap(); + let thread = + Thread::create_from_closure(|| thread_entry(), 2, &mut mapper, &mut frame_allocator) + .unwrap(); with_scheduler(|s| s.add_new_thread(thread)); println!("It did not crash!"); thread_entry(); } -fn thread_entry() -> ! { +fn idle_thread() -> ! { loop { - print!("{}", with_scheduler(|s| s.current_thread_id()).as_u64()); x86_64::instructions::hlt(); } } +fn thread_entry() -> ! { + let thread_id = with_scheduler(|s| s.current_thread_id()).as_u64(); + for _ in 0..=thread_id { + print!("{}", thread_id); + x86_64::instructions::hlt(); + } + multitasking::exit_thread(); +} + /// This function is called on panic. #[cfg(not(test))] #[panic_handler] diff --git a/src/multitasking/context_switch.rs b/src/multitasking/context_switch.rs index 2e6ef279..a158f357 100644 --- a/src/multitasking/context_switch.rs +++ b/src/multitasking/context_switch.rs @@ -1,4 +1,4 @@ -use super::with_scheduler; +use super::{with_scheduler, SwitchReason}; use crate::multitasking::thread::ThreadId; use alloc::boxed::Box; use core::mem; @@ -41,11 +41,15 @@ impl Stack { } } -pub unsafe fn context_switch_to(new_stack_pointer: VirtAddr, prev_thread_id: ThreadId) { +pub unsafe fn context_switch_to( + new_stack_pointer: VirtAddr, + prev_thread_id: ThreadId, + switch_reason: SwitchReason, +) { asm!( "call asm_context_switch" : - : "{rdi}"(new_stack_pointer), "{rsi}"(prev_thread_id) + : "{rdi}"(new_stack_pointer), "{rsi}"(prev_thread_id), "{rdx}"(switch_reason as u64) : "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rpb", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "rflags", "memory" : "intel", "volatile" @@ -72,8 +76,12 @@ global_asm!( ); #[no_mangle] -pub extern "C" fn add_paused_thread(paused_stack_pointer: VirtAddr, paused_thread_id: ThreadId) { - with_scheduler(|s| s.add_paused_thread(paused_stack_pointer, paused_thread_id)); +pub extern "C" fn add_paused_thread( + paused_stack_pointer: VirtAddr, + paused_thread_id: ThreadId, + switch_reason: SwitchReason, +) { + with_scheduler(|s| s.add_paused_thread(paused_stack_pointer, paused_thread_id, switch_reason)); } #[naked] diff --git a/src/multitasking/mod.rs b/src/multitasking/mod.rs index 2cd08b10..082f23ef 100644 --- a/src/multitasking/mod.rs +++ b/src/multitasking/mod.rs @@ -6,12 +6,42 @@ pub mod thread; static SCHEDULER: spin::Mutex> = spin::Mutex::new(None); +#[repr(u64)] +pub enum SwitchReason { + Paused, + Blocked, + Exit, +} + pub fn invoke_scheduler() { let next = SCHEDULER .try_lock() .and_then(|mut scheduler| scheduler.as_mut().and_then(|s| s.schedule())); if let Some((next_stack_pointer, prev_thread_id)) = next { - unsafe { context_switch::context_switch_to(next_stack_pointer, prev_thread_id) }; + unsafe { + context_switch::context_switch_to( + next_stack_pointer, + prev_thread_id, + SwitchReason::Paused, + ) + }; + } +} + +pub fn exit_thread() -> ! { + let next = with_scheduler(|s| s.schedule()); + match next { + Some((next_stack_pointer, prev_thread_id)) => { + unsafe { + context_switch::context_switch_to( + next_stack_pointer, + prev_thread_id, + SwitchReason::Exit, + ) + } + unreachable!("finished thread continued") + } + None => panic!("can't exit last thread"), } } diff --git a/src/multitasking/scheduler.rs b/src/multitasking/scheduler.rs index b52bc156..e09a709f 100644 --- a/src/multitasking/scheduler.rs +++ b/src/multitasking/scheduler.rs @@ -1,12 +1,16 @@ +use super::SwitchReason; use crate::multitasking::thread::{Thread, ThreadId}; -use alloc::collections::{BTreeMap, VecDeque}; +use alloc::collections::{BTreeMap, BTreeSet, VecDeque}; use core::mem; use x86_64::VirtAddr; pub struct Scheduler { threads: BTreeMap, + idle_thread_id: Option, current_thread_id: ThreadId, paused_threads: VecDeque, + blocked_threads: BTreeSet, + wakeups: BTreeSet, } impl Scheduler { @@ -21,6 +25,9 @@ impl Scheduler { threads, current_thread_id: root_id, paused_threads: VecDeque::new(), + blocked_threads: BTreeSet::new(), + wakeups: BTreeSet::new(), + idle_thread_id: None, } } @@ -29,7 +36,11 @@ impl Scheduler { } pub fn schedule(&mut self) -> Option<(VirtAddr, ThreadId)> { - if let Some(next_id) = self.next_thread() { + let mut next_thread_id = self.next_thread(); + if next_thread_id.is_none() && Some(self.current_thread_id) != self.idle_thread_id { + next_thread_id = self.idle_thread_id + } + if let Some(next_id) = next_thread_id { let next_thread = self .threads .get_mut(&next_id) @@ -49,6 +60,7 @@ impl Scheduler { &mut self, paused_stack_pointer: VirtAddr, paused_thread_id: ThreadId, + switch_reason: SwitchReason, ) { let paused_thread = self .threads @@ -58,7 +70,23 @@ impl Scheduler { .stack_pointer() .replace(paused_stack_pointer) .expect_none("running thread should have stack pointer set to None"); - self.paused_threads.push_back(paused_thread_id); + if Some(paused_thread_id) == self.idle_thread_id { + return; // do nothing + } + match switch_reason { + SwitchReason::Paused => self.paused_threads.push_back(paused_thread_id), + SwitchReason::Blocked => { + self.blocked_threads.insert(paused_thread_id); + self.check_for_wakeup(paused_thread_id); + } + SwitchReason::Exit => { + let thread = self + .threads + .remove(&paused_thread_id) + .expect("thread not found"); + // TODO: free stack memory again + } + } } pub fn add_new_thread(&mut self, thread: Thread) { @@ -69,7 +97,24 @@ impl Scheduler { self.paused_threads.push_back(thread_id); } + pub fn set_idle_thread(&mut self, thread: Thread) { + let thread_id = thread.id(); + self.threads + .insert(thread_id, thread) + .expect_none("thread already exists"); + self.idle_thread_id + .replace(thread_id) + .expect_none("idle thread should be set only once"); + } + pub fn current_thread_id(&self) -> ThreadId { self.current_thread_id } + + fn check_for_wakeup(&mut self, thread_id: ThreadId) { + if self.wakeups.remove(&thread_id) { + assert!(self.blocked_threads.remove(&thread_id)); + self.paused_threads.push_back(thread_id); + } + } }