diff --git a/devtools/etdump/data_sinks/buffer_data_sink.cpp b/devtools/etdump/data_sinks/buffer_data_sink.cpp index 5678aefb181..513a93aeeca 100644 --- a/devtools/etdump/data_sinks/buffer_data_sink.cpp +++ b/devtools/etdump/data_sinks/buffer_data_sink.cpp @@ -33,7 +33,12 @@ BufferDataSink::create(void* ptr, size_t size, size_t alignment) noexcept { } Result BufferDataSink::write(const void* ptr, size_t length) { - if (length == 0) { + bool inPlaceTensor = false; + + if (length != 0 && ptr == nullptr) { + inPlaceTensor = true; + } else if (length == 0 || ptr == nullptr) { + ET_LOG(Info, "Invalid data to write to buffer"); return offset_; } @@ -50,7 +55,13 @@ Result BufferDataSink::write(const void* ptr, size_t length) { // Zero out the padding between data blobs memset(last_data_end, 0, cur_data_begin - last_data_end); - memcpy(cur_data_begin, ptr, length); + + if (inPlaceTensor) { + memset(cur_data_begin, 0, length); + } else { + memcpy(cur_data_begin, ptr, length); + } + offset_ = (size_t)(cur_data_end - debug_buffer_.data()); return (size_t)(cur_data_begin - debug_buffer_.data()); diff --git a/devtools/etdump/data_sinks/file_data_sink.cpp b/devtools/etdump/data_sinks/file_data_sink.cpp index e9f9f44a899..55f661d4699 100644 --- a/devtools/etdump/data_sinks/file_data_sink.cpp +++ b/devtools/etdump/data_sinks/file_data_sink.cpp @@ -44,20 +44,34 @@ Result FileDataSink::write(const void* ptr, size_t size) { return Error::AccessFailed; } - size_t offset = total_written_bytes_; + bool inPlaceTensor = false; - if (size == 0) { - // No data to write, return current offset - return offset; + if (size != 0 && ptr == nullptr) { + inPlaceTensor = true; + } else if (size == 0 || ptr == nullptr) { + ET_LOG(Info, "Invalid data to write to file"); + return total_written_bytes_; } - size_t written = fwrite(ptr, 1, size, file_); - if (written != size) { - ET_LOG(Error, "Write failed: wrote %zu bytes of %zu", written, size); - return Error::Internal; + size_t offset = total_written_bytes_; + + if (inPlaceTensor) { + std::vector zeros(size, 0); + size_t written = fwrite(zeros.data(), 1, size, file_); + if (written != size) { + ET_LOG(Error, "Write failed: wrote %zu bytes of %zu", written, size); + return Error::Internal; + } + total_written_bytes_ += written; + } else { + size_t written = fwrite(ptr, 1, size, file_); + if (written != size) { + ET_LOG(Error, "Write failed: wrote %zu bytes of %zu", written, size); + return Error::Internal; + } + total_written_bytes_ += written; } - total_written_bytes_ += written; return offset; } diff --git a/devtools/etdump/data_sinks/tests/buffer_data_sink_test.cpp b/devtools/etdump/data_sinks/tests/buffer_data_sink_test.cpp index c4178c29a4b..014c5f877b1 100644 --- a/devtools/etdump/data_sinks/tests/buffer_data_sink_test.cpp +++ b/devtools/etdump/data_sinks/tests/buffer_data_sink_test.cpp @@ -148,3 +148,57 @@ TEST_F(BufferDataSinkTest, illegalAlignment) { ASSERT_EQ(buffer_data_sink_ret.error(), Error::InvalidArgument); } } + +TEST_F(BufferDataSinkTest, WriteInPlaceTensorZeroFill) { + // Test writing with nullptr ptr and non-zero length (in-place tensor case) + // This should zero-fill the buffer + size_t length = 16; + + Result ret = buffer_data_sink_->write(nullptr, length); + ASSERT_EQ(ret.error(), Error::Ok); + + size_t offset = ret.get(); + EXPECT_NE(offset, static_cast(-1)); + + // Verify the data in the buffer is zero-filled + const uint8_t* buffer_data = buffer_.data() + offset; + for (size_t i = 0; i < length; ++i) { + EXPECT_EQ(buffer_data[i], 0); + } +} + +TEST_F(BufferDataSinkTest, WriteZeroLengthReturnsCurrentOffset) { + // Test writing with zero length returns current offset + Result ret = buffer_data_sink_->write(nullptr, 0); + ASSERT_EQ(ret.error(), Error::Ok); + EXPECT_EQ(ret.get(), 0); + + // Write some data first + TensorFactory tf; + Tensor tensor = tf.make({1, 4}, {1.0, 2.0, 3.0, 4.0}); + Result write_ret = + buffer_data_sink_->write(tensor.const_data_ptr(), tensor.nbytes()); + ASSERT_EQ(write_ret.error(), Error::Ok); + + // Zero length write should return current offset + size_t current_used = buffer_data_sink_->get_used_bytes(); + Result ret2 = buffer_data_sink_->write(nullptr, 0); + ASSERT_EQ(ret2.error(), Error::Ok); + EXPECT_EQ(ret2.get(), current_used); +} + +TEST_F(BufferDataSinkTest, WriteNullptrWithZeroLengthReturnsCurrentOffset) { + // Write some data first to advance the offset + TensorFactory tf; + Tensor tensor = tf.make({1, 4}, {1.0, 2.0, 3.0, 4.0}); + Result write_ret = + buffer_data_sink_->write(tensor.const_data_ptr(), tensor.nbytes()); + ASSERT_EQ(write_ret.error(), Error::Ok); + + size_t current_used = buffer_data_sink_->get_used_bytes(); + + // Writing nullptr with zero length should return current offset + Result ret = buffer_data_sink_->write(nullptr, 0); + ASSERT_EQ(ret.error(), Error::Ok); + EXPECT_EQ(ret.get(), current_used); +} diff --git a/devtools/etdump/data_sinks/tests/file_data_sink_test.cpp b/devtools/etdump/data_sinks/tests/file_data_sink_test.cpp index 33122d320aa..fb9e49ae128 100644 --- a/devtools/etdump/data_sinks/tests/file_data_sink_test.cpp +++ b/devtools/etdump/data_sinks/tests/file_data_sink_test.cpp @@ -137,3 +137,88 @@ TEST_F(FileDataSinkTest, WriteMultipleDataAndCheckOffsets) { EXPECT_EQ( std::memcmp(file_content.data() + offset3.get(), data3, data3_size), 0); } + +TEST_F(FileDataSinkTest, WriteInPlaceTensorZeroFill) { + // Test writing with nullptr ptr and non-zero length (in-place tensor case) + // This should zero-fill the file + size_t length = 16; + + // Create a FileDataSink instance + Result result = FileDataSink::create(file_path_.c_str()); + ASSERT_TRUE(result.ok()); + FileDataSink* data_sink = &result.get(); + + // Write nullptr with non-zero length + Result write_result = data_sink->write(nullptr, length); + ASSERT_TRUE(write_result.ok()); + EXPECT_EQ(write_result.get(), 0); + + size_t used_bytes = data_sink->get_used_bytes(); + EXPECT_EQ(used_bytes, length); + + data_sink->close(); + + // Verify the file contents are zero-filled + std::ifstream file(file_path_, std::ios::binary); + file.seekg(0, std::ios::end); + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + EXPECT_EQ(file_size, length); + + // Read the file content and verify it's all zeros + std::vector file_content(file_size); + file.read(reinterpret_cast(file_content.data()), file_size); + file.close(); + + for (size_t i = 0; i < length; ++i) { + EXPECT_EQ(file_content[i], 0); + } +} + +TEST_F(FileDataSinkTest, WriteZeroLengthReturnsCurrentOffset) { + // Create a FileDataSink instance + Result result = FileDataSink::create(file_path_.c_str()); + ASSERT_TRUE(result.ok()); + FileDataSink* data_sink = &result.get(); + + // Zero length write should return current offset (0) + Result ret = data_sink->write(nullptr, 0); + ASSERT_TRUE(ret.ok()); + EXPECT_EQ(ret.get(), 0); + + // Write some data first + const char* data = "Hello"; + size_t data_size = strlen(data); + Result write_ret = data_sink->write(data, data_size); + ASSERT_TRUE(write_ret.ok()); + + // Zero length write should return current offset + size_t current_used = data_sink->get_used_bytes(); + Result ret2 = data_sink->write(nullptr, 0); + ASSERT_TRUE(ret2.ok()); + EXPECT_EQ(ret2.get(), current_used); + + data_sink->close(); +} + +TEST_F(FileDataSinkTest, WriteNullptrWithZeroLengthReturnsCurrentOffset) { + // Create a FileDataSink instance + Result result = FileDataSink::create(file_path_.c_str()); + ASSERT_TRUE(result.ok()); + FileDataSink* data_sink = &result.get(); + + // Write some data first to advance the offset + const char* data = "Test data"; + size_t data_size = strlen(data); + Result write_ret = data_sink->write(data, data_size); + ASSERT_TRUE(write_ret.ok()); + + size_t current_used = data_sink->get_used_bytes(); + + // Writing nullptr with zero length should return current offset + Result ret = data_sink->write(nullptr, 0); + ASSERT_TRUE(ret.ok()); + EXPECT_EQ(ret.get(), current_used); + + data_sink->close(); +}