Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/SnapshotHelpers.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ struct RBranchData {
TBranch *fOutputBranch = nullptr;
void *fBranchAddressForCArrays = nullptr; // Used to detect if branch addresses need to be updated

int fVariationIndex = -1; // For branches that are only valid if a specific filter passed
// A negative index indicates no variations, 0 is for nominal, >0 marks columns that are only valid if a specific
// filter passed
int fVariationIndex = -1;
std::variant<FundamentalType, EmptyDynamicType> fTypeData = FundamentalType{0};
bool fIsCArray = false;
bool fIsDefine = false;
Expand Down
24 changes: 16 additions & 8 deletions tree/dataframe/src/RDFSnapshotHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,10 @@ void ROOT::Internal::RDF::SnapshotHelperWithVariations::RegisterVariedColumn(uns
std::string const &variationName)
{
if (columnIndex == originalColumnIndex) {
fBranchData[columnIndex].fVariationIndex = variationIndex; // The base column has variations
// This is a nominal column, but it participates in variations.
// It always needs to be written, but we still need to create a mask bit to mark when nominal is invalid.
assert(variationIndex == 0);
fBranchData[columnIndex].fVariationIndex = 0;
fOutputHandle->RegisterBranch(fBranchData[columnIndex].fOutputBranchName, variationIndex);
} else if (columnIndex >= fBranchData.size()) {
// First task, need to create branches
Expand Down Expand Up @@ -1245,15 +1248,20 @@ void ROOT::Internal::RDF::SnapshotHelperWithVariations::Exec(unsigned int /*slot
for (std::size_t i = 0; i < values.size(); i++) {
const auto variationIndex = fBranchData[i].fVariationIndex;
if (variationIndex < 0) {
// Branch without variations
// Branch without variations, it always needs to be written
SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]);
} else if (filterPassed[variationIndex]) {
// Branch with variations
const bool fundamentalType = fBranchData[i].WriteValueIfFundamental(values[i]);
if (!fundamentalType) {
SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]);
} else {
// Nominal will always be written, systematics only if needed
if (variationIndex == 0 || filterPassed[variationIndex]) {
const bool fundamentalType = fBranchData[i].WriteValueIfFundamental(values[i]);
if (!fundamentalType) {
SetBranchesHelper(fInputTree, *fOutputHandle->fTree, fBranchData, i, fOptions.fBasketSize, values[i]);
}
}

if (filterPassed[variationIndex]) {
fOutputHandle->SetMaskBit(variationIndex);
}
fOutputHandle->SetMaskBit(variationIndex);
}
}

Expand Down
186 changes: 135 additions & 51 deletions tree/dataframe/test/dataframe_snapshotWithVariations.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
ASSERT_GT(tree.GetEntry(i), 0);

EXPECT_EQ(x, -1 * y);
if (!activeCuts(x, y)) {
if (!activeCuts(x, y) && !sysName.empty()) {
// Branches with systematics should be zeroed when cuts don't pass
EXPECT_EQ(x, X_t{});
EXPECT_EQ(y, Y_t{});
}
Expand Down Expand Up @@ -98,7 +99,7 @@
for (const auto branchName : {"x", "y", "x__xVar_0", "x__xVar_1", "y__xVar_0", "y__xVar_0"})
EXPECT_NE(tree->FindBranch(branchName), nullptr) << branchName;

checkOutput<float, double>(*tree, std::vector<std::string>{"__xVar_0", "__xVar_1"}, cuts);
checkOutput<float, double>(*tree, std::vector<std::string>{"", "__xVar_0", "__xVar_1"}, cuts);

if (HasFailure()) {
tree->Print();
Expand Down Expand Up @@ -264,50 +265,59 @@
if (HasFailure())
break;
}
}

// Test that the Masked column reader works
{
SCOPED_TRACE("Usage of bitmask in RDF");
auto rdf = ROOT::RDataFrame(treename, filename);

auto filterAvailable = rdf.FilterAvailable("x");
auto meanX = filterAvailable.Mean<int>("x");
auto meanY = filterAvailable.Mean<int>("y");
auto count = filterAvailable.Count();

EXPECT_EQ(count.GetValue(), 3ull); // 0, 6, 12
EXPECT_EQ(meanX.GetValue(), 6.);
EXPECT_EQ(meanY.GetValue(), -6.);

// Test reading invalid columns
auto mean = rdf.Mean<int>("x");
EXPECT_THROW(mean.GetValue(), std::out_of_range);

for (unsigned int systematicIndex : {0, 1, 100}) {
const std::string systematic = "__xVar_" + std::to_string(systematicIndex);
auto filterAv = rdf.FilterAvailable("x" + systematic);
auto meanX_sys = filterAv.Mean("x" + systematic);
auto meanY_sys = filterAv.Mean("y" + systematic);
auto count_sys = filterAv.Count();

std::vector<int> expect(N);
std::iota(expect.begin(), expect.end(), systematicIndex);

const auto nVal = std::count_if(expect.begin(), expect.end(), [](int v) { return v % 6 == 0; });
// gcc8.5 on alma8 doesn't support transform_reduce, nor reduce
// const int sum = std::transform_reduce(expect.begin(), expect.end(), 0, std::plus<>(),
// [](int v) { return v % 6 == 0 ? v : 0; });
std::transform(expect.begin(), expect.end(), expect.begin(), [](int v) { return v % 6 == 0 ? v : 0; });
const int sum = std::accumulate(expect.begin(), expect.end(), 0);

ASSERT_EQ(count_sys.GetValue(), nVal) << "systematic: " << systematic;
EXPECT_EQ(meanX_sys.GetValue(), sum / nVal) << "systematic: " << systematic;
EXPECT_EQ(meanY_sys.GetValue(), -1. * sum / nVal) << "systematic: " << systematic;
}
// Test that the Masked column reader works
{
SCOPED_TRACE("Usage of bitmask in RDF");
auto rdf = ROOT::RDataFrame(treename, filename);

auto filterAvailable = rdf.FilterAvailable("x");
auto meanX = filterAvailable.Mean<int>("x");
auto meanY = filterAvailable.Mean<int>("y");
auto count = filterAvailable.Count();

EXPECT_EQ(count.GetValue(), 3ull); // 0, 6, 12
EXPECT_EQ(meanX.GetValue(), 6.);
EXPECT_EQ(meanY.GetValue(), -6.);

// Test reading invalid columns
auto mean = rdf.Mean<int>("x");
EXPECT_THROW(mean.GetValue(), std::out_of_range);

for (unsigned int systematicIndex : {0, 1, 100}) {
const std::string systematic = "__xVar_" + std::to_string(systematicIndex);
auto filterAv = rdf.FilterAvailable("x" + systematic);
auto meanX_sys = filterAv.Mean("x" + systematic);
auto meanY_sys = filterAv.Mean("y" + systematic);
auto count_sys = filterAv.Count();

std::vector<int> expect(N);
std::iota(expect.begin(), expect.end(), systematicIndex);

const auto nVal = std::count_if(expect.begin(), expect.end(), [](int v) { return v % 6 == 0; });
// gcc8.5 on alma8 doesn't support transform_reduce, nor reduce
// const int sum = std::transform_reduce(expect.begin(), expect.end(), 0, std::plus<>(),
// [](int v) { return v % 6 == 0 ? v : 0; });
std::transform(expect.begin(), expect.end(), expect.begin(), [](int v) { return v % 6 == 0 ? v : 0; });
const int sum = std::accumulate(expect.begin(), expect.end(), 0);

ASSERT_EQ(count_sys.GetValue(), nVal) << "systematic: " << systematic;
EXPECT_EQ(meanX_sys.GetValue(), sum / nVal) << "systematic: " << systematic;
EXPECT_EQ(meanY_sys.GetValue(), -1. * sum / nVal) << "systematic: " << systematic;
}
}

if (HasFailure()) {
tree->Scan("entry:x:y:x__xVar_0:y__xVar_0:x__xVar_1:y__xVar_1:x__xVar_2:y__xVar_2");
if (HasFailure()) {
TFile file(filename, "READ");
std::unique_ptr<TTree> tree{file.Get<TTree>(treename.data())};
ASSERT_NE(tree, nullptr);
tree->Scan("entry:R_rdf_mask_testTree_0:x:y:x__xVar_0:y__xVar_0:x__xVar_1:y__xVar_1:x__xVar_2:y__xVar_2");

auto map = file.Get<std::unordered_map<std::string, std::pair<std::string, unsigned int>>>(
("R_rdf_column_to_bitmask_mapping_" + treename).c_str());
for (auto const &[name, mapping] : *map) {
std::cout << std::setw(20) << name << " --> " << mapping.first << " " << mapping.second << "\n";

Check failure on line 320 in tree/dataframe/test/dataframe_snapshotWithVariations.cxx

View workflow job for this annotation

GitHub Actions / alma8

‘setw’ is not a member of ‘std’
}
}
}
Expand Down Expand Up @@ -422,15 +432,16 @@
: ((systematicName.find("xVariation_0") != std::string::npos) ? entry + 1 : entry * 3);
const bool passCuts = (originalX % 2 == 0) || originalX == 5;

if (passCuts)
if (passCuts || systematicName.empty()) {
EXPECT_EQ(x, originalX) << "sys:'" << systematicName << "' originalX: " << originalX << " event: " << event;
else
ASSERT_EQ(y->size(), 4) << "sys:'" << systematicName << "' entry: " << entry << " originalX: " << originalX
<< " event: " << event;
for (unsigned int i = 0; i < y->size(); ++i) {
EXPECT_EQ((*y)[i], x + i);
}
} else {
EXPECT_EQ(x, 0) << "sys:'" << systematicName << "' originalX: " << originalX << " event: " << event;

ASSERT_EQ(y->size(), passCuts ? 4 : 0)
<< "sys:'" << systematicName << "' entry: " << entry << " originalX: " << originalX << " event: " << event;
for (unsigned int i = 0; i < y->size(); ++i) {
EXPECT_EQ((*y)[i], x + i);
ASSERT_EQ(y->size(), 0);
}
}
tree->ResetBranchAddresses();
Expand Down Expand Up @@ -555,7 +566,8 @@
TFile file(fileName);
auto tree = file.Get<TTree>("Events");
ASSERT_NE(tree, nullptr);
tree->Scan();
if (verbose)
tree->Scan();

double Muon_pt, Muon_pt_up, Muon_pt_down;
double Muon_2pt, Muon_2pt_up, Muon_2pt_down;
Expand All @@ -577,3 +589,75 @@
EXPECT_EQ(2. * Muon_2pt, Muon_2pt_up);
}
}

TEST(RDFVarySnapshot, TwoVaryExpressions)
{
constexpr auto filename = "VarySnapshot_TwoVaryExpressions.root";
RemoveFileRAII(filename);
constexpr unsigned int N = 10;
ROOT::RDF::RSnapshotOptions options;
options.fOverwriteIfExists = true;
options.fIncludeVariations = true;

auto cuts = [](float x, float y) { return (x < 20 || x > 30) && (y < 600 || y > 700); };

auto snap = ROOT::RDataFrame(N)
.Define("x", [](ULong64_t e) -> float { return 10.f * e; }, {"rdfentry_"})
.Define("y", [](ULong64_t e) -> float { return 100.f * e; }, {"rdfentry_"})
.Vary(
"x", [](float x) { return ROOT::RVecF{x - 0.5f, x + 0.5f}; }, {"x"}, {"down", "up"}, "xVar")
.Vary(
"y", [](float y) { return ROOT::RVecF{y - 0.5f, y + 0.5f}; }, {"y"}, {"down", "up"}, "yVar")
.Filter(cuts, {"x", "y"})
.Snapshot("t", filename, {"x", "y"}, options);

{
std::unique_ptr<TFile> file{TFile::Open(filename)};
auto tree = file->Get<TTree>("t");

EXPECT_EQ(tree->GetEntries(), 10);
EXPECT_EQ(tree->GetNbranches(), 7); // 6 branches for x/y with variations, bitmask
for (const auto branchName : {"x", "y", "x__xVar_down", "x__xVar_up", "y__yVar_down", "y__yVar_up"})
EXPECT_NE(tree->FindBranch(branchName), nullptr) << branchName;

for (std::string combos : {"x:y", "x__xVar_down:y", "x__xVar_up:y", "x:y__yVar_down", "x:y__yVar_up"}) {
const auto xName = combos.substr(0, combos.find(':'));
const auto yName = combos.substr(combos.find(':') + 1);

float x;
float y;
ASSERT_EQ(TTree::kMatch, tree->SetBranchAddress(xName.c_str(), &x)) << xName;
ASSERT_EQ(TTree::kMatch, tree->SetBranchAddress(yName.c_str(), &y)) << yName;

for (unsigned int i = 0; i < tree->GetEntries(); ++i) {
ASSERT_GT(tree->GetEntry(i), 0);
const float expectedX = i * 10.f - (xName.find("xVar_down") != std::string::npos) * 0.5f +
(xName.find("xVar_up") != std::string::npos) * 0.5f;
const float expectedY = i * 100.f - (yName.find("yVar_down") != std::string::npos) * 0.5f +
(yName.find("yVar_up") != std::string::npos) * 0.5f;

if (cuts(expectedX, expectedY)) {
EXPECT_EQ(x, expectedX) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "="
<< expectedY << ")";
EXPECT_EQ(y, expectedY) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "="
<< expectedY << ")";
EXPECT_TRUE(cuts(x, y)) << "entry:" << i << " (" << xName << "=" << expectedX << ", " << yName << "="
<< expectedY << ")";
}
}

tree->ResetBranchAddresses();
if (HasFailure())
break;
}

if (verbose || HasFailure()) {
auto map = file->Get<std::unordered_map<std::string, std::pair<std::string, unsigned int>>>(
"R_rdf_column_to_bitmask_mapping_t");
for (auto const &[name, mapping] : *map) {
std::cout << std::setw(20) << name << " --> " << mapping.first << " " << mapping.second << "\n";

Check failure on line 658 in tree/dataframe/test/dataframe_snapshotWithVariations.cxx

View workflow job for this annotation

GitHub Actions / alma8

‘setw’ is not a member of ‘std’
}
printTree(*tree);
}
}
}
Loading