diff --git a/crates/khal/src/backend/cuda.rs b/crates/khal/src/backend/cuda.rs index b734819..cf715e9 100644 --- a/crates/khal/src/backend/cuda.rs +++ b/crates/khal/src/backend/cuda.rs @@ -271,9 +271,16 @@ impl Backend for Cuda { } } - // Expect PTX text bytes. - let ptx_str = std::str::from_utf8(bytes).map_err(|_| CudaBackendError::InvalidPtx)?; - let ptx = cudarc::nvrtc::Ptx::from_src(ptx_str.to_string()); + // Accept either PTX text or a pre-linked CUBIN (detected by ELF magic). + // A cubin is required when the module references symbols the driver JIT + // cannot resolve on its own (e.g. libdevice `__nv_*` math), which a + // toolchain links into a self-contained binary ahead of time. + let ptx = if bytes.starts_with(&[0x7f, b'E', b'L', b'F']) { + cudarc::nvrtc::Ptx::from_binary(bytes.to_vec()) + } else { + let ptx_str = std::str::from_utf8(bytes).map_err(|_| CudaBackendError::InvalidPtx)?; + cudarc::nvrtc::Ptx::from_src(ptx_str.to_string()) + }; let module = self.ctx.load_module(ptx)?; // Cache the loaded module.