Skip to content

Commit 801f5a3

Browse files
committed
Check size before memcpy
1 parent 86b4bea commit 801f5a3

6 files changed

Lines changed: 218 additions & 42 deletions

File tree

extension/runner_util/inputs.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,47 @@ Result<BufferCleanup> prepare_input_tensors(
8686
Debug, "Verifying and setting input for non-tensor input %zu", i);
8787

8888
if (tag.get() == Tag::Int) {
89-
int64_t int_input;
90-
std::memcpy(&int_input, buffer, buffer_size);
91-
err = method.set_input(runtime::EValue(int_input), i);
89+
if (buffer_size != sizeof(int64_t)) {
90+
ET_LOG(
91+
Error,
92+
"Int input at index %zu has size %zu, expected sizeof(int64_t) %zu",
93+
i,
94+
buffer_size,
95+
sizeof(int64_t));
96+
err = Error::InvalidArgument;
97+
} else {
98+
int64_t int_input;
99+
std::memcpy(&int_input, buffer, buffer_size);
100+
err = method.set_input(runtime::EValue(int_input), i);
101+
}
92102
} else if (tag.get() == Tag::Double) {
93-
double double_input;
94-
std::memcpy(&double_input, buffer, buffer_size);
95-
err = method.set_input(runtime::EValue(double_input), i);
103+
if (buffer_size != sizeof(double)) {
104+
ET_LOG(
105+
Error,
106+
"Double input at index %zu has size %zu, expected sizeof(double) %zu",
107+
i,
108+
buffer_size,
109+
sizeof(double));
110+
err = Error::InvalidArgument;
111+
} else {
112+
double double_input;
113+
std::memcpy(&double_input, buffer, buffer_size);
114+
err = method.set_input(runtime::EValue(double_input), i);
115+
}
96116
} else if (tag.get() == Tag::Bool) {
97-
bool bool_input;
98-
std::memcpy(&bool_input, buffer, buffer_size);
99-
err = method.set_input(runtime::EValue(bool_input), i);
117+
if (buffer_size != sizeof(bool)) {
118+
ET_LOG(
119+
Error,
120+
"Bool input at index %zu has size %zu, expected sizeof(bool) %zu",
121+
i,
122+
buffer_size,
123+
sizeof(bool));
124+
err = Error::InvalidArgument;
125+
} else {
126+
bool bool_input;
127+
std::memcpy(&bool_input, buffer, buffer_size);
128+
err = method.set_input(runtime::EValue(bool_input), i);
129+
}
100130
} else {
101131
ET_LOG(
102132
Error,

extension/runner_util/test/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
2020
add_custom_command(
2121
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
2222
COMMAND ${PYTHON_EXECUTABLE} -m test.models.export_program --modules
23-
"ModuleAdd" --outdir "${CMAKE_CURRENT_BINARY_DIR}"
23+
"ModuleAdd,ModuleIntBool" --outdir "${CMAKE_CURRENT_BINARY_DIR}"
2424
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
2525
)
2626

2727
add_custom_target(
2828
executorch_runner_util_test_resources
29-
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
29+
DEPENDS
30+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
31+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleIntBool.pte"
3032
)
3133

32-
set(test_env "ET_MODULE_ADD_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte")
34+
set(test_env "ET_MODULE_ADD_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
35+
"ET_MODULE_INTBOOL_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleIntBool.pte")
3336

3437
set(_test_srcs inputs_test.cpp)
3538

extension/runner_util/test/inputs_test.cpp

Lines changed: 159 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,52 +40,77 @@ class InputsTest : public ::testing::Test {
4040
void SetUp() override {
4141
torch::executor::runtime_init();
4242

43-
// Create a loader for the serialized ModuleAdd program.
44-
const char* path = std::getenv("ET_MODULE_ADD_PATH");
45-
Result<FileDataLoader> loader = FileDataLoader::from(path);
46-
ASSERT_EQ(loader.error(), Error::Ok);
47-
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
48-
49-
// Use it to load the program.
50-
Result<Program> program = Program::load(
51-
loader_.get(), Program::Verification::InternalConsistency);
52-
ASSERT_EQ(program.error(), Error::Ok);
53-
program_ = std::make_unique<Program>(std::move(program.get()));
54-
55-
mmm_ = std::make_unique<ManagedMemoryManager>(
43+
// Load ModuleAdd
44+
const char* add_path = std::getenv("ET_MODULE_ADD_PATH");
45+
Result<FileDataLoader> add_loader = FileDataLoader::from(add_path);
46+
ASSERT_EQ(add_loader.error(), Error::Ok);
47+
add_loader_ = std::make_unique<FileDataLoader>(std::move(add_loader.get()));
48+
49+
Result<Program> add_program = Program::load(
50+
add_loader_.get(), Program::Verification::InternalConsistency);
51+
ASSERT_EQ(add_program.error(), Error::Ok);
52+
add_program_ = std::make_unique<Program>(std::move(add_program.get()));
53+
54+
add_mmm_ = std::make_unique<ManagedMemoryManager>(
55+
/*planned_memory_bytes=*/32 * 1024U,
56+
/*method_allocator_bytes=*/32 * 1024U);
57+
58+
Result<Method> add_method =
59+
add_program_->load_method("forward", &add_mmm_->get());
60+
ASSERT_EQ(add_method.error(), Error::Ok);
61+
add_method_ = std::make_unique<Method>(std::move(add_method.get()));
62+
63+
// Load ModuleIntBool
64+
const char* intbool_path = std::getenv("ET_MODULE_INTBOOL_PATH");
65+
Result<FileDataLoader> intbool_loader = FileDataLoader::from(intbool_path);
66+
ASSERT_EQ(intbool_loader.error(), Error::Ok);
67+
intbool_loader_ =
68+
std::make_unique<FileDataLoader>(std::move(intbool_loader.get()));
69+
70+
Result<Program> intbool_program = Program::load(
71+
intbool_loader_.get(), Program::Verification::InternalConsistency);
72+
ASSERT_EQ(intbool_program.error(), Error::Ok);
73+
intbool_program_ =
74+
std::make_unique<Program>(std::move(intbool_program.get()));
75+
76+
intbool_mmm_ = std::make_unique<ManagedMemoryManager>(
5677
/*planned_memory_bytes=*/32 * 1024U,
5778
/*method_allocator_bytes=*/32 * 1024U);
5879

59-
// Load the forward method.
60-
Result<Method> method = program_->load_method("forward", &mmm_->get());
61-
ASSERT_EQ(method.error(), Error::Ok);
62-
method_ = std::make_unique<Method>(std::move(method.get()));
80+
Result<Method> intbool_method =
81+
intbool_program_->load_method("forward", &intbool_mmm_->get());
82+
ASSERT_EQ(intbool_method.error(), Error::Ok);
83+
intbool_method_ = std::make_unique<Method>(std::move(intbool_method.get()));
6384
}
6485

6586
private:
66-
// Must outlive method_, but tests shouldn't need to touch them.
67-
std::unique_ptr<FileDataLoader> loader_;
68-
std::unique_ptr<ManagedMemoryManager> mmm_;
69-
std::unique_ptr<Program> program_;
87+
std::unique_ptr<FileDataLoader> add_loader_;
88+
std::unique_ptr<Program> add_program_;
89+
std::unique_ptr<ManagedMemoryManager> add_mmm_;
90+
91+
std::unique_ptr<FileDataLoader> intbool_loader_;
92+
std::unique_ptr<Program> intbool_program_;
93+
std::unique_ptr<ManagedMemoryManager> intbool_mmm_;
7094

7195
protected:
72-
std::unique_ptr<Method> method_;
96+
std::unique_ptr<Method> add_method_;
97+
std::unique_ptr<Method> intbool_method_;
7398
};
7499

75100
TEST_F(InputsTest, Smoke) {
76-
Result<BufferCleanup> input_buffers = prepare_input_tensors(*method_);
101+
Result<BufferCleanup> input_buffers = prepare_input_tensors(*add_method_);
77102
ASSERT_EQ(input_buffers.error(), Error::Ok);
78-
auto input_err = method_->set_input(executorch::runtime::EValue(1.0), 2);
103+
auto input_err = add_method_->set_input(executorch::runtime::EValue(1.0), 2);
79104
ASSERT_EQ(input_err, Error::Ok);
80105

81106
// We can't look at the input tensors, but we can check that the outputs make
82107
// sense after executing the method.
83-
Error status = method_->execute();
108+
Error status = add_method_->execute();
84109
ASSERT_EQ(status, Error::Ok);
85110

86111
// Get the single output, which should be a floating-point Tensor.
87-
ASSERT_EQ(method_->outputs_size(), 1);
88-
const EValue& output_value = method_->get_output(0);
112+
ASSERT_EQ(add_method_->outputs_size(), 1);
113+
const EValue& output_value = add_method_->get_output(0);
89114
ASSERT_EQ(output_value.tag, Tag::Tensor);
90115
Tensor output = output_value.toTensor();
91116
ASSERT_EQ(output.scalar_type(), ScalarType::Float);
@@ -107,14 +132,14 @@ TEST_F(InputsTest, ExceedingInputCountLimitFails) {
107132
// The smoke test above demonstrated that we can prepare inputs with the
108133
// default limits. It should fail if we lower the max below the number of
109134
// actual inputs.
110-
MethodMeta method_meta = method_->method_meta();
135+
MethodMeta method_meta = add_method_->method_meta();
111136
size_t num_inputs = method_meta.num_inputs();
112137
ASSERT_GE(num_inputs, 1);
113138
executorch::extension::PrepareInputTensorsOptions options;
114139
options.max_inputs = num_inputs - 1;
115140

116141
Result<BufferCleanup> input_buffers =
117-
prepare_input_tensors(*method_, options);
142+
prepare_input_tensors(*add_method_, options);
118143
ASSERT_NE(input_buffers.error(), Error::Ok);
119144
}
120145

@@ -128,7 +153,7 @@ TEST_F(InputsTest, ExceedingInputAllocationLimitFails) {
128153
options.max_total_allocation_size = 1;
129154

130155
Result<BufferCleanup> input_buffers =
131-
prepare_input_tensors(*method_, options);
156+
prepare_input_tensors(*add_method_, options);
132157
ASSERT_NE(input_buffers.error(), Error::Ok);
133158
}
134159

@@ -186,3 +211,107 @@ TEST(BufferCleanupTest, Smoke) {
186211
// complaint.
187212
bc2.reset();
188213
}
214+
215+
TEST_F(InputsTest, DoubleInputWrongSizeFails) {
216+
MethodMeta method_meta = add_method_->method_meta();
217+
218+
// ModuleAdd has 3 inputs: tensor, tensor, double (alpha)
219+
ASSERT_EQ(method_meta.num_inputs(), 3);
220+
221+
// Verify input 2 is a Double
222+
auto tag = method_meta.input_tag(2);
223+
ASSERT_TRUE(tag.ok());
224+
ASSERT_EQ(tag.get(), Tag::Double);
225+
226+
// Create input_buffers with wrong size for the Double input
227+
std::vector<std::pair<char*, size_t>> input_buffers;
228+
229+
// Allocate correct buffers for tensors (inputs 0 and 1)
230+
auto tensor0_meta = method_meta.input_tensor_meta(0);
231+
auto tensor1_meta = method_meta.input_tensor_meta(1);
232+
ASSERT_TRUE(tensor0_meta.ok());
233+
ASSERT_TRUE(tensor1_meta.ok());
234+
235+
std::vector<char> buf0(tensor0_meta->nbytes());
236+
std::vector<char> buf1(tensor1_meta->nbytes());
237+
238+
// ModuleAdd expects alpha=1.0. Need to set this correctly, otherwise
239+
// set_input fails validation before the buffer overflow happens.
240+
double alpha = 1.0;
241+
// Double is size 8; use a larger buffer to invoke overflow.
242+
char large_buffer[16];
243+
memcpy(large_buffer, &alpha, sizeof(double));
244+
245+
input_buffers.push_back({buf0.data(), buf0.size()});
246+
input_buffers.push_back({buf1.data(), buf1.size()});
247+
input_buffers.push_back({large_buffer, sizeof(large_buffer)});
248+
249+
Result<BufferCleanup> result =
250+
prepare_input_tensors(*add_method_, {}, input_buffers);
251+
EXPECT_EQ(result.error(), Error::InvalidArgument);
252+
}
253+
254+
TEST_F(InputsTest, IntBoolInputWrongSizeFails) {
255+
MethodMeta method_meta = intbool_method_->method_meta();
256+
257+
// ModuleIntBool has 3 inputs: tensor, int, bool
258+
ASSERT_EQ(method_meta.num_inputs(), 3);
259+
260+
// Verify input types
261+
auto int_tag = method_meta.input_tag(1);
262+
ASSERT_TRUE(int_tag.ok());
263+
ASSERT_EQ(int_tag.get(), Tag::Int);
264+
265+
auto bool_tag = method_meta.input_tag(2);
266+
ASSERT_TRUE(bool_tag.ok());
267+
ASSERT_EQ(bool_tag.get(), Tag::Bool);
268+
269+
// Allocate correct buffer for tensor (input 0)
270+
auto tensor0_meta = method_meta.input_tensor_meta(0);
271+
ASSERT_TRUE(tensor0_meta.ok());
272+
std::vector<char> buf0(tensor0_meta->nbytes());
273+
274+
// Prepare scalar values
275+
int64_t y = 1;
276+
bool z = true;
277+
278+
// Test 1: Int input with wrong size
279+
{
280+
std::vector<std::pair<char*, size_t>> input_buffers;
281+
282+
// Int is size 8; use a larger buffer to invoke overflow.
283+
char large_int_buffer[16];
284+
memcpy(large_int_buffer, &y, sizeof(int64_t));
285+
286+
char bool_buffer[1];
287+
memcpy(bool_buffer, &z, sizeof(bool));
288+
289+
input_buffers.push_back({buf0.data(), buf0.size()});
290+
input_buffers.push_back({large_int_buffer, sizeof(large_int_buffer)});
291+
input_buffers.push_back({bool_buffer, sizeof(bool_buffer)});
292+
293+
Result<BufferCleanup> result =
294+
prepare_input_tensors(*intbool_method_, {}, input_buffers);
295+
EXPECT_EQ(result.error(), Error::InvalidArgument);
296+
}
297+
298+
// Test 2: Bool input with wrong size
299+
{
300+
std::vector<std::pair<char*, size_t>> input_buffers;
301+
302+
char int_buffer[8];
303+
memcpy(int_buffer, &y, sizeof(int64_t));
304+
305+
// Bool is size 1; use a larger buffer to invoke overflow.
306+
char large_bool_buffer[8];
307+
memcpy(large_bool_buffer, &z, sizeof(bool));
308+
309+
input_buffers.push_back({buf0.data(), buf0.size()});
310+
input_buffers.push_back({int_buffer, sizeof(int_buffer)});
311+
input_buffers.push_back({large_bool_buffer, sizeof(large_bool_buffer)});
312+
313+
Result<BufferCleanup> result =
314+
prepare_input_tensors(*intbool_method_, {}, input_buffers);
315+
EXPECT_EQ(result.error(), Error::InvalidArgument);
316+
}
317+
}

extension/runner_util/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ def define_common_targets(is_fbcode = False):
2828
],
2929
env = {
3030
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
31+
"ET_MODULE_INTBOOL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleIntBool.pte])",
3132
},
3233
)

test/models/export_program.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def get_random_inputs(self):
6767
return (torch.randn(10, 10, 10),)
6868

6969

70+
# Used for testing int and bool inputs.
71+
class ModuleIntBool(torch.nn.Module):
72+
def __init__(self):
73+
super().__init__()
74+
75+
def forward(self, x: torch.Tensor, y: int, z: bool):
76+
return x + y + int(z)
77+
78+
def get_random_inputs(self):
79+
return (torch.ones(1), 1, True)
80+
81+
7082
class ModuleNoOp(nn.Module):
7183
def __init__(self):
7284
super(ModuleNoOp, self).__init__()

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def define_common_targets():
6868
"ModuleMultipleEntry",
6969
"ModuleNoKVCache",
7070
"ModuleIndex",
71+
"ModuleIntBool",
7172
"ModuleDynamicCatUnallocatedIO",
7273
"ModuleSimpleTrain",
7374
"ModuleStateful",

0 commit comments

Comments
 (0)