From 4951397071518e7571ecbee01c416a19be3c1ba0 Mon Sep 17 00:00:00 2001
From: Maximilian Reininghaus <maximilian.reininghaus@kit.edu>
Date: Fri, 3 Jun 2022 17:50:36 +0200
Subject: [PATCH] wip

---
 .../modules/qgsjetII/InteractionModel.inl     | 66 +++++++++++--------
 tests/modules/testQGSJetII.cpp                | 22 ++++---
 2 files changed, 49 insertions(+), 39 deletions(-)

diff --git a/corsika/detail/modules/qgsjetII/InteractionModel.inl b/corsika/detail/modules/qgsjetII/InteractionModel.inl
index f7e0ab154..dee1f11e6 100644
--- a/corsika/detail/modules/qgsjetII/InteractionModel.inl
+++ b/corsika/detail/modules/qgsjetII/InteractionModel.inl
@@ -75,7 +75,7 @@ namespace corsika::qgsjetII {
 
     // define projectile, in lab frame
     auto const S = (projectileP4 + targetP4).getNormSqr();
-    auto const SNN = S / (AfactorProjectile * AfactorTarget);
+    auto const SNN = S / static_pow<2>(AfactorProjectile * AfactorTarget);
     auto const sqrtSNN = sqrt(SNN);
     if (!isValid(projectileId, targetId, sqrtSNN)) { return CrossSectionType::zero(); }
 
@@ -111,35 +111,44 @@ namespace corsika::qgsjetII {
         projectileId, corsika::qgsjetII::canInteract(projectileId));
 
     // define projectile, in lab frame
-    auto const SNN =
-        (projectileP4 / (is_nucleus(projectileId) ? get_nucleus_A(projectileId) : 1) +
-         targetP4 / (is_nucleus(targetId) ? get_nucleus_A(targetId) : 1))
-            .getNormSqr();
-    auto const sqrtS_NN = sqrt(SNN);
+    auto const AfactorProjectile =
+        is_nucleus(projectileId) ? get_nucleus_A(projectileId) : 1;
+    auto const AfactorTarget = is_nucleus(targetId) ? get_nucleus_A(targetId) : 1;
+
+    // define projectile, in lab frame
+    auto const S = (projectileP4 + targetP4).getNormSqr();
+    auto const SNN = S / static_pow<2>(AfactorProjectile * AfactorTarget);
+    auto const sqrtSNN = sqrt(SNN);
+
+
     if (!corsika::qgsjetII::canInteract(projectileId) ||
-        !isValid(projectileId, targetId, sqrtS_NN)) {
+        !isValid(projectileId, targetId, sqrtSNN)) {
       throw std::runtime_error("invalid target/projectile/energy combination.");
     }
-    auto const projectileMass =
-        (is_nucleus(projectileId) ? constants::nucleonMass : get_mass(projectileId));
-    auto const targetMass =
-        (projectileId == Code::Proton
-             ? get_mass(Code::Proton)
-             : constants::nucleonMass); // qgsjet target is always proton or nucleon.
-                                        // always nucleon??
+    
+    auto const projMass = get_mass(projectileId);
+    auto const targetMass = get_mass(targetId);
+
+    // lab-frame energy per projectile nucleon as required by qgsect()
+    HEPEnergyType const ElabN =
+        calculate_lab_energy(S, projMass, targetMass) / AfactorProjectile;    
+    
+    //~ auto const projectileMass =
+        //~ (is_nucleus(projectileId) ? constants::nucleonMass : get_mass(projectileId));
+    //~ auto const targetMass =
+        //~ (projectileId == Code::Proton
+             //~ ? get_mass(Code::Proton)
+             //~ : constants::nucleonMass); // qgsjet target is always proton or nucleon.
+                                        //~ // always nucleon??
 
     // lab energy/hadron
-    HEPEnergyType const Elab = calculate_lab_energy(SNN, projectileMass, targetMass);
+    //~ HEPEnergyType const Elab = calculate_lab_energy(SNN, projectileMass, targetMass);
 
-    int beamA = 0;
-    if (is_nucleus(projectileId)) { beamA = get_nucleus_A(projectileId); }
+    int const beamA = is_nucleus(projectileId) ? get_nucleus_A(projectileId) : 0;
 
-    CORSIKA_LOG_DEBUG("ebeam lab: {} GeV ", Elab / 1_GeV);
+    CORSIKA_LOG_DEBUG("ebeam lab: {} GeV per projectile nucleon", ElabN / 1_GeV);
 
-    int targetMassNumber = 1;   // proton
-    if (is_nucleus(targetId)) { // nucleus
-      targetMassNumber = get_nucleus_A(targetId);
-    }
+    int const targetMassNumber = is_nucleus(targetId) ? get_nucleus_A(targetId) : 1;   // proton
     CORSIKA_LOG_DEBUG("target: {}, qgsjetII code/A: {}", targetId, targetMassNumber);
 
     // select QGSJetII internal projectile type
@@ -160,11 +169,11 @@ namespace corsika::qgsjetII {
     }
 
     count_++;
-    int qgsjet_hadron_type_int = static_cast<QgsjetIICodeIntType>(qgsjet_hadron_type);
+    int const qgsjet_hadron_type_int = static_cast<QgsjetIICodeIntType>(qgsjet_hadron_type);
     CORSIKA_LOG_DEBUG(
         "qgsjet_hadron_type_int={} projectileMassNumber={} targetMassNumber={}",
         qgsjet_hadron_type_int, projectileMassNumber, targetMassNumber);
-    qgini_(Elab / 1_GeV, qgsjet_hadron_type_int, projectileMassNumber, targetMassNumber);
+    qgini_(ElabN / 1_GeV, qgsjet_hadron_type_int, projectileMassNumber, targetMassNumber);
     qgconf_();
 
     CoordinateSystemPtr const& rootCS = get_root_CoordinateSystem();
@@ -196,16 +205,15 @@ namespace corsika::qgsjetII {
 
     // fragments
     QGSJetIIFragmentsStack qfs;
+    std::bernoulli_distribution nucleonTypeDist;
     for (auto& fragm : qfs) {
       int const A = fragm.getFragmentSize();
       if (A == 1) { // nucleon
-        std::uniform_real_distribution<double> select;
-        Code idFragm = Code::Proton;
-        if (select(rng_) > 0.5) { idFragm = Code::Neutron; }
+        Code const idFragm = nucleonTypeDist(rng_) ? Code::Proton : Code::Neutron;
 
-        const HEPMassType nucleonMass = get_mass(idFragm);
+        HEPMassType const nucleonMass = get_mass(idFragm);
         // no pT, fragments just go forward
-        HEPEnergyType const projectileEnergyLabPerNucleon = Elab / beamA;
+        HEPEnergyType const projectileEnergyLabPerNucleon = ElabN;
         MomentumVector momentum{csPrime,
                                 {0.0_GeV, 0.0_GeV,
                                  sqrt((projectileEnergyLabPerNucleon + nucleonMass) *
diff --git a/tests/modules/testQGSJetII.cpp b/tests/modules/testQGSJetII.cpp
index 4c5d351b7..98c313e34 100644
--- a/tests/modules/testQGSJetII.cpp
+++ b/tests/modules/testQGSJetII.cpp
@@ -142,32 +142,34 @@ TEST_CASE("QgsjetIIInterface", "interaction,processes") {
   corsika::qgsjetII::InteractionModel model;
 
   SECTION("cross-sections") {
-    auto projCode = GENERATE(Code::PiPlus, Code::Proton, Code::K0Long);
+    auto projCode = GENERATE(Code::PiPlus, Code::Proton, Code::K0Long, Code::Nitrogen, Code::Helium);
     auto targetCode = GENERATE(Code::Oxygen, Code::Nitrogen);
-    auto projEnergy = GENERATE(100_GeV, 1_PeV, 1e20_eV);
+    auto projEnergy = GENERATE(1_PeV, 1e18_eV);
 
     auto momMagnitude = calculate_momentum(projEnergy, get_mass(projCode));
     MomentumVector const projMomentum{*csPtr, 0_eV, momMagnitude, 0_eV};
 
     REQUIRE(model.getCrossSection(
                 projCode, targetCode, FourMomentum{projEnergy, projMomentum},
-                FourMomentum{get_mass(Code::Oxygen), {*csPtr, 0_eV, 0_eV, 0_eV}}) /
-                1_mb >
-            0);
+                FourMomentum{get_mass(targetCode), {*csPtr, 0_eV, 0_eV, 0_eV}}) /
+                1_mb > 0);
   }
 
   SECTION("InteractionInterface") {
+    auto projCode = GENERATE(/*Code::PiPlus, Code::Proton, Code::K0Long,*/ Code::Iron/*, Code::Nitrogen, Code::Helium*/);
+    auto targetCode = GENERATE(Code::Oxygen/*, Code::Nitrogen*/);
+    auto projMomentum = GENERATE(1_PeV); //, 1e20_eV);
 
     auto [stackPtr, secViewPtr] = setup::testing::setup_stack(
-        Code::Proton, 110_GeV, (DummyEnvironment::BaseNodeType* const)nodePtr, *csPtr);
+        Code::Proton, projMomentum, (DummyEnvironment::BaseNodeType* const)nodePtr, *csPtr);
     test::StackView& view = *(secViewPtr.get());
     auto projectile = secViewPtr->getProjectile();
     auto const projectileMomentum = projectile.getMomentum();
 
-    model.doInteraction(view, Code::Proton, Code::Oxygen,
-                        {sqrt(static_pow<2>(110_GeV) + static_pow<2>(Proton::mass)),
-                         MomentumVector{cs, 110_GeV, 0_GeV, 0_GeV}},
-                        {Oxygen::mass, MomentumVector{cs, {0_eV, 0_eV, 0_eV}}});
+    model.doInteraction(view, projCode, targetCode,
+                        FourMomentum{calculate_total_energy(projMomentum, get_mass(projCode)),
+                         projectileMomentum},
+                        FourMomentum{get_mass(targetCode), MomentumVector{cs, {0_eV, 0_eV, 0_eV}}});
 
     /* **********************************
      As it turned out already two times (#291 and #307) that the detailed output of
-- 
GitLab