wdk-mutex: An idiomatic mutex for Rust Windows Kernel Drivers
An open-source idiomatic Windows Driver mutex for Rust
Intro to wdk-mutex
In this post we will look at using wdk-mutex
, a library I have created to allow you to use idiomatic Rust in order to use mutex’s in the Windows Kernel in your driver. We will also look at how to
spawn system threads in your driver, and how you can use your wdk-mutex
between thread boundaries.
If you like this, give the repo a star on GitHub, and contributions are welcome! The project can also be found on crates.io, and the documentation at GitHub Pages.
wdk-mutex
supports both a KMUTEX
and a FAST_MUTEX
mutex type.
Just like a std::mutex
, you are able to use it like so:
{
let mtx = KMutex::new(0u32).unwrap();
let lock = mtx.lock().unwrap();
// If T implements display, you do not need to dereference the lock to print.
println!("The value is: {}", lock);
} // Mutex will become unlocked as it is managed via RAII
Using wdk-mutex Across Thread Boundaries in Windows Drivers
A somewhat more complex example includes crossing thread boundaries. There are two best ways to wrap an inner <T>
with a wdk-mutex
.
- Utilise the Grt module of the
wdk-mutex
which stands for Global Reference Tracker (this is the easiest way to do it). - Utilise the Device Extension so that Windows will clean up the memory when the driver unloads, so long as you
set it up correctly. I’d recommend still an atomic global pointer to the Device Extension field of your
DriverObject
so you can access it globally.
Here is an example of accessing the mutex between thread boundaries using the Grt
module:
// Initialise the mutex
#[export_name = "DriverEntry"]
pub unsafe extern "system" fn driver_entry(
driver: &mut DRIVER_OBJECT,
registry_path: PCUNICODE_STRING,
) -> NTSTATUS {
if let Err(e) = Grt::init() {
println!("Error creating Grt!: {:?}", e);
return STATUS_UNSUCCESSFUL;
}
// ...
my_function();
}
// Register a new Mutex in the `Grt` of value 0u32:
pub fn my_function() {
Grt::register_kmutex("my_test_mutex", 0u32);
// spawn a thread to do some work
let _res_ = unsafe {
PsCreateSystemThread(
&mut thread_handle,
0,
null_mut::<OBJECT_ATTRIBUTES>(),
null_mut(),
null_mut::<CLIENT_ID>(),
Some(my_thread_fn_pointer),
null_mut(),
)
};
}
unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
if let Err(e) = my_mutex {
println!("Error in thread: {:?}", e);
return;
}
let mut lock = my_mutex.unwrap().lock().unwrap();
*lock += 1;
}
// Destroy the Grt to prevent memory leak on DriverExit
extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
unsafe {Grt::destroy()};
}
Initializing KMUTEX with wdk-mutex in Rust
Using a KMUTEX
, is done with the below functions:
As I was developing this, I was having some difficulty with different page pools, IRQLs and the lifetime of certain objects; so I
elected to split the KMutex
type in wdk-mutex
into an outer value, which the end user interacts with, that points to its inner
value, which is memory for a KMUTEX
and the data the mutex protects.
To avoid the memory issues, creating a new KMutex will allocate memory in the non-paged pool to store the data for the KMutexInner
which looks
like the below:
struct KMutexInner<T> {
mutex: KMUTEX,
data: T,
}
After this memory is allocated in the non-paged pool, we can write our KMUTEX to it with a default
trait initialisation, and the user’s data <T>
.
At this point, all we have to return to the user is a pointer to this allocated object, wrapped in the KMutex
type.
Acquiring a lock
When you acquire a lock on the KMutex
performs a check of the IRQL, ensuring that you are at a sufficient level as specified in the windows
docs, and if so, calls the KeWaitForSingleObject
function.
Once the kernel has granted you access to the KMUTEX
, we return a RAII style mutex guard, in keeping with how std
mutex works.
Drop
When the mutex guard falls out of scope, the guard itself will be dropped, thanks to the Drop
trait. When the mutex guard goes out of scope,
we want to unlock the mutex so other threads can lock it. This is implemented simply with KeReleaseMutex
like so:
impl<T> Drop for KMutexGuard<'_, T> {
fn drop(&mut self) {
unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) };
}
}
The MSDN documentation specifies some IRQL rules for using KeReleaseMutex
; so there is a chance this will panic if the IRQL is too high. To
get around this, I provided a function drop_safe
which will attempt to release the mutex, returning an error if the callers IRQL is too high:
impl<T> KMutexGuard<'_, T> {
/// Safely drop the KMutexGuard, an alternative to RAII.
///
/// This function checks the IRQL before attempting to drop the guard.
///
/// # Errors
///
/// If the IRQL > DISPATCH_LEVEL, no unlock will occur and a DriverMutexError will be returned to the
/// caller.
///
/// # IRQL
///
/// This function is safe to call at any IRQL, but it will not release the mutex if IRQL > DISPATCH_LEVEL
pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {
let irql = unsafe {KeGetCurrentIrql()};
if irql > DISPATCH_LEVEL as u8 {
if cfg!(feature = "debug") {
println!("[wdk-mutex] [-] Unable to safely drop the KMUTEX. Calling IRQL is too high: {}", irql);
}
return Err(DriverMutexError::IrqlTooHigh);
}
unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) };
Ok(())
}
}
Joining system threads
Finally, I wanted to touch on how to join system threads. Much like in programming in usermode, if we were to spawn our threads and go and do other things in our main thread; the two things will never interfere. This is fine for certain thread activities such as threads which are designed to handle communication or UI, you don’t want these blocking the main thread, but sometimes you want threads to execute and then wait for their completion before continuing on the main thread. This is what thread joining is.
To keep this simple, there is a pattern we need to employ in the kernel to wait on threads.
- Spawn the threads, obtaining a thread handle.
- Converting the thread handle to a PETHREAD.
- Wait on the PETHREAD (aka our thread).
- Close the thread handle.
- Decrement the reference count obtained for the PETHREAD.
The TL;DR just show me the code:
//
// spawn 3 threads
//
for _ in 0..3 {
let mut thread_handle: HANDLE = null_mut();
let res = unsafe {
PsCreateSystemThread(
&mut thread_handle,
0,
null_mut::<OBJECT_ATTRIBUTES>(),
null_mut(),
null_mut::<CLIENT_ID>(),
Some(KMutexTest::callback_test_multithread_mutex_global_static),
null_mut(),
)
};
if res == STATUS_SUCCESS {
th.push(thread_handle);
}
}
//
// Join the thread handles
//
for thread_handle in th {
if !thread_handle.is_null() && unsafe{KeGetCurrentIrql()} <= APC_LEVEL as u8 {
let mut thread_obj: PVOID = null_mut();
let ref_status = unsafe {
ObReferenceObjectByHandle(
thread_handle,
THREAD_ALL_ACCESS,
null_mut(),
KernelMode as i8,
&mut thread_obj,
null_mut(),
)
};
unsafe { let _ = ZwClose(thread_handle); };
if ref_status == STATUS_SUCCESS {
unsafe {
let _ = KeWaitForSingleObject(
thread_obj,
Executive,
KernelMode as i8,
FALSE as u8,
null_mut(),
);
}
unsafe { ObfDereferenceObject(thread_obj) };
}
}
}
// now the main thread is resumed