REST-for-Physics  v2.3
Rare Event Searches ToolKit for Physics
TRestDataSetOdds.cxx
1/*************************************************************************
2 * This file is part of the REST software framework. *
3 * *
4 * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
5 * For more information see https://gifna.unizar.es/trex *
6 * *
7 * REST is free software: you can redistribute it and/or modify *
8 * it under the terms of the GNU General Public License as published by *
9 * the Free Software Foundation, either version 3 of the License, or *
10 * (at your option) any later version. *
11 * *
12 * REST is distributed in the hope that it will be useful, *
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of *
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
15 * GNU General Public License for more details. *
16 * *
17 * You should have a copy of the GNU General Public License along with *
18 * REST in $REST_PATH/LICENSE. *
19 * If not, see https://www.gnu.org/licenses/. *
20 * For the list of contributors see $REST_PATH/CREDITS. *
21 *************************************************************************/
22
102
103#include "TRestDataSetOdds.h"
104
105#include "TRestDataSet.h"
106
107ClassImp(TRestDataSetOdds);
108
113
128TRestDataSetOdds::TRestDataSetOdds(const char* configFilename, std::string name)
129 : TRestMetadata(configFilename) {
131 Initialize();
132
134}
135
140
145void TRestDataSetOdds::Initialize() { SetSectionName(this->ClassName()); }
146
152 Initialize();
154
155 TiXmlElement* obsDefinition = GetElement("observable");
156 while (obsDefinition != nullptr) {
157 std::string obsName = GetFieldValue("name", obsDefinition);
158 if (obsName.empty() || obsName == "Not defined") {
159 RESTError << "< observable variable key does not contain a name!" << RESTendl;
160 exit(1);
161 } else {
162 fObsName.push_back(obsName);
163 }
164
165 std::string range = GetFieldValue("range", obsDefinition);
166 if (range.empty() || range == "Not defined") {
167 RESTError << "< observable key does not contain a range value!" << RESTendl;
168 exit(1);
169 } else {
170 TVector2 roi = StringTo2DVector(range);
171 fObsRange.push_back(roi);
172 }
173
174 std::string nBins = GetFieldValue("nBins", obsDefinition);
175 if (nBins.empty() || nBins == "Not defined") {
176 RESTError << "< observable key does not contain a nBins value!" << RESTendl;
177 exit(1);
178 } else {
179 fObsNbins.push_back(StringToInteger(nBins));
180 }
181
182 obsDefinition = GetNextElement(obsDefinition);
183 }
184
185 if (fObsName.empty() || fObsRange.empty()) {
186 RESTError << "No observables provided, exiting..." << RESTendl;
187 exit(1);
188 }
189
190 if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", "");
191
192 fCut = (TRestCut*)InstantiateChildMetadata("TRestCut");
193}
194
206
207 TRestDataSet dataSet;
208 dataSet.Import(fDataSetName);
209
210 if (fOddsFile.empty()) {
211 auto DF = dataSet.MakeCut(fCut);
212 RESTInfo << "Generating PDFs for dataset: " << fDataSetName << RESTendl;
213 for (size_t i = 0; i < fObsName.size(); i++) {
214 const std::string obsName = fObsName[i];
215 const TVector2 range = fObsRange[i];
216 const std::string histName = "h" + obsName;
217 const int nBins = fObsNbins[i];
218 RESTDebug << "\tGenerating PDF for " << obsName << " with range: (" << range.X() << ", "
219 << range.Y() << ") and nBins: " << nBins << RESTendl;
220 auto histo =
221 DF.Histo1D({histName.c_str(), histName.c_str(), nBins, range.X(), range.Y()}, obsName);
222 TH1F* h = static_cast<TH1F*>(histo->DrawClone());
223 RESTDebug << "\tNormalizing by integral = " << h->Integral() << RESTendl;
224 h->Scale(1. / h->Integral());
225 fHistos[obsName] = h;
226 }
227 } else {
228 TFile* f = TFile::Open(fOddsFile.c_str());
229 if (f == nullptr) {
230 RESTError << "Cannot open calibration odds file " << fOddsFile << RESTendl;
231 exit(1);
232 }
233 RESTInfo << "Opening " << fOddsFile << " as oddsFile." << RESTendl;
234 for (size_t i = 0; i < fObsName.size(); i++) {
235 const std::string obsName = fObsName[i];
236 const std::string histName = "h" + obsName;
237 TH1F* h = (TH1F*)f->Get(histName.c_str());
238 fHistos[obsName] = h;
239 }
240 }
241
242 auto df = dataSet.GetDataFrame();
243 std::string totName = "";
244 RESTDebug << "Computing log odds from " << fDataSetName << RESTendl;
245 for (const auto& [obsName, histo] : fHistos) {
246 const std::string oddsName = "odds_" + obsName;
247 auto GetLogOdds = [&histo = histo](double val) {
248 double odds = histo->GetBinContent(histo->GetXaxis()->FindBin(val));
249 if (odds == 0) return 1000.;
250 return log(1. - odds) - log(odds);
251 };
252
253 if (df.GetColumnType(obsName) != "Double_t") {
254 RESTWarning << "Column " << obsName << " is not of type 'double'. It will be converted."
255 << RESTendl;
256 df = df.Redefine(obsName, "static_cast<double>(" + obsName + ")");
257 }
258 df = df.Define(oddsName, GetLogOdds, {obsName});
259 auto h = df.Histo1D(oddsName);
260
261 if (!totName.empty()) totName += "+";
262 totName += oddsName;
263 }
264
265 RESTDebug << "Computing total log odds" << RESTendl;
266 RESTDebug << "\tTotal log odds = " << totName << RESTendl;
267 df = df.Define("odds_total", totName);
268
269 dataSet.SetDataFrame(df);
270
271 if (!fOutputFileName.empty()) {
273 RESTDebug << "Exporting dataset to " << fOutputFileName << RESTendl;
274 dataSet.Export(fOutputFileName);
275 TFile* f = TFile::Open(fOutputFileName.c_str(), "UPDATE");
276 this->Write();
277 RESTDebug << "Writing histograms to " << fOutputFileName << RESTendl;
278 for (const auto& [obsName, histo] : fHistos) histo->Write();
279 f->Close();
280 }
281 }
282}
283
284std::vector<std::tuple<std::string, TVector2, int>> TRestDataSetOdds::GetOddsObservables() {
285 std::vector<std::tuple<std::string, TVector2, int>> obs;
286 for (size_t i = 0; i < fObsName.size(); i++) {
287 if (i >= fObsName.size() || i >= fObsRange.size() || i >= fObsNbins.size()) {
288 RESTError << "Sizes for observables names, ranges and bins do not match!" << RESTendl;
289 break;
290 }
291 obs.push_back(std::make_tuple(fObsName[i], fObsRange[i], fObsNbins[i]));
292 }
293 return obs;
294}
295
296void TRestDataSetOdds::AddOddsObservable(const std::string& name, const TVector2& range, int nbins) {
297 fObsName.push_back(name);
298 fObsRange.push_back(range);
299 fObsNbins.push_back(nbins);
300}
301
302void TRestDataSetOdds::SetOddsObservables(const std::vector<std::tuple<std::string, TVector2, int>>& obs) {
303 fObsName.clear();
304 fObsRange.clear();
305 fObsNbins.clear();
306 for (const auto& [name, range, nbins] : obs) AddOddsObservable(name, range, nbins);
307}
308
314
315 // if (fCut) fCut->PrintMetadata();
316 if (!fOddsFile.empty()) RESTMetadata << " Odds file: " << fOddsFile << RESTendl;
317 RESTMetadata << " DataSet file: " << fDataSetName << RESTendl;
318
319 RESTMetadata << " Observables to compute: " << RESTendl;
320 for (size_t i = 0; i < fObsName.size(); i++) {
321 RESTMetadata << fObsName[i] << "; Range: (" << fObsRange[i].X() << ", " << fObsRange[i].Y()
322 << "); nBins: " << fObsNbins[i] << RESTendl;
323 }
324 RESTMetadata << "----" << RESTendl;
325}
A class to help on cuts definitions. To be used with TRestAnalysisTree.
Definition: TRestCut.h:31
This class is meant to compute the log odds for different datasets.
std::string fOddsFile
Name of the odds file to be used to get the PDF.
std::vector< std::string > fObsName
Vector containing different obserbable names.
TRestDataSetOdds()
Default constructor.
void PrintMetadata() override
Prints on screen the information about the metadata members of TRestDataSetOdds.
std::map< std::string, TH1F * > fHistos
Map containing the PDF of the different observables.
std::vector< TVector2 > fObsRange
Vector containing different obserbable ranges.
void Initialize() override
Function to initialize input/output event members and define the section name.
~TRestDataSetOdds()
Default destructor.
TRestCut * fCut
Cuts over the dataset for PDF selection.
void ComputeLogOdds()
This function computes the log odds for a given dataSet. If no calibration odds file is provided it c...
std::string fOutputFileName
Name of the output file.
std::string fDataSetName
Name of the dataSet inside the config file.
std::vector< int > fObsNbins
Vector containing number of bins for the different observables.
void InitFromConfigFile() override
Function to initialize some variables from configfile.
It allows to group a number of runs that satisfy given metadata conditions.
Definition: TRestDataSet.h:34
void Import(const std::string &fileName)
This function imports metadata from a root file it import metadata info from the previous dataSet whi...
ROOT::RDF::RNode GetDataFrame() const
Gives access to the RDataFrame.
Definition: TRestDataSet.h:129
ROOT::RDF::RNode MakeCut(const TRestCut *cut)
This function applies a TRestCut to the dataframe and returns a dataframe with the applied cuts....
void Export(const std::string &filename, std::vector< std::string > excludeColumns={})
It will generate an output file with the dataset compilation. Only the selected branches and the file...
A base class for any REST metadata class.
Definition: TRestMetadata.h:74
virtual void PrintMetadata()
Implemented it in the derived metadata class to print out specific metadata information.
endl_t RESTendl
Termination flag object for TRestStringOutput.
TiXmlElement * GetElement(std::string eleDeclare, TiXmlElement *e=nullptr)
Get an xml element from a given parent element, according to its declaration.
Int_t LoadConfigFromFile(const std::string &configFilename, const std::string &sectionName="")
Give the file name, find out the corresponding section. Then call the main starter.
TRestMetadata * InstantiateChildMetadata(int index, std::string pattern="")
This method will retrieve a new TRestMetadata instance of a child element of the present TRestMetadat...
virtual void InitFromConfigFile()
To make settings from rml file. This method must be implemented in the derived class.
TRestStringOutput::REST_Verbose_Level GetVerboseLevel()
returns the verboselevel in type of REST_Verbose_Level enumerator
std::string GetFieldValue(std::string parName, TiXmlElement *e)
Returns the field value of an xml element which has the specified name.
void SetSectionName(std::string sName)
set the section name, clear the section content
std::string fConfigFileName
Full name of the rml file.
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
overwriting the write() method with fStore considered
TiXmlElement * GetNextElement(TiXmlElement *e)
Get the next sibling xml element of this element, with same eleDeclare.
std::string GetParameter(std::string parName, TiXmlElement *e, TString defaultValue=PARAMETER_NOT_FOUND_STR)
Returns the value for the parameter named parName in the given section.
@ REST_Info
+show most of the information for each steps
static std::string GetFileNameExtension(const std::string &fullname)
Gets the file extension as the substring found after the latest ".".
Definition: TRestTools.cxx:823
Int_t StringToInteger(std::string in)
Gets an integer from a string.
TVector2 StringTo2DVector(std::string in)
Gets a 2D-vector from a string.