rusticl/program: use blob.h to parse binaries
It checks for alignment and overruns, and is a lot safer than whatever was done before here. Cc: mesa-stable Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29946>
This commit is contained in:
@@ -19,7 +19,6 @@ use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::CString;
|
||||
use std::mem::size_of;
|
||||
use std::ptr;
|
||||
use std::ptr::addr_of;
|
||||
use std::slice;
|
||||
use std::sync::Arc;
|
||||
@@ -380,35 +379,43 @@ impl Program {
|
||||
}
|
||||
|
||||
fn spirv_from_bin_for_dev(bin: &[u8]) -> CLResult<(SPIRVBin, cl_program_binary_type)> {
|
||||
let mut ptr = bin.as_ptr();
|
||||
if bin.is_empty() {
|
||||
return Err(CL_INVALID_VALUE);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// 1. version
|
||||
let version = ptr.cast::<u32>().read();
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
let mut blob = blob_reader::default();
|
||||
blob_reader_init(&mut blob, bin.as_ptr().cast(), bin.len());
|
||||
|
||||
// 1. version
|
||||
let version = blob_read_uint32(&mut blob);
|
||||
match version {
|
||||
1 => {
|
||||
// 2. size of the spirv
|
||||
let spirv_size = ptr.cast::<u32>().read();
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
let spirv_size = blob_read_uint32(&mut blob) as usize;
|
||||
|
||||
// 3. binary_type
|
||||
let bin_type = ptr.cast::<cl_program_binary_type>().read();
|
||||
ptr = ptr.add(size_of::<cl_program_binary_type>());
|
||||
let bin_type = blob_read_uint32(&mut blob);
|
||||
|
||||
debug_assert!(
|
||||
// `blob_read_*` doesn't advance the pointer on failure to read
|
||||
blob.current.offset_from(blob.data) == BIN_HEADER_SIZE_V1 as isize
|
||||
|| blob.overrun,
|
||||
);
|
||||
|
||||
// 4. the spirv
|
||||
assert!(bin.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr);
|
||||
let spirv_data = blob_read_bytes(&mut blob, spirv_size);
|
||||
|
||||
if bin.len() != BIN_HEADER_SIZE_V1 + spirv_size as usize {
|
||||
// check that all the reads are valid before accessing the data, which might
|
||||
// be uninitialized otherwise.
|
||||
if blob.overrun {
|
||||
return Err(CL_INVALID_BINARY);
|
||||
}
|
||||
|
||||
let spirv =
|
||||
spirv::SPIRVBin::from_bin(slice::from_raw_parts(ptr, spirv_size as usize));
|
||||
let spirv = spirv::SPIRVBin::from_bin(slice::from_raw_parts(
|
||||
spirv_data.cast(),
|
||||
spirv_size,
|
||||
));
|
||||
|
||||
Ok((spirv, bin_type))
|
||||
}
|
||||
@@ -541,7 +548,6 @@ impl Program {
|
||||
|
||||
let lock = self.build_info();
|
||||
for (i, d) in self.devs.iter().enumerate() {
|
||||
let mut ptr = ptrs[i];
|
||||
let info = lock.dev_build(d);
|
||||
|
||||
// no spirv means nothing to write
|
||||
@@ -551,21 +557,24 @@ impl Program {
|
||||
let spirv = spirv.to_bin();
|
||||
|
||||
unsafe {
|
||||
let mut blob = blob::default();
|
||||
|
||||
// sadly we have to trust the buffer to be correctly sized...
|
||||
blob_init_fixed(&mut blob, ptrs[i].cast(), usize::MAX);
|
||||
|
||||
// 1. binary format version
|
||||
ptr.cast::<u32>().write(1);
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
blob_write_uint32(&mut blob, 1);
|
||||
|
||||
// 2. size of the spirv
|
||||
ptr.cast::<u32>().write(spirv.len() as u32);
|
||||
ptr = ptr.add(size_of::<u32>());
|
||||
blob_write_uint32(&mut blob, spirv.len() as u32);
|
||||
|
||||
// 3. binary_type
|
||||
ptr.cast::<cl_program_binary_type>().write(info.bin_type);
|
||||
ptr = ptr.add(size_of::<cl_program_binary_type>());
|
||||
blob_write_uint32(&mut blob, info.bin_type);
|
||||
debug_assert!(blob.size == BIN_HEADER_SIZE);
|
||||
|
||||
// 4. the spirv
|
||||
assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr);
|
||||
ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len());
|
||||
blob_write_bytes(&mut blob, spirv.as_ptr().cast(), spirv.len());
|
||||
blob_finish(&mut blob);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user