Kani and CUDA

Posted on || 5 minute read

Table of Contents

CUDA kernels in pure Rust?!

I've been implementing an experimental feature in the cudarc crate. The goal is to write CUDA kernels in pure Rust to improve portability and safety. It's far from finished.

Currently cudarc emits PTX, the CUDA equivalent of assembly, with JIT source code compilation of C code to PTX that is then loaded at runtime:

let ptx = cudarc::nvrtc::compile_ptx("
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, const size_t numel) {
    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < numel) {
        out[i] = sin(inp[i]);
    }
}")?;

// and dynamically load it into the device
dev.load_ptx(ptx, "my_module", &["sin_kernel"])?;

For Rust PTX emission, I instead use a JIT crate compiler.

// use compile_crate_to_ptx to build kernels in pure Rust
// uses experimental ABI_PTX
let kernel_path: PathBuf = "examples/rust-kernel/src/lib.rs".into();
let kernels: Vec<Ptx> = PtxCrate::compile_crate_to_ptx(&kernel_path).unwrap();
let kernel: Ptx = kernels.first().unwrap();
dev.load_ptx(kernel.clone(), "rust_kernel", &["square_kernel"])?;

// we can also manage and clean up the build ptx files with a PtxCrate
let mut rust_ptx: PtxCrate = kernel_path.try_into().unwrap();
rust_ptx.build_ptx().unwrap();
let _kernel: &Ptx = rust_ptx.peek_kernels().unwrap().first().unwrap();
println!("cleaned successfully? {:?}", rust_ptx.clean());

The development environment is rough. Notice how I used square? I couldn't actually find sin! I'm still learning how to get the linter on board inside the kernel crate.

Kernel crate source code

// lib.rs
#![feature(abi_ptx)]        // emitting ptx (unstable)
#![feature(stdsimd)]        // simd instructions (unstable)
#![no_std]                  // CUDA compatibility

mod device;

use core::arch::nvptx::*;   // access to thread id, etc

#[panic_handler]
fn my_panic(_: &core::panic::PanicInfo) -> ! {
    loop {}
}

#[no_mangle]
pub unsafe extern "ptx-kernel" fn square_kernel(input: *const f32, output: *mut f32, size: i32) {
    /* https://doc.rust-lang.org/stable/core/arch/nvptx/index.html */
    let thread_id: i32 = _thread_idx_x();
    let block_id: i32 = _block_idx_x();
    
    let block_dim: i32 = _block_dim_x();
    let grid_dim: i32 = _grid_dim_x();
    
    let n_threads = (block_dim * grid_dim) as u64;
    
    let thread_index = 
        thread_id + 
        block_id * block_dim
    ;

    if thread_index < size {
        let value = device::square(*input.offset(thread_index as isize));
        *output.offset(thread_index as isize) = value;
    }
}

The panic_handler is for no_std crates. Here it is used to give CUDA devices instructions for when panics can't be avoided. Branches and panics drag on performance. Avoid them as much as possible.

// device.rs
#![no_std]
// a no_std fn
pub fn square(num: f32) -> f32 {
    // your device function implementation
    num * num
}

We compile as a C dynamic library and emit PTX, targeing nvptx64.

# Cargo.toml
[package]
name = "rust-kernel"
version = "0.1.0"
edition = "2021"

[lib]
name = "kernel"
# build = "cargo +nightly rustc --lib --target nvptx64-nvidia-cuda --release -- --emit asm"
crate-type = ["cdylib"]

Emitted PTX

The PTX itself is pretty clean.

//
// Generated by LLVM NVPTX Back-End
//

.version 6.0
.target sm_30
.address_size 64

	// .globl	square_kernel

.visible .entry square_kernel(
	.param .u64 square_kernel_param_0,
	.param .u64 square_kernel_param_1,
	.param .u32 square_kernel_param_2
)
{
	.reg .pred 	%p<2>;
	.reg .b32 	%r<6>;
	.reg .f32 	%f<3>;
	.reg .b64 	%rd<8>;

	ld.param.u32 	%r1, [square_kernel_param_2];
	mov.u32 	%r2, %tid.x;
	mov.u32 	%r3, %ctaid.x;
	mov.u32 	%r4, %ntid.x;
	mad.lo.s32 	%r5, %r3, %r4, %r2;
	setp.lt.s32 	%p1, %r5, %r1;
	@%p1 bra 	$L__BB0_2;
	bra.uni 	$L__BB0_1;
$L__BB0_2:
	ld.param.u64 	%rd3, [square_kernel_param_0];
	ld.param.u64 	%rd4, [square_kernel_param_1];
	cvta.to.global.u64 	%rd5, %rd4;
	cvta.to.global.u64 	%rd6, %rd3;
	mul.wide.s32 	%rd7, %r5, 4;
	add.s64 	%rd1, %rd5, %rd7;
	add.s64 	%rd2, %rd6, %rd7;
	ld.global.f32 	%f1, [%rd2];
	mul.rn.f32 	%f2, %f1, %f1;
	st.global.f32 	[%rd1], %f2;
$L__BB0_1:
	ret;

}

On the other hand, ABI_PTX is very much an experiemntal Rust feature. It would be an exciting challenge to work on it some day.

As of today, PTX is on version 8.1, compared to our version 6.0. For simple kernels, perhaps this doesn't matter. But what about sm_30? Read here to learn more about the architecture names. I don't think it's a good sign to see Kepler here!

Kani

What if my device function emits PTX with lots of panics? Ideally I'd like to refactor it so that the Rust compiler writes fewer panics into the code.

I'm experimenting with the model-checker kani.

// https://github.com/model-checking/kani/blob/2df67e380a42a78748b8e84dcc699b0378b287c7/README.md?plain=1#L30
use my_crate::{function_under_test, meets_specification, precondition};

#[kani::proof]
fn check_my_property() {
   // Create a nondeterministic input
   let input = kani::any();

   // Constrain it according to the function's precondition
   kani::assume(precondition(input));

   // Call the function under verification
   let output = function_under_test(input);

   // Check that it meets the specification
   assert!(meets_specification(input, output));
}

Here we assert! that any input to the function_under_test which meets the precondition produces output that meets_specification. It does so by converting the problem to a boolean SAT formula (SAT solvers are fun!) and trying to write a proof. We encode input and output by assigning each bit a boolean variable. The SAT formula asserts that there exists an input whose output does not meet specifications. Many thousands of dummy variables are introduced to calculate output from input through function_under_test.

When given a boolean formula in CNF, a SAT solver either produces:

  • a witness: a truth assignment that satisfies the formula (equivalently, a model)
    • here, a counterexample to the proof: the underlying bits for input that fails the assertion;
    • (!) save this to write a failing unit test; or
  • a certificate: a step-by-step walkthrough of the binary search space that proves insatisfiability
    • this can be saved for reuse

Will this meet my needs? It certainly has me interested.

Next steps

  • finish the PR
  • investigate panic-free no_std refactoring with kani
  • investigate abi_ptx development
  • apply to other crates
  • potential features:
    • build scripts to build PTX before runtime
    • specify more command arguments
    • improve compatibility
    • compile standalone kernel code (uses rustc, not cargo)
    • macros to build kernels from device functions