From e07751da1688e1d4a55fba356446d0ee44ca84d7 Mon Sep 17 00:00:00 2001
From: Felix Riehn <felix@matilda>
Date: Thu, 8 Apr 2021 16:49:43 +0100
Subject: [PATCH] added getter for kinetic energy

---
 corsika/detail/stack/NuclearStackExtension.inl |  7 +++++++
 corsika/stack/NuclearStackExtension.hpp        |  7 +++++++
 corsika/stack/VectorStack.hpp                  |  2 ++
 tests/stack/testNuclearStackExtension.cpp      | 15 ++++++++++++++-
 tests/stack/testVectorStack.cpp                |  6 ++++++
 5 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/corsika/detail/stack/NuclearStackExtension.inl b/corsika/detail/stack/NuclearStackExtension.inl
index 5dab28ab2..e8b72b128 100644
--- a/corsika/detail/stack/NuclearStackExtension.inl
+++ b/corsika/detail/stack/NuclearStackExtension.inl
@@ -121,6 +121,13 @@ namespace corsika::nuclear_stack {
     return super_type::getCharge();
   }
 
+  template <template <typename> class InnerParticleInterface,
+            typename StackIteratorInterface>
+  inline HEPEnergyType NuclearParticleInterface<
+      InnerParticleInterface, StackIteratorInterface>::getKineticEnergy() const {
+    return this->getEnergy() - this->getMass();
+  }
+
   template <template <typename> class InnerParticleInterface,
             typename StackIteratorInterface>
   inline int16_t NuclearParticleInterface<
diff --git a/corsika/stack/NuclearStackExtension.hpp b/corsika/stack/NuclearStackExtension.hpp
index 19e9a44c3..445350506 100644
--- a/corsika/stack/NuclearStackExtension.hpp
+++ b/corsika/stack/NuclearStackExtension.hpp
@@ -92,10 +92,17 @@ namespace corsika::nuclear_stack {
      * Overwrite normal getParticleMass function with nuclear version
      */
     HEPMassType getMass() const;
+
     /**
      * Overwrite normal getParticleCharge function with nuclear version
      */
     ElectricChargeType getCharge() const;
+
+    /**
+     * Overwrite normal getKineticEnergy function with nuclear version
+     */
+    HEPMassType getKineticEnergy() const;
+
     /**
      * Overwirte normal getChargeNumber function with nuclear version
      **/
diff --git a/corsika/stack/VectorStack.hpp b/corsika/stack/VectorStack.hpp
index e5a765781..23211e4cb 100644
--- a/corsika/stack/VectorStack.hpp
+++ b/corsika/stack/VectorStack.hpp
@@ -95,6 +95,8 @@ namespace corsika {
 
     ElectricChargeType getCharge() const { return get_charge(this->getPID()); }
 
+    HEPEnergyType getKineticEnergy() const { return this->getEnergy() - this->getMass(); }
+
     int16_t getChargeNumber() const { return get_charge_number(this->getPID()); }
     ///@}
   };
diff --git a/tests/stack/testNuclearStackExtension.cpp b/tests/stack/testNuclearStackExtension.cpp
index 1e33f6c05..1b310ba4d 100644
--- a/tests/stack/testNuclearStackExtension.cpp
+++ b/tests/stack/testNuclearStackExtension.cpp
@@ -17,6 +17,14 @@ using namespace corsika;
 #include <iostream>
 using namespace std;
 
+template <typename TParticle>
+HEPEnergyType kineticEnergy(TParticle const p) {
+  if (p.getPID() == Code::Nucleus)
+    return p.getEnergy() - get_nucleus_mass(p.getNuclearA(), p.getNuclearZ());
+  else
+    return p.getEnergy() - get_mass(p.getPID());
+}
+
 TEST_CASE("NuclearStackExtension", "[stack]") {
 
   logging::set_level(logging::level::info);
@@ -64,13 +72,18 @@ TEST_CASE("NuclearStackExtension", "[stack]") {
   }
 
   SECTION("read nucleus") {
+    auto const A = 10;
+    auto const Z = 9;
     nuclear_stack::ParticleDataStack s;
     s.addParticle(std::make_tuple(
         Code::Nucleus, 1.5_GeV, MomentumVector(dummyCS, {1_GeV, 1_GeV, 1_GeV}),
-        Point(dummyCS, {1 * meter, 1 * meter, 1 * meter}), 100_s, 10, 9));
+        Point(dummyCS, {1 * meter, 1 * meter, 1 * meter}), 100_s, A, Z));
     const auto pout = s.getNextParticle();
     CHECK(pout.getPID() == Code::Nucleus);
     CHECK(pout.getEnergy() == 1.5_GeV);
+    CHECK(pout.getMass() == get_nucleus_mass(A, Z));
+    CHECK(pout.getKineticEnergy() == kineticEnergy(pout));
+    CHECK(pout.getKineticEnergy() > 0_GeV);
     CHECK(pout.getTime() == 100_s);
     CHECK(pout.getNuclearA() == 10);
     CHECK(pout.getNuclearZ() == 9);
diff --git a/tests/stack/testVectorStack.cpp b/tests/stack/testVectorStack.cpp
index 9767b6c84..074c34dea 100644
--- a/tests/stack/testVectorStack.cpp
+++ b/tests/stack/testVectorStack.cpp
@@ -17,6 +17,11 @@
 using namespace corsika;
 using namespace std;
 
+template <typename TParticle>
+HEPEnergyType kineticEnergy(TParticle const p) {
+  return p.getEnergy() - get_mass(p.getPID());
+}
+
 TEST_CASE("VectorStack", "[stack]") {
 
   logging::set_level(logging::level::info);
@@ -37,6 +42,7 @@ TEST_CASE("VectorStack", "[stack]") {
     auto pout = s.getNextParticle();
     CHECK(pout.getPID() == Code::Electron);
     CHECK(pout.getEnergy() == 1.5_GeV);
+    CHECK(pout.getKineticEnergy() == kineticEnergy(pout));
     CHECK(pout.getTime() == 100_s);
   }
 
-- 
GitLab