Skip to content

Commit

Permalink
Updated kf functions to work inside the manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Valerio Pia committed Oct 15, 2024
1 parent f007aac commit 7f8b25e
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 688 deletions.
34 changes: 24 additions & 10 deletions include/STTKFKalmanFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@
#include "TGeoManager.h"
#include "TMatrixD.h"

#include "SANDTrackerCluster.h"
#include "SANDTrackerClusterCollection.h"
#include "STTKFTrack.h"
#include "SANDTrackerUtils.h"
#include "STTKFGeoManager.h"

struct SParticleInfo {
int charge;
double mass;
int pdg_code;
int id;

TVector3 pos;
TVector3 mom;
};

class STTKFChecks;

using STTKFStateCovarianceMatrix = TMatrixD;
using STTKFMeasurement = TMatrixD;
using TrackletMap = std::map<double, std::vector<TVectorD>>;

class STTKFKalmanFilterManager {

Expand All @@ -22,7 +32,9 @@ class STTKFKalmanFilterManager {

STTKFTrackStep::STTKFTrackStateStage fCurrentStage; // forward or backward
int fCurrentStep; // index of the STTKFTrackStep in STTKFTrack

double fCurrentZ;
TrackletMap* z_to_tracklets_;
SParticleInfo particleInfo_;


public:
Expand Down Expand Up @@ -53,6 +65,7 @@ class STTKFKalmanFilterManager {
public:
Orientation fCurrentOrientation = Orientation::kVertical;

STTKFMeasurement GetMeasurementFromTracklet(const TVectorD& tracklet);
double DeltaRadius(const STTKFStateVector& stateVector, double nextPhi, double dZ, double dE, double particle_mass) const;
inline double DEDTanl(const STTKFStateVector& stateVector, double nextPhi, double dZ, double dE, double particle_mass) const { auto tan = stateVector.TanLambda(); return dE * tan / (1 + tan*tan); };
inline double DEDPhi(const STTKFStateVector& stateVector, double nextPhi, double dZ, double dE, double particle_mass) const {
Expand Down Expand Up @@ -191,14 +204,15 @@ class STTKFKalmanFilterManager {

STTKFMeasurement GetPrediction(Orientation orientation, const STTKFStateVector& stateVector);

// void Propagate(double dE, const STTPlaneID& nextPlaneID);
double EvalChi2(const STTKFMeasurement& observation, const STTKFMeasurement& prediction, const TMatrixD& measurementNoiseMatrix);
// int FindBestMatch(const std::vector<int>& clusterIDs, const STTKFMeasurement& prediction, const TMatrixD& measurementNoiseMatrix);
// void Filter(const STTKFMeasurement& observation, const STTKFMeasurement& prediction, Orientation orientation);
// void Smooth();
// void Init(const STTPlaneID& plane, int clusterID);
// void Run();
// const STTKFTrack& GetTrack() {return fThisTrack; };
void Propagate(double& dE, double& dZ, double& beta);
double EvalChi2(const STTKFMeasurement& measurement, const STTKFMeasurement& prediction, const TMatrixD& measurementNoiseMatrix);
int FindBestMatch(double& nextZ, const STTKFMeasurement& prediction, const TMatrixD& measurementNoiseMatrix);
void SetNextOrientation();
void Filter(const STTKFMeasurement& measurement, const STTKFMeasurement& prediction);
void Smooth();
void InitFromMC(TrackletMap* z_to_tracklets, const SParticleInfo& particloInfo);
void Run();
const STTKFTrack& GetTrack() {return fThisTrack; };

friend class STTKFChecks;
};
Expand Down
6 changes: 4 additions & 2 deletions include/STTKFTrack.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class STTKFTrackStep {
STTKFState fSmoothed;

// the propagation that bring the vector in this state
// TMatrixD fPropagatorMatrix;
TMatrixD fPropagatorMatrix;
// TMatrixD fProjectionMatrix;
// TMatrixD fProcessNoiseMatrix;
// TMatrixD fMeasurementNoiseMatrix;
Expand All @@ -125,7 +125,7 @@ class STTKFTrackStep {
int fClusterID;

public:
// STTKFTrackStep(): fPropagatorMatrix(5,5),
STTKFTrackStep(): fPropagatorMatrix(5,5) {};
// fProjectionMatrix(2,5),
// fProcessNoiseMatrix(5,5),
// fMeasurementNoiseMatrix(2,2),
Expand All @@ -137,6 +137,8 @@ class STTKFTrackStep {
int GetClusterIDForThisState() const { return fClusterID; }
void SetStage(STTKFTrackStateStage stage, STTKFState state);
const STTKFState& GetStage(STTKFTrackStateStage stage) const;
void SetPropagatorMatrix(TMatrixD pMatrix) { fPropagatorMatrix = pMatrix; };
const TMatrixD GetPropagatorMatrix() { return fPropagatorMatrix; };
};

class STTKFTrack {
Expand Down
235 changes: 21 additions & 214 deletions src/SANDMeasurementsBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,23 @@

#include "EDEPTree.h"

struct SParticleInfo {
int charge;
double mass;
int pdg_code;
int id;
void TryCompleteManager(TrackletMap z_to_tracklets, SParticleInfo particleInfo) {
STTKFKalmanFilterManager manager;
manager.InitFromMC(&z_to_tracklets, particleInfo);
manager.Run();

TVector3 pos;
TVector3 mom;
};
auto track = manager.GetTrack();

auto step = track.GetSteps().back();
auto reco_state =
step.GetStage(STTKFTrackStep::STTKFTrackStateStage::kSmoothing).GetStateVector();
auto reco_mom = SANDTrackerUtils::GetMomentumInMeVFromRadiusInMM(
reco_state.Radius(), reco_state.TanLambda());

std::cout << "Initial Smoothed Reco Momentum " << reco_mom << std::endl;

return;
}

void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vector<dg_wire>* digits)
{
Expand All @@ -39,9 +47,10 @@ void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vecto

SANDTrackerDigitCollection::FillMap(digits);
auto digit_map = SANDTrackerDigitCollection::GetDigits();
if (SANDTrackerDigitCollection::GetDigits().empty()) {
return;
}
std::string tracker_name = SANDTrackerDigitCollection::GetDigits().begin()->det;
// std::cout << "TRACKER: " << tracker_name << std::endl;

SANDTrackerClusterCollection clusters(sand_geo, SANDTrackerDigitCollection::GetDigits(), SANDTrackerClusterCollection::ClusteringMethod::kCellAdjacency);

TrackletFinder traklet_finder;
Expand Down Expand Up @@ -76,10 +85,8 @@ void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vecto

int sum = 0;
for (auto el:z_to_tracklets) {
// std::cout << "At z = " << el.first << " there are " << el.second.size() << " tracklets" << std::endl;
sum += el.second.size();
}
std::cout << "Total tracklets: " << sum << std::endl;
if (sum == 0) {
return;
}
Expand All @@ -91,9 +98,6 @@ void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vecto
tree.Filter(std::back_insert_iterator<std::vector<EDEPTrajectory>>(primaryTrj),
[](const EDEPTrajectory& trj) { return trj.GetParentId() == -1;} );

// std::string print;
// for (auto trj:primaryTrj) trj.Print(print);

TDatabasePDG pdg_db;
std::vector<SParticleInfo> particleInfos;
for (auto trj:primaryTrj) {
Expand All @@ -109,10 +113,7 @@ void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vecto
pi.mom = trj.GetTrajectoryPoints().at(string_to_component[tracker_name]).back().GetMomentum();
particleInfos.push_back(pi);

// std::cout << pi.pdg_code << " " << pi.id << " " << pi.mass << " " << pi.charge << std::endl;
std::cout << "Initial Momentum " << trj.GetInitialMomentum().Vect().Mag() << std::endl;
// pi.pos.Print();
// pi.mom.Print();
}

int nParticles = particleInfos.size();
Expand All @@ -124,202 +125,8 @@ void ProcessEventWithKF(SANDGeoManager* sand_geo, TG4Event* mc_event, std::vecto
}

for (int ip = 0; ip < nParticles; ip++) {
int charge = particleInfos[ip].charge;
double particle_mass = particleInfos[ip].mass;
int pdg = particleInfos[ip].pdg_code;

STTKFTrack this_track;
STTKFStateVector current_state;
STTKFKalmanFilterManager manager;
STTKFKalmanFilterManager::Orientation current_orientation = manager.GetOrientation();

// To Do: find a smarter algorithm to compute this
TMatrixD initial_cov_matrix(5, 5);
initial_cov_matrix[0][0] = pow(200E-6, 2);
initial_cov_matrix[1][1] = pow(200E-6, 2);
initial_cov_matrix[2][2] = pow(0.1, 2);
initial_cov_matrix[3][3] = pow(0.01, 2);
initial_cov_matrix[4][4] = pow(0.01, 2);


// To Do: implment a seeding algorithm
STTKFStateVector initial_state_vector = STTKFCheck::get_state_vector(particleInfos[ip].mom * 1E-3, // GeV
particleInfos[ip].pos * 1E-3, // m
particleInfos[ip].charge);


// initial_cov_matrix.Print();


std::vector<STTKFStateCovarianceMatrix> propagator_matrices;

STTKFTrackStep trackStep;
trackStep.SetStage(STTKFTrackStep::STTKFTrackStateStage::kPrediction,
STTKFState(initial_state_vector, initial_cov_matrix));
trackStep.SetStage(STTKFTrackStep::STTKFTrackStateStage::kFiltering,
STTKFState(initial_state_vector, initial_cov_matrix));

propagator_matrices.push_back(initial_cov_matrix);

this_track.AddStep(trackStep);


double previous_z = particleInfos[ip].pos.Z();
auto current_z_it = std::prev(z_to_tracklets.lower_bound(particleInfos[ip].pos.Z()));
for ( ; current_z_it != z_to_tracklets.begin(); current_z_it--) {

// Get previous kf step
auto previous_step =
this_track.GetStep(this_track.GetSteps().size() - 1);
auto previousStateVector =
previous_step
.GetStage(STTKFTrackStep::STTKFTrackStateStage::kFiltering)
.GetStateVector();
auto previousCovMatrix =
previous_step
.GetStage(STTKFTrackStep::STTKFTrackStateStage::kFiltering)
.GetStateCovMatrix();


// Compute energy loss for the new step
auto dir = -1. * manager.GetDirectiveCosinesFromStateVector(
previousStateVector);

// To Do: check if this is still valid and add a real fix if needed
if (dir.Z() > 0) {
dir *= -1;
}

auto current_mom = SANDTrackerUtils::GetMomentumInMeVFromRadiusInMM(
previousStateVector.Radius(),
previousStateVector.TanLambda()) / 1000;
double gamma = sqrt(current_mom * current_mom + particle_mass * particle_mass) /
particle_mass;
double beta_from_gamma = sqrt(1 - pow(1 / gamma, 2));

// std::cout << previous_z << " "
// << previousStateVector.X() << " "
// << previousStateVector.Y() << " "
// << current_z_it->first << " " << std::endl;
// dir.Print();

// To Do: check all units
auto de_step = STTKFGeoManager::GetDE(
current_z_it->first,
1000 * previousStateVector.X(), 1000 * previousStateVector.Y(), previous_z,
dir.X(), dir.Y(), dir.Z(),
beta_from_gamma, particle_mass, charge) / 1000;
// std::cout << "de_step " << de_step << std::endl;

double dz = (current_z_it->first - previous_z) / 1000;
// std::cout << dz << std::endl;

// Propagation of state and cov matrix
auto predictedStateVector = manager.PropagateState(
previousStateVector, dz, de_step, particle_mass);
auto nextPhi = predictedStateVector.Phi();
TMatrixD covariance_noise(5, 5);
covariance_noise = manager.GetProcessNoiseMatrix(
previousStateVector, nextPhi, dz, de_step, previous_z,
particle_mass);

auto propagatorMatrix = manager.GetPropagatorMatrix(
previousStateVector, nextPhi, dz, de_step, particle_mass);

auto predictedCovMatrix = manager.PropagateCovMatrix(
previousCovMatrix, propagatorMatrix, covariance_noise);


// Filling prediction step of track
STTKFTrackStep currentTrackState;
currentTrackState.SetStage(
STTKFTrackStep::STTKFTrackStateStage::kPrediction,
STTKFState(predictedStateVector, predictedCovMatrix));

// previousStateVector().Print();
// predictedStateVector().Print();

// Prediction and filtering on the next plane
STTKFMeasurement prediction =
manager.GetPrediction(current_orientation, predictedStateVector);
auto projectionMatrix =
manager.GetProjectionMatrix(current_orientation, predictedStateVector);
auto measurementNoiseMatrix = manager.GetMeasurementNoiseMatrix();

auto kalmanGainMatrix = manager.GetKalmanGainMatrix(
predictedCovMatrix, projectionMatrix, measurementNoiseMatrix);

TMatrixD projectionMatrixTransposed(TMatrixD::kTransposed,
projectionMatrix);
auto Sk = measurementNoiseMatrix + projectionMatrix *
predictedCovMatrix *
projectionMatrixTransposed;


// prediction.Print();
double best_chi = 1E9;
STTKFMeasurement best_measurement(2, 1);
for (const auto& tracklet:current_z_it->second) {
STTKFMeasurement measurement(2, 1);
// To Do: vertical and horizontal are outdated and confusing. Replace with something more meaningful.
// Notice: vertical planes means horizontal measurements and the opposite
if (current_orientation == STTKFKalmanFilterManager::Orientation::kVertical) {
measurement[0][0] = tracklet[0] / 1000.;
measurement[1][0] = M_PI_2 - tracklet[2];
} else {
measurement[0][0] = tracklet[1] / 1000.;
measurement[1][0] = tracklet[3];
}
// measurement[0][0] += rand->Gaus(0, sigmaPos);
// measurement[1][0] += rand->Gaus(0, sigmaAng);
// measurement.Print();
// std::cout << tracklet[0] << " " << tracklet[1] << " " << tracklet[2] << " " << tracklet[3] << std::endl;
auto chi2 = manager.EvalChi2(measurement, prediction, Sk);
if (chi2 < best_chi) {
best_chi = chi2;
best_measurement = measurement;
}
}
if (best_chi > 1.5) {
continue;
}
propagator_matrices.push_back(propagatorMatrix);

// std::cout << "chi2 " << best_chi << std::endl;
// best_measurement.Print();

auto filteredStateVector =
manager.FilterState(predictedStateVector, kalmanGainMatrix,
best_measurement, prediction);
// filteredStateVector().Print();

auto filteredCovMatrix = manager.FilterCovMatrix(
predictedCovMatrix, projectionMatrix, measurementNoiseMatrix);

currentTrackState.SetStage(
STTKFTrackStep::STTKFTrackStateStage::kFiltering,
STTKFState(filteredStateVector, filteredCovMatrix));
this_track.AddStep(currentTrackState);


if (current_orientation == STTKFKalmanFilterManager::Orientation::kVertical) {
current_orientation = STTKFKalmanFilterManager::Orientation::kHorizontal;
} else {
current_orientation = STTKFKalmanFilterManager::Orientation::kVertical;
}

previous_z = current_z_it->first;
}

auto reco_state =
this_track.GetSteps()[this_track.GetSteps().size() - 1]
.GetStage(STTKFTrackStep::STTKFTrackStateStage::kFiltering)
.GetStateVector();
auto reco_mom = SANDTrackerUtils::GetMomentumInMeVFromRadiusInMM(
reco_state.Radius(), reco_state.TanLambda());
std::cout << "Initial Reco Momentum " << reco_mom << std::endl;
TryCompleteManager(z_to_tracklets, particleInfos[ip]);
}

}

int main(int argc, char* argv[])
Expand Down Expand Up @@ -353,7 +160,7 @@ int main(int argc, char* argv[])
SANDGeoManager sand_geo;
sand_geo.init(geo);

for (int i = 2; i < 3; i++) {
for (int i = 0; i < 30; i++) {
t_h->GetEntry(i);
t->GetEntry(i);

Expand Down
Loading

0 comments on commit 7f8b25e

Please sign in to comment.