Skip to content

Commit

Permalink
[RF] Implement channel masking for simultaneous likelihoods
Browse files Browse the repository at this point in the history
Upstreaming a feature from CMS combine.

Draft for now.
  • Loading branch information
guitargeek committed Nov 20, 2024
1 parent 3b38fed commit d5e95ac
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 13 deletions.
2 changes: 2 additions & 0 deletions roofit/codegen/inc/RooFit/CodegenImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace Detail {
class RooFixedProdPdf;
class RooNLLVarNew;
class RooNormalizedPdf;
class RooSimNLL;
} // namespace Detail

namespace Experimental {
Expand All @@ -76,6 +77,7 @@ class CodegenContext;
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooSimNLL &arg, CodegenContext &ctx);
void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx);
void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx);
void codegenImpl(RooAbsArg &arg, CodegenContext &ctx);
Expand Down
32 changes: 23 additions & 9 deletions roofit/codegen/src/CodegenImpl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
}
}

void codegenImpl(RooFit::Detail::RooSimNLL &arg, CodegenContext &ctx)
{
if (arg.terms().empty()) {
ctx.addResult(&arg, "0.0");
}

std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
ctx.addResult(&arg, resName);
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");

std::stringstream ss;

std::size_t i = 0;

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac13 ARM64 LLVM_ENABLE_ASSERTIONS=On, builtin_zlib=ON

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac14 X64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac15 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac-beta ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

variable 'i' set but not used [-Wunused-but-set-variable]
for (auto *component : static_range_cast<RooAbsReal *>(arg.terms())) {

// TODO: support channel masking here
ss << resName << " += " << ctx.buildFunction(*component, ctx.outputSizes()) << "(params, obs, xlArr);\n";
++i;
}
ctx.addToGlobalScope(ss.str());
}

