From eda15ddafa0049fd94525d19092bef306da44e11 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Tue, 2 Jul 2024 23:33:32 +0200 Subject: [PATCH] rusticl/program: use blob.h to parse binaries It checks for alignment and overruns, and is a lot safer than whatever was done before here. Cc: mesa-stable Part-of: --- src/gallium/frontends/rusticl/core/program.rs | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 67f8dbea3c4..c0015cce505 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -19,7 +19,6 @@ use std::collections::HashMap; use std::collections::HashSet; use std::ffi::CString; use std::mem::size_of; -use std::ptr; use std::ptr::addr_of; use std::slice; use std::sync::Arc; @@ -380,35 +379,43 @@ impl Program { } fn spirv_from_bin_for_dev(bin: &[u8]) -> CLResult<(SPIRVBin, cl_program_binary_type)> { - let mut ptr = bin.as_ptr(); if bin.is_empty() { return Err(CL_INVALID_VALUE); } unsafe { - // 1. version - let version = ptr.cast::().read(); - ptr = ptr.add(size_of::()); + let mut blob = blob_reader::default(); + blob_reader_init(&mut blob, bin.as_ptr().cast(), bin.len()); + // 1. version + let version = blob_read_uint32(&mut blob); match version { 1 => { // 2. size of the spirv - let spirv_size = ptr.cast::().read(); - ptr = ptr.add(size_of::()); + let spirv_size = blob_read_uint32(&mut blob) as usize; // 3. binary_type - let bin_type = ptr.cast::().read(); - ptr = ptr.add(size_of::()); + let bin_type = blob_read_uint32(&mut blob); + + debug_assert!( + // `blob_read_*` doesn't advance the pointer on failure to read + blob.current.offset_from(blob.data) == BIN_HEADER_SIZE_V1 as isize + || blob.overrun, + ); // 4. the spirv - assert!(bin.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr); + let spirv_data = blob_read_bytes(&mut blob, spirv_size); - if bin.len() != BIN_HEADER_SIZE_V1 + spirv_size as usize { + // check that all the reads are valid before accessing the data, which might + // be uninitialized otherwise. + if blob.overrun { return Err(CL_INVALID_BINARY); } - let spirv = - spirv::SPIRVBin::from_bin(slice::from_raw_parts(ptr, spirv_size as usize)); + let spirv = spirv::SPIRVBin::from_bin(slice::from_raw_parts( + spirv_data.cast(), + spirv_size, + )); Ok((spirv, bin_type)) } @@ -541,7 +548,6 @@ impl Program { let lock = self.build_info(); for (i, d) in self.devs.iter().enumerate() { - let mut ptr = ptrs[i]; let info = lock.dev_build(d); // no spirv means nothing to write @@ -551,21 +557,24 @@ impl Program { let spirv = spirv.to_bin(); unsafe { + let mut blob = blob::default(); + + // sadly we have to trust the buffer to be correctly sized... + blob_init_fixed(&mut blob, ptrs[i].cast(), usize::MAX); + // 1. binary format version - ptr.cast::().write(1); - ptr = ptr.add(size_of::()); + blob_write_uint32(&mut blob, 1); // 2. size of the spirv - ptr.cast::().write(spirv.len() as u32); - ptr = ptr.add(size_of::()); + blob_write_uint32(&mut blob, spirv.len() as u32); // 3. binary_type - ptr.cast::().write(info.bin_type); - ptr = ptr.add(size_of::()); + blob_write_uint32(&mut blob, info.bin_type); + debug_assert!(blob.size == BIN_HEADER_SIZE); // 4. the spirv - assert!(ptrs[i].add(BIN_HEADER_SIZE) == ptr); - ptr::copy_nonoverlapping(spirv.as_ptr(), ptr, spirv.len()); + blob_write_bytes(&mut blob, spirv.as_ptr().cast(), spirv.len()); + blob_finish(&mut blob); } }