LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
LArSupportVectorMachine.cc
Go to the documentation of this file.
1 
9 #include "Helpers/XmlHelper.h"
10 
12 
13 using namespace pandora;
14 
15 namespace lar_content
16 {
17 
18 SupportVectorMachine::SupportVectorMachine() :
19  m_isInitialized(false),
20  m_enableProbability(false),
21  m_probAParameter(0.),
22  m_probBParameter(0.),
23  m_standardizeFeatures(true),
24  m_nFeatures(0),
25  m_bias(0.),
26  m_scaleFactor(1.),
27  m_kernelType(QUADRATIC),
28  m_kernelFunction(QuadraticKernel),
30 {
31 }
32 
33 //------------------------------------------------------------------------------------------------------------------------------------------
34 
35 StatusCode SupportVectorMachine::Initialize(const std::string &parameterLocation, const std::string &svmName)
36 {
37  if (m_isInitialized)
38  {
39  std::cout << "SupportVectorMachine: svm was already initialized" << std::endl;
40  return STATUS_CODE_ALREADY_INITIALIZED;
41  }
42 
43  this->ReadXmlFile(parameterLocation, svmName);
44 
45  // Check the sizes of sigma and scale factor if they are to be used as divisors
47  {
48  for (const FeatureInfo &featureInfo : m_featureInfoList)
49  {
50  if (featureInfo.m_sigmaValue < std::numeric_limits<double>::epsilon())
51  {
52  std::cout << "SupportVectorMachine: could not standardize parameters because sigma value was too small" << std::endl;
53  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
54  }
55  }
56  }
57 
58  // Check the number of features is consistent.
60 
61  for (const SupportVectorInfo &svInfo : m_svInfoList)
62  {
63  if (svInfo.m_supportVector.size() != m_nFeatures)
64  {
65  std::cout << "SupportVectorMachine: the number of features in the xml file was inconsistent" << std::endl;
66  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
67  }
68  }
69 
70  // There's the possibility of a user-defined kernel that doesn't use this as a divisor but let's be safe
71  if (m_scaleFactor < std::numeric_limits<double>::epsilon())
72  {
73  std::cout << "SupportVectorMachine: could not evaluate kernel because scale factor was too small" << std::endl;
74  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
75  }
76 
77  m_isInitialized = true;
78  return STATUS_CODE_SUCCESS;
79 }
80 
81 //------------------------------------------------------------------------------------------------------------------------------------------
82 
83 void SupportVectorMachine::ReadXmlFile(const std::string &svmFileName, const std::string &svmName)
84 {
85  TiXmlDocument xmlDocument(svmFileName);
86 
87  if (!xmlDocument.LoadFile())
88  {
89  std::cout << "SupportVectorMachine::Initialize - Invalid xml file." << std::endl;
90  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
91  }
92 
93  const TiXmlHandle xmlDocumentHandle(&xmlDocument);
94  TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
95 
96  // Try to find the svm container with the required name
97  while (pContainerXmlNode)
98  {
99  if (pContainerXmlNode->ValueStr() != "SupportVectorMachine")
100  throw StatusCodeException(STATUS_CODE_FAILURE);
101 
102  const TiXmlHandle currentHandle(pContainerXmlNode);
103 
104  std::string currentName;
105  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
106 
107  if (currentName.empty() || (currentName.size() > 1000))
108  {
109  std::cout << "SupportVectorMachine::Initialize - Implausible svm name extracted from xml." << std::endl;
110  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
111  }
112 
113  if (currentName == svmName)
114  break;
115 
116  pContainerXmlNode = pContainerXmlNode->NextSibling();
117  }
118 
119  if (!pContainerXmlNode)
120  {
121  std::cout << "SupportVectorMachine: Could not find an svm by the name " << svmName << std::endl;
122  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
123  }
124 
125  // Read the components of this svm container
126  TiXmlHandle localHandle(pContainerXmlNode);
127  TiXmlElement *pCurrentXmlElement = localHandle.FirstChild().Element();
128 
129  while (pCurrentXmlElement)
130  {
131  if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
132  {
133  std::cout << "SupportVectorMachine: Unknown component in xml file" << std::endl;
134  throw StatusCodeException(STATUS_CODE_FAILURE);
135  }
136 
137  pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
138  }
139 }
140 
141 //------------------------------------------------------------------------------------------------------------------------------------------
142 
143 StatusCode SupportVectorMachine::ReadComponent(TiXmlElement *pCurrentXmlElement)
144 {
145  const std::string componentName(pCurrentXmlElement->ValueStr());
146  const TiXmlHandle currentHandle(pCurrentXmlElement);
147 
148  if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
149  return STATUS_CODE_SUCCESS;
150 
151  if (std::string("Machine") == componentName)
152  return this->ReadMachine(currentHandle);
153 
154  if (std::string("Features") == componentName)
155  return this->ReadFeatures(currentHandle);
156 
157  if (std::string("SupportVector") == componentName)
158  return this->ReadSupportVector(currentHandle);
159 
160  return STATUS_CODE_INVALID_PARAMETER;
161 }
162 
163 //------------------------------------------------------------------------------------------------------------------------------------------
164 
165 StatusCode SupportVectorMachine::ReadMachine(const TiXmlHandle &currentHandle)
166 {
167  int kernelType(0);
168  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "KernelType", kernelType));
169 
170  double bias(0.);
171  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "Bias", bias));
172 
173  double scaleFactor(0.);
174  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ScaleFactor", scaleFactor));
175 
176  bool standardize(true);
177  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "Standardize", standardize));
178 
179  bool enableProbability(false);
180  PANDORA_RETURN_RESULT_IF_AND_IF(
181  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "EnableProbability", enableProbability));
182 
183  double probAParameter(0.);
184  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ProbAParameter", probAParameter));
185 
186  double probBParameter(0.);
187  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "ProbBParameter", probBParameter));
188 
189  m_kernelType = static_cast<KernelType>(kernelType);
190  m_bias = bias;
191  m_scaleFactor = scaleFactor;
192  m_enableProbability = enableProbability;
193  m_probAParameter = probAParameter;
194  m_probBParameter = probBParameter;
195 
196  if (kernelType != USER_DEFINED) // if user-defined, leave it so it alone can be set before/after initialization
198 
199  return STATUS_CODE_SUCCESS;
200 }
201 
202 //------------------------------------------------------------------------------------------------------------------------------------------
203 
204 StatusCode SupportVectorMachine::ReadFeatures(const TiXmlHandle &currentHandle)
205 {
206  std::vector<double> muValues;
207  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "MuValues", muValues));
208 
209  std::vector<double> sigmaValues;
210  PANDORA_RETURN_RESULT_IF_AND_IF(
211  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "SigmaValues", sigmaValues));
212 
213  if (muValues.size() != sigmaValues.size())
214  {
215  std::cout << "SupportVectorMachine: could not add feature info because the size of mu (" << muValues.size()
216  << ") did not match "
217  "the size of sigma ("
218  << sigmaValues.size() << ")" << std::endl;
219  return STATUS_CODE_INVALID_PARAMETER;
220  }
221 
222  m_featureInfoList.reserve(muValues.size());
223 
224  for (std::size_t i = 0; i < muValues.size(); ++i)
225  m_featureInfoList.emplace_back(muValues.at(i), sigmaValues.at(i));
226 
227  return STATUS_CODE_SUCCESS;
228 }
229 
230 //------------------------------------------------------------------------------------------------------------------------------------------
231 
232 StatusCode SupportVectorMachine::ReadSupportVector(const TiXmlHandle &currentHandle)
233 {
234  double yAlpha(0.0);
235  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(currentHandle, "AlphaY", yAlpha));
236 
237  std::vector<double> values;
238  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(currentHandle, "Values", values));
239 
240  LArMvaHelper::MvaFeatureVector valuesFeatureVector;
241  for (const double &value : values)
242  valuesFeatureVector.emplace_back(value);
243 
244  m_svInfoList.emplace_back(yAlpha, valuesFeatureVector);
245  return STATUS_CODE_SUCCESS;
246 }
247 
248 //------------------------------------------------------------------------------------------------------------------------------------------
249 
251 {
252  if (!m_isInitialized)
253  {
254  std::cout << "SupportVectorMachine: could not perform classification because the svm was uninitialized" << std::endl;
255  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
256  }
257 
258  if (m_svInfoList.empty())
259  {
260  std::cout << "SupportVectorMachine: could not perform classification because the initialized svm had no support vectors in the model"
261  << std::endl;
262  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
263  }
264 
265  LArMvaHelper::MvaFeatureVector standardizedFeatures;
266  standardizedFeatures.reserve(m_nFeatures);
267 
269  {
270  for (std::size_t i = 0; i < m_nFeatures; ++i)
271  standardizedFeatures.push_back(m_featureInfoList.at(i).StandardizeParameter(features.at(i).Get()));
272  }
273 
274  double classScore(0.);
275  for (const SupportVectorInfo &supportVectorInfo : m_svInfoList)
276  {
277  classScore += supportVectorInfo.m_yAlpha *
278  m_kernelFunction(supportVectorInfo.m_supportVector, (m_standardizeFeatures ? standardizedFeatures : features), m_scaleFactor);
279  }
280 
281  return classScore + m_bias;
282 }
283 
284 } // namespace lar_content
static double CubicKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
An inhomogeneous cubic kernel.
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read the component at the current xml element.
pandora::StatusCode ReadSupportVector(const pandora::TiXmlHandle &currentHandle)
Read the support vector component at the current xml handle.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:75
bool m_enableProbability
Whether to enable probability calculations.
double m_scaleFactor
The kernel scale factor.
FeatureInfoVector m_featureInfoList
The list of FeatureInfo objects.
SVInfoList m_svInfoList
The list of SupportVectorInfo objects.
static double QuadraticKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
An inhomogeneous quadratic kernel.
static double LinearKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
A linear kernel.
pandora::StatusCode ReadMachine(const pandora::TiXmlHandle &currentHandle)
Read the machine component at the current xml handle.
bool m_standardizeFeatures
Whether to standardize the features.
Header file for the lar support vector machine class.
decltype(auto) values(Coll &&coll)
Range-for loop helper iterating across the values of the specified collection.
void ReadXmlFile(const std::string &svmFileName, const std::string &svmName)
Read the svm parameters from an xml file.
KernelMap m_kernelMap
Map from the kernel types to the kernel functions.
KernelFunction m_kernelFunction
The kernel function.
KernelType m_kernelType
The kernel type.
double value
Definition: spectrum.C:18
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &svmName)
Initialize the svm using a serialized model.
double m_probAParameter
The first-order score coefficient for mapping to a probability using the logistic function...
unsigned int m_nFeatures
The number of features.
static double GaussianRbfKernel(const LArMvaHelper::MvaFeatureVector &supportVector, const LArMvaHelper::MvaFeatureVector &features, const double scaleFactor=1.)
A gaussian RBF kernel.
double CalculateClassificationScoreImpl(const LArMvaHelper::MvaFeatureVector &features) const
Implementation method for calculating the classification score using the trained model.
bool m_isInitialized
Whether this svm has been initialized.
double m_probBParameter
The score offset parameter for mapping to a probability using the logistic function.
pandora::StatusCode ReadFeatures(const pandora::TiXmlHandle &currentHandle)
Read the feature component at the current xml handle.