24 #include "KalmanFitter.h"
26 #include "Exception.h"
27 #include "KalmanFitterInfo.h"
28 #include "KalmanFitStatus.h"
30 #include "TrackPoint.h"
34 #include <Math/ProbFunc.h>
36 #include <TDecompChol.h>
37 #include <TMatrixDSymEigen.h>
44 double& chi2,
double& ndf,
45 int startId,
int endId,
int& nFailedHits)
52 Exception exc(
"KalmanFitter::fitTrack ==> cannot use (un)weightedClosestToReference(Wire) as multiple measurement handling.",__LINE__,__FILE__);
58 startId += tr->getNumPointsWithMeasurement();
60 endId += tr->getNumPointsWithMeasurement();
72 debugOut << tr->getNumPointsWithMeasurement() <<
" TrackPoints w/ measurement in this track." << std::endl;
74 for (
int i = startId; ; i+=direction) {
75 TrackPoint *tp = tr->getPointWithMeasurement(i);
76 assert(direction == +1 || direction == -1);
79 debugOut <<
" process TrackPoint nr. " << i <<
" (" << tp <<
")\n";
83 processTrackPoint(tp, rep, chi2, ndf, direction);
90 tr->getPoint(i)->deleteFitterInfo(rep);
96 debugOut <<
"There was an exception, try to continue with next TrackPoint " << i+direction <<
" \n";
116 Exception exc(
"KalmanFitter::processTrack: Cannot process pruned track!", __LINE__,__FILE__);
120 TrackPoint* trackPoint = tr->getPointWithMeasurement(0);
122 if (trackPoint->hasFitterInfo(rep) &&
128 debugOut <<
"take backward update of previous iteration as seed \n";
135 debugOut <<
"take smoothed state of cardinal rep fit as seed \n";
146 rep->
setTime(*currentState_, tr->getTimeSeed());
147 rep->
setPosMomCov(*currentState_, tr->getStateSeed(), tr->getCovSeed());
149 debugOut <<
"take seed state of track as seed \n";
156 double oldChi2FW(1.e6);
157 double oldChi2BW(1.e6);
158 double oldPvalFW(0.);
160 double oldPvalBW = 0.;
161 double chi2FW(0), ndfFW(0);
162 double chi2BW(0), ndfBW(0);
164 int nFailedHitsForward(0), nFailedHitsBackward(0);
168 tr->setFitStatus(status, rep);
170 unsigned int nIt = 0;
174 debugOut <<
"\033[1;21mstate pre" << std::endl;
175 currentState_->Print();
176 debugOut <<
"\033[0mfitting" << std::endl;
179 if (!fitTrack(tr, rep, chi2FW, ndfFW, 0, -1, nFailedHitsForward)) {
180 status->setIsFitted(
false);
181 status->setIsFitConvergedFully(
false);
182 status->setIsFitConvergedPartially(
false);
183 status->setNFailedPoints(nFailedHitsForward);
188 debugOut <<
"\033[1;21mstate post forward" << std::endl;
189 currentState_->Print();
196 if (!fitTrack(tr, rep, chi2BW, ndfBW, -1, 0, nFailedHitsBackward)) {
197 status->setIsFitted(
false);
198 status->setIsFitConvergedFully(
false);
199 status->setIsFitConvergedPartially(
false);
200 status->setNFailedPoints(nFailedHitsBackward);
205 debugOut <<
"\033[1;21mstate post backward" << std::endl;
206 currentState_->Print();
209 debugOut <<
"old chi2s: " << oldChi2BW <<
", " << oldChi2FW
210 <<
" new chi2s: " << chi2BW <<
", " << chi2FW
211 <<
" oldPvals " << oldPvalFW <<
", " << oldPvalBW << std::endl;
215 double PvalBW = std::max(0.,ROOT::Math::chisquared_cdf_c(chi2BW, ndfBW));
216 double PvalFW = (debugLvl_ > 0) ? std::max(0.,ROOT::Math::chisquared_cdf_c(chi2FW, ndfFW)) : 0;
222 bool converged(
false);
223 bool finished(
false);
249 if (nFailedHitsForward == 0 && nFailedHitsBackward == 0)
250 status->setIsFitConvergedFully(converged);
252 status->setIsFitConvergedFully(
false);
254 status->setIsFitConvergedPartially(converged);
256 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
272 debugOut <<
"KalmanFitter::number of max iterations reached!\n";
278 status->setIsFitted(
false);
279 status->setIsFitConvergedFully(
false);
280 status->setIsFitConvergedPartially(
false);
281 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
286 status->setIsFitted();
288 TrackPoint* tp = tr->getPointWithMeasurementAndFitterInfo(0, rep);
293 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
294 status->setCharge(charge);
295 status->setNumIterations(nIt);
296 status->setForwardChi2(chi2FW);
297 status->setBackwardChi2(chi2BW);
298 status->setForwardNdf(std::max(0., ndfFW));
299 status->setBackwardNdf(std::max(0., ndfBW));
307 Exception exc(
"KalmanFitter::processTrack: Cannot process pruned track!", __LINE__,__FILE__);
312 startId += tr->getNumPointsWithMeasurement();
314 endId += tr->getNumPointsWithMeasurement();
321 TrackPoint* trackPoint = tr->getPointWithMeasurement(startId);
325 if (direction == 1 && startId > 0)
326 prevTrackPoint = tr->getPointWithMeasurement(startId - 1);
327 else if (direction == -1 && startId < (
int)tr->getNumPointsWithMeasurement() - 1)
328 prevTrackPoint = tr->getPointWithMeasurement(startId + 1);
331 if (prevTrackPoint !=
nullptr &&
332 prevTrackPoint->hasFitterInfo(rep) &&
337 debugOut <<
"take update of previous fitter info as seed \n";
344 debugOut <<
"take smoothed state of cardinal rep fit as seed \n";
355 rep->
setTime(*currentState_, tr->getTimeSeed());
356 rep->
setPosMomCov(*currentState_, tr->getStateSeed(), tr->getCovSeed());
358 debugOut <<
"take seed of track as seed \n";
362 if (startId == 0 || startId == (
int)tr->getNumPointsWithMeasurement() - 1) {
369 debugOut <<
"\033[1;21mstate pre" << std::endl;
370 currentState_->Print();
371 debugOut <<
"\033[0mfitting" << std::endl;
376 fitTrack(tr, rep, chi2, ndf, startId, endId, nFailedHits);
382 KalmanFitter::processTrackPoint(
TrackPoint* tp,
383 const AbsTrackRep* rep,
double& chi2,
double& ndf,
int direction)
385 assert(direction == -1 || direction == +1);
387 if (!tp->hasRawMeasurements())
390 bool newFi(!tp->hasFitterInfo(rep));
401 bool oldWeightsFixed(
false);
402 std::vector<double> oldWeights;
407 oldWeights = fi->getWeights();
408 oldWeightsFixed = fi->areWeightsFixed();
411 fi->deleteForwardInfo();
412 fi->deleteBackwardInfo();
413 fi->deleteMeasurementInfo();
416 const std::vector< genfit::AbsMeasurement* >& rawMeasurements = tp->getRawMeasurements();
417 plane = rawMeasurements[0]->constructPlane(*currentState_);
420 plane = fi->getPlane();
424 debugOut <<
"extrapolated by " << extLen << std::endl;
426 fi->setPrediction(currentState_->clone(), direction);
431 const std::vector< genfit::AbsMeasurement* >& rawMeasurements = tp->getRawMeasurements();
432 for (std::vector< genfit::AbsMeasurement* >::const_iterator it = rawMeasurements.begin(); it != rawMeasurements.end(); ++it) {
433 fi->addMeasurementsOnPlane((*it)->constructMeasurementsOnPlane(*state));
435 if (oldWeights.size() == fi->getNumMeasurements()) {
436 fi->setWeights(oldWeights);
437 fi->fixWeights(oldWeightsFixed);
442 assert(fi->getPlane() == plane);
443 assert(fi->checkConsistency());
447 debugOut <<
"its plane is at R = " << plane->getO().Perp()
448 <<
" with normal pointing along (" << plane->getNormal().X() <<
", " << plane->getNormal().Y() <<
", " << plane->getNormal().Z() <<
")" << std::endl;
451 TVectorD stateVector(state->getState());
452 TMatrixDSym cov(state->getCov());
456 if (!squareRootFormalism_) {
458 const std::vector<MeasurementOnPlane *>& measurements =
getMeasurements(fi, tp, direction);
459 for (std::vector<MeasurementOnPlane *>::const_iterator it = measurements.begin(); it != measurements.end(); ++it) {
461 const double weight = mOnPlane.getWeight();
464 debugOut <<
"Weight of measurement: " << weight <<
"\n";
469 debugOut <<
"Weight of measurement is almost 0, continue ... \n";
474 const TVectorD& measurement(mOnPlane.getState());
478 1./weight * mOnPlane.getCov() :
482 debugOut <<
"State prediction: "; stateVector.Print();
483 debugOut <<
"Cov prediction: "; state->getCov().Print();
486 debugOut <<
"measurement: "; measurement.Print();
487 debugOut <<
"measurement covariance V: "; V.Print();
492 TVectorD res(measurement - H->Hv(stateVector));
494 debugOut <<
"Residual = (" << res(0);
495 if (res.GetNrows() > 1)
505 TMatrixDSym covSumInv(cov);
508 tools::invertMatrix(covSumInv);
510 TMatrixD CHt(H->MHt(cov));
511 TVectorD update(TMatrixD(CHt, TMatrixD::kMult, covSumInv) * res);
518 debugOut <<
"Update: "; update.Print();
523 stateVector += update;
524 covSumInv.Similarity(CHt);
530 debugOut <<
"updated state: "; stateVector.Print();
531 debugOut <<
"updated cov: "; cov.Print();
534 TVectorD resNew(measurement - H->Hv(stateVector));
536 debugOut <<
"Residual New = (" << resNew(0);
538 if (resNew.GetNrows() > 1)
545 TMatrixDSym HCHt(cov);
550 tools::invertMatrix(HCHt);
552 chi2inc += HCHt.Similarity(resNew);
555 ndfInc += weight * measurement.GetNrows();
558 ndfInc += measurement.GetNrows();
561 debugOut <<
"chi² increment = " << chi2inc << std::endl;
573 TDecompChol decompCov(cov);
574 decompCov.Decompose();
575 TMatrixD S(decompCov.GetU());
577 const std::vector<MeasurementOnPlane *>& measurements =
getMeasurements(fi, tp, direction);
578 for (std::vector<MeasurementOnPlane *>::const_iterator it = measurements.begin(); it != measurements.end(); ++it) {
580 const double weight = mOnPlane.getWeight();
583 debugOut <<
"Weight of measurement: " << weight <<
"\n";
588 debugOut <<
"Weight of measurement is almost 0, continue ... \n";
593 const TVectorD& measurement(mOnPlane.getState());
597 1./weight * mOnPlane.getCov() :
601 debugOut <<
"State prediction: "; stateVector.Print();
602 debugOut <<
"Cov prediction: "; state->getCov().Print();
605 debugOut <<
"measurement: "; measurement.Print();
606 debugOut <<
"measurement covariance V: "; V.Print();
611 TVectorD res(measurement - H->Hv(stateVector));
613 debugOut <<
"Residual = (" << res(0);
614 if (res.GetNrows() > 1)
620 TDecompChol decompR(V);
622 const TMatrixD&
R(decompR.GetU());
624 TVectorD update(stateVector.GetNrows());
625 tools::kalmanUpdateSqrt(S, res,
R, H,
627 stateVector += update;
634 debugOut <<
"updated state: "; stateVector.Print();
635 debugOut <<
"updated cov: "; TMatrixDSym(TMatrixDSym::kAtA, S).Print() ;
638 res -= H->Hv(update);
640 debugOut <<
"Residual New = (" << res(0);
642 if (res.GetNrows() > 1)
653 TMatrixDSym HCHt(TMatrixDSym::kAtA, H->MHt(S));
657 tools::invertMatrix(HCHt);
659 chi2inc += HCHt.Similarity(res);
662 ndfInc += weight * measurement.GetNrows();
665 ndfInc += measurement.GetNrows();
668 debugOut <<
"chi² increment = " << chi2inc << std::endl;
672 cov = TMatrixDSym(TMatrixDSym::kAtA, S);
675 currentState_->setStateCovPlane(stateVector, cov, plane);
676 currentState_->setAuxInfo(state->getAuxInfo());
683 fi->setUpdate(updatedSOP, direction);
688 void KalmanFitter::Streamer(TBuffer &R__b)
693 typedef ::genfit::KalmanFitter thisClass;
695 if (R__b.IsReading()) {
696 Version_t R__v = R__b.ReadVersion(&R__s, &R__c);
if (R__v) { }
699 baseClass0::Streamer(R__b);
702 currentState_.reset(p);
703 R__b.CheckByteCount(R__s, R__c, thisClass::IsA());
705 R__c = R__b.WriteVersion(thisClass::IsA(), kTRUE);
708 baseClass0::Streamer(R__b);
709 R__b << currentState_.get();
710 R__b.SetByteCount(R__c, kTRUE);
HMatrix for projecting from AbsTrackRep parameters to measured parameters in a DetPlane.
Abstract base class for Kalman fitter and derived fitting algorithms.
bool canIgnoreWeights() const
returns if the fitter can ignore the weights and handle the MeasurementOnPlanes as if they had weight...
double deltaPval_
Convergence criterion.
double blowUpFactor_
Blow up the covariance of the forward (backward) fit by this factor before seeding the backward (forw...
int maxFailedHits_
after how many failed hits (exception during construction of plane, extrapolation etc....
double relChi2Change_
@ brief Non-convergence criterion
unsigned int minIterations_
Minimum number of iterations to attempt. Forward and backward are counted as one iteration.
double blowUpMaxVal_
Limit the cov entries to this maxuimum value when blowing up the cov.
unsigned int maxIterations_
Maximum number of iterations to attempt. Forward and backward are counted as one iteration.
eMultipleMeasurementHandling multipleMeasurementHandling_
How to handle if there are multiple MeasurementsOnPlane.
bool resetOffDiagonals_
Reset the off-diagonals to 0 when blowing up the cov.
const std::vector< MeasurementOnPlane * > getMeasurements(const KalmanFitterInfo *fi, const TrackPoint *tp, int direction) const
get the measurementsOnPlane taking the multipleMeasurementHandling_ into account
Abstract base class for a track representation.
virtual void setTime(StateOnPlane &state, double time) const =0
Set time at which the state was defined.
virtual unsigned int getDim() const =0
Get the dimension of the state vector used by the track representation.
virtual void setPosMomCov(MeasuredStateOnPlane &state, const TVector3 &pos, const TVector3 &mom, const TMatrixDSym &cov6x6) const =0
Set position, momentum and covariance of state.
virtual double getQop(const StateOnPlane &state) const =0
Get charge over momentum.
virtual void setQop(StateOnPlane &state, double qop) const =0
Set charge/momentum.
virtual double extrapolateToPlane(StateOnPlane &state, const genfit::SharedPlanePtr &plane, bool stopAtBoundary=false, bool calcJacobianNoise=false) const =0
Extrapolates the state to plane, and returns the extrapolation length and, via reference,...
virtual void getPosMomCov(const MeasuredStateOnPlane &state, TVector3 &pos, TVector3 &mom, TMatrixDSym &cov) const =0
Translates MeasuredStateOnPlane into 3D position, momentum and 6x6 covariance.
Exception class for error handling in GENFIT (provides storage for diagnostic information)
bool isTrackPruned() const
Has the track been pruned after the fit?
FitStatus for use with AbsKalmanFitter implementations.
#MeasuredStateOnPlane with additional info produced by a Kalman filter or DAF.
Collects information needed and produced by a AbsKalmanFitter implementations and is specific to one ...
const MeasuredStateOnPlane & getFittedState(bool biased=true) const override
Get unbiased or biased (default) smoothed state.
void processTrackPartially(Track *tr, const AbsTrackRep *rep, int startId=0, int endId=-1)
process only a part of the track.
void processTrackWithRep(Track *tr, const AbsTrackRep *rep, bool resortHits=false) override
Hit resorting currently NOT supported.
#StateOnPlane with additional covariance matrix.
Measured coordinates on a plane.
Object containing AbsMeasurement and AbsFitterInfo objects.
AbsFitterInfo * getFitterInfo(const AbsTrackRep *rep=nullptr) const
Get fitterInfo for rep. Per default, use cardinal rep.
void setFitterInfo(genfit::AbsFitterInfo *fitterInfo)
Takes Ownership.
Collection of TrackPoint objects, AbsTrackRep objects and FitStatus objects.
bool hasFitStatus(const AbsTrackRep *rep=nullptr) const
Check if track has a FitStatus for given AbsTrackRep. Per default, check for cardinal rep.
FitStatus * getFitStatus(const AbsTrackRep *rep=nullptr) const
Get FitStatus for a AbsTrackRep. Per default, return FitStatus for cardinalRep.
AbsTrackRep * getCardinalRep() const
Get cardinal track representation.
Defines for I/O streams used for error and debug printing.
std::ostream debugOut
Default stream for debug output.
std::shared_ptr< genfit::DetPlane > SharedPlanePtr
Shared Pointer to a DetPlane.
@ weightedClosestToReferenceWire
if corresponding TrackPoint has one WireMeasurement, select closest to reference, weighted with its w...
@ weightedClosestToReference
closest to reference, weighted with its weight_
@ unweightedClosestToReference
closest to reference, weighted with 1
@ unweightedClosestToReferenceWire
if corresponding TrackPoint has one WireMeasurement, select closest to reference, weighted with 1.
std::ostream errorOut
Default stream for error output.