rusticl/program: use blob.h to parse binaries

It checks for alignment and overruns, and is a lot safer than whatever was
done before here.

Cc: mesa-stable
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29946>
This commit is contained in:
Karol Herbst
2024-07-02 23:33:32 +02:00
committed by Marge Bot
parent 81bb379c94
commit eda15ddafa
+31 -22
View File
@@ -19,7 +19,6 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::ffi::CString;
use std::mem::size_of;
use std::ptr;
use std::ptr::addr_of;
use std::slice;
use std::sync::Arc;
@@ -380,35 +379,43 @@ impl Program {
}
fn spirv_from_bin_for_dev(bin: &[u8]) -> CLResult<(SPIRVBin, cl_program_binary_type)> {
let mut ptr = bin.as_ptr();
if bin.is_empty() {
return Err(CL_INVALID_VALUE);
}
unsafe {
// 1. version
let version = ptr.cast::<u32>().read();
ptr = ptr.add(size_of::<u32>());
let mut blob = blob_reader::default();
blob_reader_init(&mut blob, bin.as_ptr().cast(), bin.len());
// 1. version
let version = blob_read_uint32(&mut blob);
match version {
1 => {
// 2. size of the spirv
let spirv_size = ptr.cast::<u32>().read();
ptr = ptr.add(size_of::<u32>());
let spirv_size = blob_read_uint32(&mut blob) as usize;
// 3. binary_type
let bin_type = ptr.cast::<cl_program_binary_type>().read();
ptr = ptr.add(size_of::<cl_program_binary_type>());
let bin_type = blob_read_uint32(&mut blob);
debug_assert!(
// `blob_read_*` doesn't advance the pointer on failure to read
blob.current.offset_from(blob.data) == BIN_HEADER_SIZE_V1 as isize
|| blob.overrun,
);
// 4. the spirv
assert!(bin.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr);
let spirv_data = blob_read_bytes(&mut blob, spirv_size);
if bin.len() != BIN_HEADER_SIZE_V1 + spirv_size as usize {
// check that all the reads are valid before accessing the data, which might
// be uninitialized otherwise.
if blob.overrun {
return Err(CL_INVALID_BINARY);
}
let spirv =
spirv::SPIRVBin::from_bin(slice::from_raw_parts(ptr, spirv_size as usize));
let spirv = spirv::SPIRVBin::from_bin(slice::from_raw_parts(
spirv_data.cast(),
spirv_size,
));
Ok((spirv, bin_type))
}
@@ -541,7 +548,6 @@ impl Program {
let lock = self.build_info();
for (i, d) in self.devs.iter().enumerate() {
let mut ptr = ptrs[i];
let info = lock.dev_build(d);
// no spirv means nothing to write
@@ -551,21 +557,24 @@ impl Program {
let spirv = spirv.to_bin();
unsafe {
let mut blob = blob::default();
// sadly we have to trust the buffer to be correctly sized...
blob_init_fixed(&mut blob, ptrs[i].cast(), usize::MAX);
// 1. binary format version
ptr.cast::<u32>().write(1);
ptr = ptr.add(size_of::<u32>());
blob_write_uint32(&mut blob, 1);
// 2. size of the spirv
ptr.cast::<u32>().write(spirv.len() as u32);
ptr = ptr.add(size_of::<u32>());
blob_write_uint32(&mut blob, spirv.len() as u32);
// 3. binary_type
ptr.cast::<cl_program_binary_type>().write(info.bin_type);
ptr = ptr.add(size_of::<cl_program_binary_type>());
blob_write_uint32(&mut blob, info.bin_type);
debug_assert!(blob.size == BIN_HEADER_SIZE);
// 4. the spirv
assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr);
ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len());
blob_write_bytes(&mut blob, spirv.as_ptr().cast(), spirv.len());
blob_finish(&mut blob);
}
}