wdk-mutex: An idiomatic mutex for Rust Windows Kernel Drivers

An open-source idiomatic Windows Driver mutex for Rust


Intro to wdk-mutex

In this post we will look at using wdk-mutex, a library I have created to allow you to use idiomatic Rust in order to use mutex’s in the Windows Kernel in your driver. We will also look at how to spawn system threads in your driver, and how you can use your wdk-mutex between thread boundaries.

If you like this, give the repo a star on GitHub, and contributions are welcome! The project can also be found on crates.io, and the documentation at GitHub Pages.

Currently, the mutex type available

Just like a std::mutex, you are able to use it like so:

{
    let mtx = KMutex::new(0u32).unwrap();
    let lock = mtx.lock().unwrap();

    // If T implements display, you do not need to dereference the lock to print.
    println!("The value is: {}", lock);
} // Mutex will become unlocked as it is managed via RAII 

Using wdk-mutex Across Thread Boundaries in Windows Drivers

A somewhat more complex example includes crossing thread boundaries. For now (until an update hopefully coming soon), there are 2 ways I can identify to utilise the mutex and its inner <T> .

  1. Utilise a static AtomicPtr<KMutex<T>> - if you take this option you MUST ensure in your driver exit routine you clean up the memory you allocate which is pointed to by the AtomicPtr
  2. Utilise the Device Extension so that Windows will clean up the memory when the driver unloads, so long as you set it up correctly. I’d recommend still an atomic global pointer to the Device Extension field of your DriverObject so you can access it globally.

Here is an example of accessing the mutex between thread boundaries:

pub static HEAP_MTX_PTR: AtomicPtr<KMutex<u32>> = AtomicPtr::new(null_mut());

pub fn test_multithread_mutex() {

    //
    // Prepare global static for access in multiple threads.
    //

    let heap_mtx = Box::new(KMutex::new(0u32).unwrap());
    let heap_mtx_ptr = Box::into_raw(heap_mtx);
    HEAP_MTX_PTR.store(heap_mtx_ptr, Ordering::SeqCst);

    //
    // spawn x threads to test
    //
    for _ in 0..3 {
        let mut thread_handle: HANDLE = null_mut();

        let status = unsafe {
            PsCreateSystemThread(
                &mut thread_handle, 
                0, 
                null_mut::<OBJECT_ATTRIBUTES>(), 
                null_mut(),
                null_mut::<CLIENT_ID>(), 
                Some(callback_fn), 
                null_mut(),
            )
        };

        println!("[i] Thread status: {status}");
    }
}

unsafe extern "C" fn callback_fn(_: *mut c_void) {
    for _ in 0..1500 {
        let p = HEAP_MTX_PTR.load(Ordering::SeqCst);
        if !p.is_null() {
            let p = unsafe { &*p };
            let mut lock = p.lock().unwrap();
            // println!("Got the lock before change! {}", *lock);
            *lock += 1;
            println!("After the change: {}", *lock);
        }
    }

    // Proof of threads acting concurrently; if these printed after x iterations from the for loop, that
    // would indicate that it is not running concurrently. 
    println!("THREAD FINISHED.");
}

Initializing KMUTEX with wdk-mutex in Rust

Using a KMUTEX, is done with the below functions:

As I was developing this, I was having some difficulty with different page pools, IRQLs and the lifetime of certain objects; so I elected to split the KMutex type in wdk-mutex into an outer value, which the end user interacts with, that points to its inner value, which is memory for a KMUTEX and the data the mutex protects.

To avoid the memory issues, creating a new KMutex will allocate memory in the non-paged pool to store the data for the KMutexInner which looks like the below:

struct KMutexInner<T> {
    mutex: KMUTEX,
    data: T,
}

After this memory is allocated in the non-paged pool, we can write our KMUTEX to it with a default trait initialisation, and the user’s data <T>.

At this point, all we have to return to the user is a pointer to this allocated object, wrapped in the KMutex type.

Acquiring a lock

When you acquire a lock on the KMutex performs a check of the IRQL, ensuring that you are at a sufficient level as specified in the windows docs, and if so, calls the KeWaitForSingleObject function.

Once the kernel has granted you access to the KMUTEX, we return a RAII style mutex guard, in keeping with how std mutex works.

Drop

