diff --git a/backends/xnnpack/runtime/XNNWeightsCache.cpp b/backends/xnnpack/runtime/XNNWeightsCache.cpp index 005169249bc..09cfe3641b5 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.cpp +++ b/backends/xnnpack/runtime/XNNWeightsCache.cpp @@ -348,6 +348,19 @@ size_t XNNWeightsCache::look_up( if (packed_weight_entry == context->name_to_packed_data_metadata_.end()) { return SIZE_MAX; } + // XNNPACK upgrade detection: a ukernel whose implementation changed + // produces a different seed. Reject the cached entry so look_up_or_insert + // falls through to re-pack with the current ukernel. + if (packed_weight_entry->second.seed != cache_key->seed) { + ET_LOG( + Info, + "look_up: seed mismatch for '%s' (cached=0x%08x, current=0x%08x); " + "treating as miss for re-pack", + weight_bias_name.c_str(), + packed_weight_entry->second.seed, + cache_key->seed); + return SIZE_MAX; + } packed_weight_entry->second.in_current_runtime = true; return packed_weight_entry->second.offset; } @@ -474,6 +487,7 @@ size_t XNNWeightsCache::look_up_or_insert( packed_data_metadata.ref_count = 0; // ref_count is only incremented after finalizing for runtime packed_data_metadata.in_current_runtime = true; + packed_data_metadata.seed = cache_key->seed; context->name_to_packed_data_metadata_[weight_bias_name] = packed_data_metadata; } else { @@ -524,7 +538,7 @@ Error XNNWeightsCache::save_packed_index() { std::vector buf; uint32_t entry_count = 0; - // Index entry: [name_len:u32][name][file_offset:u64][data_size:u64] + // Index entry: [name_len:u32][name][file_offset:u64][data_size:u64][seed:u32] for (const auto& [name, meta] : name_to_packed_data_metadata_) { void* ptr = packed_data_ptrs_[meta.offset]; auto it = ptr_to_file_offset_.find(ptr); @@ -536,6 +550,7 @@ Error XNNWeightsCache::save_packed_index() { buf.insert(buf.end(), name.begin(), name.end()); append_le(buf, static_cast(it->second)); append_le(buf, static_cast(meta.data_size)); + append_le(buf, meta.seed); } // Footer: [index_start:u64][entry_count:u32][magic:u32][version:u32] @@ -635,7 +650,8 @@ bool XNNWeightsCache::load_packed_cache() { for (uint32_t i = 0; i < entry_count && cursor + 4 <= end; ++i) { uint32_t name_len = read_le(cursor); cursor += 4; - if (cursor + name_len + 16 > end) { + // [file_offset:u64][data_size:u64][seed:u32] = 20 bytes + if (cursor + name_len + 20 > end) { // Truncated entry header: trailer doesn't match the entry_count we // read from the footer, so the cache is corrupt. Apply the same // full rollback as the invalid-bounds branch below — otherwise the @@ -660,6 +676,8 @@ bool XNNWeightsCache::load_packed_cache() { cursor += 8; uint64_t data_size = read_le(cursor); cursor += 8; + uint32_t seed = read_le(cursor); + cursor += 4; // Bounds check: the entry's bytes must lie entirely inside the // packed-data region. @@ -692,6 +710,7 @@ bool XNNWeightsCache::load_packed_cache() { meta.ref_count = 0; meta.in_current_runtime = false; meta.from_load = true; + meta.seed = seed; name_to_packed_data_metadata_[name] = meta; } diff --git a/backends/xnnpack/runtime/XNNWeightsCache.h b/backends/xnnpack/runtime/XNNWeightsCache.h index 4bfa916d289..851b452441f 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.h +++ b/backends/xnnpack/runtime/XNNWeightsCache.h @@ -44,6 +44,13 @@ struct PackedDataMeta { // cache_loaded_ is auto-invalidated so the next init re-enters // load_packed_cache and reuses the saved file instead of re-packing. bool from_load{false}; + // Per-ukernel seed from xnn_weights_cache_look_up_key.seed. XNNPACK + // guarantees this is consistent across runs of the same ukernel; when + // XNNPACK upgrades and a ukernel implementation changes, the seed + // changes. look_up rejects entries whose stored seed doesn't match + // the caller's seed so that stale cache entries don't deliver wrongly + // packed weights to a newer ukernel. + uint32_t seed{0}; }; class XNNWeightsCache { @@ -151,7 +158,11 @@ class XNNWeightsCache { private: static constexpr uint32_t kCacheMagic = 0x58505743; // "XPWC" - static constexpr uint32_t kCacheVersion = 1; + // Bump when the on-disk layout (footer or per-entry record) changes. + // v2: per-entry seed added — old v1 files don't carry seeds and would + // load with seed=0, mismatching every fresh look_up with a non-zero + // seed, causing a stampede of re-packs. Reject v1 outright. + static constexpr uint32_t kCacheVersion = 2; bool load_packed_cache(); void reset_for_fresh_write(); void release_entry(void* packed_data_ptr); diff --git a/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp b/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp index 59cfbcbdb5d..4c65e97f3a0 100644 --- a/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp +++ b/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp @@ -704,6 +704,236 @@ TEST_F(XNNWeightsCacheTest, MultiplePTEsInSameInstance_NoFileGrowth) { ::unlink(cache_path.c_str()); } +namespace { + +// Little-endian decode helpers matching XNNWeightsCache's on-disk format. +uint32_t read_le_u32(const uint8_t* p) { + uint32_t v = 0; + for (int i = 0; i < 4; ++i) { + v |= static_cast(p[i]) << (8 * i); + } + return v; +} +uint64_t read_le_u64(const uint8_t* p) { + uint64_t v = 0; + for (int i = 0; i < 8; ++i) { + v |= static_cast(p[i]) << (8 * i); + } + return v; +} +void write_le_u32(std::ostream& f, uint32_t v) { + for (int i = 0; i < 4; ++i) { + char b = static_cast((v >> (8 * i)) & 0xff); + f.write(&b, 1); + } +} +void write_le_u64(std::ostream& f, uint64_t v) { + for (int i = 0; i < 8; ++i) { + char b = static_cast((v >> (8 * i)) & 0xff); + f.write(&b, 1); + } +} + +} // namespace + +// A cache file written by older code (kCacheVersion=1) carries no per-entry +// seed field. Loading such a file with the current schema would yield +// entries with seed=0 and mismatch every fresh look_up. The version bump +// must reject it outright so the next init re-packs from scratch. +TEST_F(XNNWeightsCacheTest, LoadPackedCache_RejectsV1Format) { + std::string cache_path = std::string("/tmp/xnn_weights_cache_v1_") + + std::to_string(::getpid()) + ".packed_cache"; + ::unlink(cache_path.c_str()); + + // v1 layout: 64 bytes of dummy data, then 20-byte footer with version=1. + { + std::ofstream f(cache_path, std::ios::binary); + std::vector data(64, 0); + f.write(data.data(), data.size()); + write_le_u64(f, 64); // index_start + write_le_u32(f, 0); // entry_count + write_le_u32(f, 0x58505743); // kCacheMagic "XPWC" + write_le_u32(f, 1); // OLD kCacheVersion = 1 + } + + XNNWeightsCache cache; + cache.set_packed_cache_path(cache_path); + Error err = + cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + ASSERT_EQ(err, Error::Ok); + // Version mismatch → load_packed_cache returned false → no entries. + EXPECT_EQ(cache.get_packed_data_names().size(), 0u); + + ::unlink(cache_path.c_str()); +} + +// Verify save_packed_index writes the schema version 2 footer and embeds a +// 4-byte seed field in each entry record. Guards against future refactors +// silently dropping the seed write. +TEST_F(XNNWeightsCacheTest, SavePackedIndex_EntryFormatIncludesSeed) { + std::string cache_path = std::string("/tmp/xnn_weights_cache_format_") + + std::to_string(::getpid()) + ".packed_cache"; + ::unlink(cache_path.c_str()); + + std::vector batches{1, 2, 3}; + size_t input_channels = 3; + size_t output_channels = 4; + size_t num_batches = 1 * 2 * 3; + size_t padding = 32; + std::vector input(num_batches * input_channels + padding, 1.0f); + std::vector output(num_batches * output_channels, 0.0f); + + { + XNNWeightsCache cache; + cache.set_packed_cache_path(cache_path); + cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + BuildAndRunGraphWithWeightsCache( + cache, + batches, + input_channels, + output_channels, + input.data(), + output.data()); + ASSERT_EQ(cache.save_packed_index(), Error::Ok); + } + + // Parse footer at file_size - 20. + std::ifstream f(cache_path, std::ios::binary); + ASSERT_TRUE(f.is_open()); + f.seekg(0, std::ios::end); + size_t file_size = f.tellg(); + ASSERT_GE(file_size, 24u); + + uint8_t footer[20]; + f.seekg(file_size - 20); + f.read(reinterpret_cast(footer), 20); + uint32_t magic = read_le_u32(footer + 12); + uint32_t version = read_le_u32(footer + 16); + EXPECT_EQ(magic, 0x58505743u); + EXPECT_EQ(version, 2u); + + // Walk first entry: + // [name_len:u32][name][file_offset:u64][data_size:u64][seed:u32] + uint64_t index_start = read_le_u64(footer); + uint32_t entry_count = read_le_u32(footer + 8); + ASSERT_GT(entry_count, 0u); + + f.seekg(index_start); + uint8_t name_len_buf[4]; + f.read(reinterpret_cast(name_len_buf), 4); + uint32_t name_len = read_le_u32(name_len_buf); + + // The seed field sits at index_start + 4 + name_len + 8 + 8. + f.seekg(index_start + 4 + name_len + 8 + 8); + uint8_t seed_buf[4]; + f.read(reinterpret_cast(seed_buf), 4); + // XNNPACK ukernel seeds are non-zero in practice. The signal here is + // simply that 4 well-formed bytes follow the size field — confirming + // the new entry layout was written, not the legacy 16-byte tail. + uint32_t stored_seed = read_le_u32(seed_buf); + EXPECT_NE(stored_seed, 0u); + + ::unlink(cache_path.c_str()); +} + +// After loading a cache file whose entry seed has been tampered with +// (simulating an XNNPACK upgrade where the same ukernel now emits a +// different seed), the next inference must produce correct output. Either +// look_up's seed check or look_up_or_insert's memcmp fallback drives the +// re-pack; this test exercises the end-to-end safety net. +TEST_F( + XNNWeightsCacheTest, + LoadPackedCache_CorruptedSeed_ProducesCorrectOutput) { + std::string cache_path = std::string("/tmp/xnn_weights_cache_badseed_") + + std::to_string(::getpid()) + ".packed_cache"; + ::unlink(cache_path.c_str()); + + std::vector batches{1, 2, 3}; + size_t input_channels = 3; + size_t output_channels = 4; + size_t num_batches = 1 * 2 * 3; + size_t padding = 32; + std::vector input(num_batches * input_channels + padding, 1.0f); + + // Baseline: fresh pack, heap-only, no cache file. + std::vector baseline(num_batches * output_channels, 0.0f); + { + XNNWeightsCache cache; + cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + BuildAndRunGraphWithWeightsCache( + cache, + batches, + input_channels, + output_channels, + input.data(), + baseline.data()); + } + + // Write a valid cache file. + { + XNNWeightsCache cache; + cache.set_packed_cache_path(cache_path); + cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + std::vector out(num_batches * output_channels, 0.0f); + BuildAndRunGraphWithWeightsCache( + cache, + batches, + input_channels, + output_channels, + input.data(), + out.data()); + ASSERT_EQ(cache.save_packed_index(), Error::Ok); + } + + // Corrupt the seed field of the first entry to a value no real ukernel + // would emit (0xDEADBEEF). + { + std::fstream f(cache_path, std::ios::binary | std::ios::in | std::ios::out); + ASSERT_TRUE(f.is_open()); + f.seekg(0, std::ios::end); + size_t file_size = f.tellg(); + ASSERT_GE(file_size, 24u); + + uint8_t footer_buf[20]; + f.seekg(file_size - 20); + f.read(reinterpret_cast(footer_buf), 20); + uint64_t index_start = read_le_u64(footer_buf); + uint32_t entry_count = read_le_u32(footer_buf + 8); + ASSERT_GT(entry_count, 0u); + + f.seekg(index_start); + uint8_t name_len_buf[4]; + f.read(reinterpret_cast(name_len_buf), 4); + uint32_t name_len = read_le_u32(name_len_buf); + + size_t seed_offset = index_start + 4 + name_len + 8 + 8; + f.seekp(seed_offset); + uint32_t corrupted = 0xDEADBEEFu; + f.write(reinterpret_cast(&corrupted), 4); + f.close(); + } + + // Reload and run. Output must still match baseline. + std::vector after_corruption(num_batches * output_channels, 0.0f); + { + XNNWeightsCache cache; + cache.set_packed_cache_path(cache_path); + cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + ASSERT_GT(cache.get_packed_data_names().size(), 0u); + BuildAndRunGraphWithWeightsCache( + cache, + batches, + input_channels, + output_channels, + input.data(), + after_corruption.data()); + } + + EXPECT_EQ(after_corruption, baseline); + + ::unlink(cache_path.c_str()); +} + // save_packed_index must be a true no-op when no new reserve_space happened // since the last save — same content but writing would still bump mtime, // making the cache file look modified on every model load.