24 #include "KalmanFitter.h"
26 #include "Exception.h"
27 #include "KalmanFitterInfo.h"
28 #include "KalmanFitStatus.h"
30 #include "TrackPoint.h"
34 #include <Math/ProbFunc.h>
36 #include <TDecompChol.h>
37 #include <TMatrixDSymEigen.h>
44 double& chi2,
double& ndf,
45 int startId,
int endId,
int& nFailedHits)
52 Exception exc(
"KalmanFitter::fitTrack ==> cannot use (un)weightedClosestToReference(Wire) as multiple measurement handling.",__LINE__,__FILE__);
58 startId += tr->getNumPointsWithMeasurement();
60 endId += tr->getNumPointsWithMeasurement();
72 debugOut << tr->getNumPointsWithMeasurement() <<
" TrackPoints w/ measurement in this track." << std::endl;
74 for (
int i = startId; ; i+=direction) {
75 TrackPoint *tp = tr->getPointWithMeasurement(i);
76 assert(direction == +1 || direction == -1);
79 debugOut <<
" process TrackPoint nr. " << i <<
" (" << tp <<
")\n";
83 processTrackPoint(tp, rep, chi2, ndf, direction);
90 tr->getPoint(i)->deleteFitterInfo(rep);
96 debugOut <<
"There was an exception, try to continue with next TrackPoint " << i+direction <<
" \n";
116 Exception exc(
"KalmanFitter::processTrack: Cannot process pruned track!", __LINE__,__FILE__);
120 TrackPoint* trackPoint = tr->getPointWithMeasurement(0);
122 if (trackPoint->hasFitterInfo(rep) &&
128 debugOut <<
"take backward update of previous iteration as seed \n";
135 debugOut <<
"take smoothed state of cardinal rep fit as seed \n";
146 rep->
setTime(*currentState_, tr->getTimeSeed());
147 rep->
setPosMomCov(*currentState_, tr->getStateSeed(), tr->getCovSeed());
149 debugOut <<
"take seed state of track as seed \n";
156 double oldChi2FW(1.e6);
157 double oldChi2BW(1.e6);
158 double oldPvalFW(0.);
160 double oldPvalBW = 0.;
161 double chi2FW(0), ndfFW(0);
162 double chi2BW(0), ndfBW(0);
164 int nFailedHitsForward(0), nFailedHitsBackward(0);
168 tr->setFitStatus(status, rep);
170 unsigned int nIt = 0;
174 debugOut <<
"\033[1;21mstate pre" << std::endl;
175 currentState_->Print();
176 debugOut <<
"\033[0mfitting" << std::endl;
179 if (!fitTrack(tr, rep, chi2FW, ndfFW, 0, -1, nFailedHitsForward)) {
180 status->setIsFitted(
false);
181 status->setIsFitConvergedFully(
false);
182 status->setIsFitConvergedPartially(
false);
183 status->setNFailedPoints(nFailedHitsForward);
188 debugOut <<
"\033[1;21mstate post forward" << std::endl;
189 currentState_->Print();
196 if (!fitTrack(tr, rep, chi2BW, ndfBW, -1, 0, nFailedHitsBackward)) {
197 status->setIsFitted(
false);
198 status->setIsFitConvergedFully(
false);
199 status->setIsFitConvergedPartially(
false);
200 status->setNFailedPoints(nFailedHitsBackward);
205 debugOut <<
"\033[1;21mstate post backward" << std::endl;
206 currentState_->Print();
209 debugOut <<
"old chi2s: " << oldChi2BW <<
", " << oldChi2FW
210 <<
" new chi2s: " << chi2BW <<
", " << chi2FW
211 <<
" oldPvals " << oldPvalFW <<
", " << oldPvalBW << std::endl;
215 double PvalBW = std::max(0.,ROOT::Math::chisquared_cdf_c(chi2BW, ndfBW));
216 double PvalFW = (debugLvl_ > 0) ? std::max(0.,ROOT::Math::chisquared_cdf_c(chi2FW, ndfFW)) : 0;
222 bool converged(
false);
223 bool finished(
false);
249 if (nFailedHitsForward == 0 && nFailedHitsBackward == 0)
250 status->setIsFitConvergedFully(converged);
252 status->setIsFitConvergedFully(
false);
254 status->setIsFitConvergedPartially(converged);
256 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
272 debugOut <<
"KalmanFitter::number of max iterations reached!\n";
278 status->setIsFitted(
false);
279 status->setIsFitConvergedFully(
false);
280 status->setIsFitConvergedPartially(
false);
281 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
286 status->setIsFitted();
288 TrackPoint* tp = tr->getPointWithMeasurementAndFitterInfo(0, rep);
293 status->setNFailedPoints(std::max(nFailedHitsForward, nFailedHitsBackward));
294 status->setCharge(charge);
295 status->setNumIterations(nIt);
296 status->setForwardChi2(chi2FW);
297 status->setBackwardChi2(chi2BW);
298 status->setForwardNdf(std::max(0., ndfFW));
299 status->setBackwardNdf(std::max(0., ndfBW));
307 Exception exc(
"KalmanFitter::processTrack: Cannot process pruned track!", __LINE__,__FILE__);
312 startId += tr->getNumPointsWithMeasurement();
314 endId += tr->getNumPointsWithMeasurement();
321 TrackPoint* trackPoint = tr->getPointWithMeasurement(startId);
325 if (direction == 1 && startId > 0)
326 prevTrackPoint = tr->getPointWithMeasurement(startId - 1);
327 else if (direction == -1 && startId < (
int)tr->getNumPointsWithMeasurement() - 1)
328 prevTrackPoint = tr->getPointWithMeasurement(startId + 1);
331 if (prevTrackPoint !=
nullptr &&
332 prevTrackPoint->hasFitterInfo(rep) &&
337 debugOut <<
"take update of previous fitter info as seed \n";
344 debugOut <<
"take smoothed state of cardinal rep fit as seed \n";
355 rep->
setTime(*currentState_, tr->getTimeSeed());
356 rep->
setPosMomCov(*currentState_, tr->getStateSeed(), tr->getCovSeed());
358 debugOut <<
"take seed of track as seed \n";
362 if (startId == 0 || startId == (
int)tr->getNumPointsWithMeasurement() - 1) {
369 debugOut <<
"\033[1;21mstate pre" << std::endl;
370 currentState_->Print();
371 debugOut <<
"\033[0mfitting" << std::endl;
376 fitTrack(tr, rep, chi2, ndf, startId, endId, nFailedHits);
382 KalmanFitter::processTrackPoint(
TrackPoint* tp,
383 const AbsTrackRep* rep,
double& chi2,
double& ndf,
int direction)
385 assert(direction == -1 || direction == +1);
387 if (!tp->hasRawMeasurements())
390 bool newFi(!tp->hasFitterInfo(rep));
401 bool oldWeightsFixed(
false);
402 std::vector<double> oldWeights;
407 oldWeights = fi->getWeights();
408 oldWeightsFixed = fi->areWeightsFixed();
411 fi->deleteForwardInfo();
412 fi->deleteBackwardInfo();
413 fi->deleteMeasurementInfo();
416 const std::vector< genfit::AbsMeasurement* >& rawMeasurements = tp->getRawMeasurements();
417 plane = rawMeasurements[0]->constructPlane(*currentState_);
420 plane =
fi->getPlane();
424 debugOut <<
"extrapolated by " << extLen << std::endl;
426 fi->setPrediction(currentState_->clone(), direction);
431 const std::vector< genfit::AbsMeasurement* >& rawMeasurements = tp->getRawMeasurements();
432 for (std::vector< genfit::AbsMeasurement* >::const_iterator it = rawMeasurements.begin(); it != rawMeasurements.end(); ++it) {
433 fi->addMeasurementsOnPlane((*it)->constructMeasurementsOnPlane(*state));
435 if (oldWeights.size() ==
fi->getNumMeasurements()) {
436 fi->setWeights(oldWeights);
437 fi->fixWeights(oldWeightsFixed);
442 assert(
fi->getPlane() == plane);
443 assert(
fi->checkConsistency());
447 debugOut <<
"its plane is at R = " << plane->getO().Perp()
448 <<
" with normal pointing along (" << plane->getNormal().X() <<
", " << plane->getNormal().Y() <<
", " << plane->getNormal().Z() <<
")" << std::endl;
451 TVectorD stateVector(state->getState());
452 TMatrixDSym cov(state->getCov());
456 if (!squareRootFormalism_) {
458 const std::vector<MeasurementOnPlane *>& measurements =
getMeasurements(fi, tp, direction);
459 for (std::vector<MeasurementOnPlane *>::const_iterator it = measurements.begin(); it != measurements.end(); ++it) {
461 const double weight = mOnPlane.getWeight();
464 debugOut <<
"Weight of measurement: " << weight <<
"\n";
469 debugOut <<
"Weight of measurement is almost 0, continue ... \n";
474 const TVectorD& measurement(mOnPlane.getState());
478 1./weight * mOnPlane.getCov() :
482 debugOut <<
"State prediction: "; stateVector.Print();
483 debugOut <<
"Cov prediction: "; state->getCov().Print();
486 debugOut <<
"measurement: "; measurement.Print();
487 debugOut <<
"measurement covariance V: "; V.Print();
492 TVectorD res(measurement - H->Hv(stateVector));
494 debugOut <<
"Residual = (" << res(0);
495 if (res.GetNrows() > 1)
505 TMatrixDSym covSumInv(cov);
508 tools::invertMatrix(covSumInv);
510 TMatrixD CHt(H->MHt(cov));
511 TVectorD update(TMatrixD(CHt, TMatrixD::kMult, covSumInv) * res);
518 debugOut <<
"Update: "; update.Print();
523 stateVector += update;
524 covSumInv.Similarity(CHt);
530 debugOut <<
"updated state: "; stateVector.Print();
531 debugOut <<
"updated cov: "; cov.Print();
534 TVectorD resNew(measurement - H->Hv(stateVector));
536 debugOut <<
"Residual New = (" << resNew(0);
538 if (resNew.GetNrows() > 1)
545 TMatrixDSym HCHt(cov);
550 tools::invertMatrix(HCHt);
552 chi2inc += HCHt.Similarity(resNew);
555 ndfInc += weight * measurement.GetNrows();
558 ndfInc += measurement.GetNrows();
561 debugOut <<
"chi² increment = " << chi2inc << std::endl;
573 TDecompChol decompCov(cov);
574 decompCov.Decompose();
575 TMatrixD S(decompCov.GetU());
577 const std::vector<MeasurementOnPlane *>& measurements =
getMeasurements(fi, tp, direction);
578 for (std::vector<MeasurementOnPlane *>::const_iterator it = measurements.begin(); it != measurements.end(); ++it) {
580 const double weight = mOnPlane.getWeight();
583 debugOut <<
"Weight of measurement: " << weight <<
"\n";
588 debugOut <<
"Weight of measurement is almost 0, continue ... \n";
593 const TVectorD& measurement(mOnPlane.getState());
597 1./weight * mOnPlane.getCov() :
601 debugOut <<
"State prediction: "; stateVector.Print();
602 debugOut <<
"Cov prediction: "; state->getCov().Print();
605 debugOut <<
"measurement: "; measurement.Print();
606 debugOut <<
"measurement covariance V: "; V.Print();
611 TVectorD res(measurement - H->Hv(stateVector));
613 debugOut <<
"Residual = (" << res(0);
614 if (res.GetNrows() > 1)
620 TDecompChol decompR(V);
622 const TMatrixD& R(decompR.GetU());
624 TVectorD update(stateVector.GetNrows());
625 tools::kalmanUpdateSqrt(S, res, R, H,
627 stateVector += update;
634 debugOut <<
"updated state: "; stateVector.Print();
635 debugOut <<
"updated cov: "; TMatrixDSym(TMatrixDSym::kAtA, S).Print() ;
638 res -= H->Hv(update);
640 debugOut <<
"Residual New = (" << res(0);
642 if (res.GetNrows() > 1)
653 TMatrixDSym HCHt(TMatrixDSym::kAtA, H->MHt(S));
657 tools::invertMatrix(HCHt);
659 chi2inc += HCHt.Similarity(res);
662 ndfInc += weight * measurement.GetNrows();
665 ndfInc += measurement.GetNrows();
668 debugOut <<
"chi² increment = " << chi2inc << std::endl;
672 cov = TMatrixDSym(TMatrixDSym::kAtA, S);
675 currentState_->setStateCovPlane(stateVector, cov, plane);
676 currentState_->setAuxInfo(state->getAuxInfo());
683 fi->setUpdate(updatedSOP, direction);
688 void KalmanFitter::Streamer(TBuffer &R__b)
693 typedef ::genfit::KalmanFitter thisClass;
695 if (R__b.IsReading()) {
696 Version_t R__v = R__b.ReadVersion(&R__s, &R__c);
if (R__v) { }
699 baseClass0::Streamer(R__b);
702 currentState_.reset(p);
703 R__b.CheckByteCount(R__s, R__c, thisClass::IsA());
705 R__c = R__b.WriteVersion(thisClass::IsA(), kTRUE);
708 baseClass0::Streamer(R__b);
709 R__b << currentState_.get();
710 R__b.SetByteCount(R__c, kTRUE);