When the mutex guard falls out of scope, the guard itself will be dropped, thanks to the Drop trait. When the mutex guard goes out of scope, we want to unlock the mutex so other threads can lock it. This is implemented simply with KeReleaseMutex like so:

impl<T> Drop for KMutexGuard<'_, T> {
    fn drop(&mut self) {
        unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) }; 
    }
}

The MSDN documentation specifies some IRQL rules for using KeReleaseMutex; so there is a chance this will panic if the IRQL is too high. To get around this, I provided a function drop_safe which will attempt to release the mutex, returning an error if the callers IRQL is too high:

impl<T> KMutexGuard<'_, T> {
    /// Safely drop the KMutexGuard, an alternative to RAII.
    /// 
    /// This function checks the IRQL before attempting to drop the guard. 
    /// 
    /// # Errors
    /// 
    /// If the IRQL > DISPATCH_LEVEL, no unlock will occur and a DriverMutexError will be returned to the 
    /// caller.
    /// 
    /// # IRQL
    /// 
    /// This function is safe to call at any IRQL, but it will not release the mutex if IRQL > DISPATCH_LEVEL
    pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {

        let irql = unsafe {KeGetCurrentIrql()};
        if irql > DISPATCH_LEVEL as u8 {
            if cfg!(feature = "debug") {
                println!("[wdk-mutex] [-] Unable to safely drop the KMUTEX. Calling IRQL is too high: {}", irql);
            }
            return Err(DriverMutexError::IrqlTooHigh);
        }

        unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) }; 

        Ok(())
    }
}

Coming soon

I have a few features I would like to implement, such as a FAST_MUTEX as outlined in my previous blog post. Aside from that I have an idea as to how I want to make the API more ergonomic to work with for developers, reducing the overhead of Box’ing and having statics all over the place.

I’ll be testing these on an experimental branch as time goes on, and I have made a new repository designated for ‘unit testing’ of the wdk-mutex crate, you can find it here: wdk_mutex_tests.

If you have any ideas, suggestions etc, please engage with me on the repo in issues or discussions! If you liked this post or project, please show the repo some love and give it a star on GitHub!

Joining system threads

Finally, I wanted to touch on how to join system threads. Much like in programming in usermode, if we were to spawn our threads and go and do other things in our main thread; the two things will never interfere. This is fine for certain thread activities such as threads which are designed to handle communication or UI, you don’t want these blocking the main thread, but sometimes you want threads to execute and then wait for their completion before continuing on the main thread. This is what thread joining is.

To keep this simple, there is a pattern we need to employ in the kernel to wait on threads.

  1. Spawn the threads, obtaining a thread handle.
  2. Converting the thread handle to a PETHREAD.
  3. Wait on the PETHREAD (aka our thread).
  4. Close the thread handle.
  5. Decrement the reference count obtained for the PETHREAD.

The TL;DR just show me the code:

//
// spawn 3 threads
//
for _ in 0..3 {
    let mut thread_handle: HANDLE = null_mut();

    let res = unsafe {
        PsCreateSystemThread(
            &mut thread_handle, 
            0, 
            null_mut::<OBJECT_ATTRIBUTES>(), 
            null_mut(),
            null_mut::<CLIENT_ID>(), 
            Some(KMutexTest::callback_test_multithread_mutex_global_static), 
            null_mut(),
        )
    };

    if res == STATUS_SUCCESS {
        th.push(thread_handle);
    }
}


//
// Join the thread handles
//

for thread_handle in th {
    if !thread_handle.is_null() && unsafe{KeGetCurrentIrql()} <= APC_LEVEL as u8 {
        let mut thread_obj: PVOID = null_mut();
        let ref_status = unsafe {
            ObReferenceObjectByHandle(
                thread_handle,
                THREAD_ALL_ACCESS,
                null_mut(),
                KernelMode as i8,
                &mut thread_obj,
                null_mut(),
            )
        };
        unsafe { let _ = ZwClose(thread_handle); };

        if ref_status == STATUS_SUCCESS {
            unsafe {
                let _ = KeWaitForSingleObject(
                    thread_obj,
                    Executive,
                    KernelMode as i8,
                    FALSE as u8,
                    null_mut(),
                );
            }
            
        unsafe { ObfDereferenceObject(thread_obj) };
        }
        
    }
}

// now the main thread is resumed