void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx)
{
std::string const &idx = arg.dataHist().calculateTreeIndexForCodeSquash(&arg, ctx, arg.dataVars(), true);
Expand Down Expand Up @@ -251,15 +273,7 @@ void codegenImpl(RooAddition &arg, CodegenContext &ctx)

std::size_t i = 0;
for (auto *component : static_range_cast<RooAbsReal *>(arg.list())) {

if (!dynamic_cast<RooFit::Detail::RooNLLVarNew *>(component) || arg.list().size() == 1) {
result += ctx.getResult(*component);
++i;
if (i < arg.list().size())
result += '+';
continue;
}
result += ctx.buildFunction(*component, ctx.outputSizes()) + "(params, obs, xlArr)";
result += ctx.getResult(*component);
++i;
if (i < arg.list().size())
result += '+';
Expand Down
1 change: 1 addition & 0 deletions roofit/roofitcore/inc/LinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,5 +337,6 @@
#pragma link C++ class RooBinWidthFunction+;
#pragma link C++ class RooFit::Detail::RooNLLVarNew+;
#pragma link C++ class RooFit::Detail::RooNormalizedPdf+ ;
#pragma link C++ class RooFit::Detail::RooSimNLL+;

#endif
27 changes: 27 additions & 0 deletions roofit/roofitcore/inc/RooFit/Detail/RooNLLVarNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include <RooAbsPdf.h>
#include <RooAbsReal.h>
#include <RooGlobalFunc.h>
#include <RooListProxy.h>
#include <RooTemplateProxy.h>

#include <Math/Util.h>

class RooAbsCategory;

namespace RooFit {
namespace Detail {

Expand Down Expand Up @@ -87,6 +90,30 @@ class RooNLLVarNew : public RooAbsReal {
ClassDefOverride(RooFit::Detail::RooNLLVarNew, 0);
};

class RooSimNLL : public RooAbsReal {
public:
RooSimNLL(const char *name, const char *title, const RooArgSet &terms, RooAbsCategoryLValue const &indexCat,
bool channelMasking);

RooSimNLL(const RooSimNLL &other, const char *name = nullptr);
TObject *clone(const char *newname) const override { return new RooSimNLL(*this, newname); }

double defaultErrorLevel() const override;

const RooArgSet &terms() const { return _set; }
const RooArgSet &masks() const { return _mask; }

void doEval(RooFit::EvalContext &) const override;

protected:
double evaluate() const override;

RooSetProxy _set; ///< set of terms to be summed
RooSetProxy _mask;

ClassDefOverride(RooFit::Detail::RooSimNLL, 0) // Sum of RooNLLVarNew instances
};

} // namespace Detail
} // namespace RooFit

Expand Down
3 changes: 2 additions & 1 deletion roofit/roofitcore/src/FitHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#endif

using RooFit::Detail::RooNLLVarNew;
using RooFit::Detail::RooSimNLL;

namespace {

Expand Down Expand Up @@ -357,7 +358,7 @@ std::unique_ptr<RooAbsArg> createSimultaneousNLL(RooSimultaneous const &simPdf,
}

// Time to sum the NLLs
auto nll = std::make_unique<RooAddition>("mynll", "mynll", nllTerms);
auto nll = std::make_unique<RooSimNLL>("mynll", "mynll", nllTerms, simCat, true);
nll->addOwnedComponents(std::move(nllTerms));
return nll;
}
Expand Down
59 changes: 56 additions & 3 deletions roofit/roofitcore/src/RooNLLVarNew.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ computation times.

#include "RooFit/Detail/RooNLLVarNew.h"

#include <RooHistPdf.h>
#include <RooAbsCategoryLValue.h>
#include <RooBatchCompute.h>
#include <RooConstVar.h>
#include <RooDataHist.h>
#include <RooFit/Detail/MathFuncs.h>
#include <RooHistPdf.h>
#include <RooNaNPacker.h>
#include <RooConstVar.h>
#include <RooRealVar.h>
#include <RooSetProxy.h>
#include <RooFit/Detail/MathFuncs.h>

#include "RooFitImplHelpers.h"

Expand All @@ -47,6 +48,7 @@ computation times.
#include <vector>

ClassImp(RooFit::Detail::RooNLLVarNew);
ClassImp(RooFit::Detail::RooSimNLL);

namespace RooFit {
namespace Detail {
Expand Down Expand Up @@ -336,6 +338,57 @@ void RooNLLVarNew::finalizeResult(RooFit::EvalContext &ctx, ROOT::Math::KahanSum
ctx.setOutputWithOffset(this, result, _offset);
}

RooSimNLL::RooSimNLL(const char *name, const char *title, const RooArgSet &terms, RooAbsCategoryLValue const &indexCat,
bool channelMasking)
: RooAbsReal(name, title), _set("!set", "set of components", this), _mask("!mask", "set of masks", this)
{
_set.addTyped<RooAbsReal>(terms);

if (channelMasking) {
for (auto const &catState : indexCat) {
std::string const &catName = catState.first;
std::string maskName = "mask_" + catName;
_mask.addOwned(std::make_unique<RooRealVar>(maskName.c_str(), maskName.c_str(), 0.0));
}
}
}

RooSimNLL::RooSimNLL(const RooSimNLL &other, const char *name)
: RooAbsReal(other, name), _set("!set", this, other._set), _mask("!mask", this, other._set)
{
}

double RooSimNLL::evaluate() const
{
double sum(0);
const RooArgSet *nset = _set.nset();

std::size_t i = 0;
for (auto *comp : static_range_cast<RooAbsReal *>(_set)) {
if (_mask.empty() || static_cast<RooAbsReal const *>(_mask[i])->getVal() == 0.0) {
sum += comp->getVal(nset);
}
++i;
}
return sum;
}

void RooSimNLL::doEval(RooFit::EvalContext &ctx) const
{
double result = 0.;
for (std::size_t i = 0; i < _set.size(); ++i) {
if (_mask.empty() || ctx.at(_mask[i])[0] == 0.0) {
result += ctx.at(_set[i])[0];
}
}
ctx.output()[0] = result;
}

double RooSimNLL::defaultErrorLevel() const
{
return static_cast<RooNLLVarNew *>(_set[0])->defaultErrorLevel();
}

} // namespace Detail
} // namespace RooFit

Expand Down

0 comments on commit d5e95ac

Please sign in to comment.