From 186b10e5b16c1285b08edb8fcec503d893f8967c Mon Sep 17 00:00:00 2001
From: Maximilian Reininghaus <maximilian.reininghaus@tu-dortmund.de>
Date: Sat, 26 Jan 2019 18:01:22 +0100
Subject: [PATCH] moved SampleTarget to NuclearComposition, removed temporary
 weight vector

---
 Environment/IMediumModel.h                    | 23 ---------
 Environment/NuclearComposition.h              | 50 +++++++++++++++++++
 Environment/testEnvironment.cc                | 32 ++++++++++++
 .../HadronicElasticModel.h                    |  4 +-
 Processes/Sibyll/Interaction.h                |  4 +-
 5 files changed, 86 insertions(+), 27 deletions(-)

diff --git a/Environment/IMediumModel.h b/Environment/IMediumModel.h
index cdfb30cd6..ddae05642 100644
--- a/Environment/IMediumModel.h
+++ b/Environment/IMediumModel.h
@@ -17,7 +17,6 @@
 #include <corsika/geometry/Point.h>
 #include <corsika/geometry/Trajectory.h>
 #include <corsika/units/PhysicalUnits.h>
-#include <random>
 
 namespace corsika::environment {
 
@@ -39,28 +38,6 @@ namespace corsika::environment {
         corsika::units::si::GrammageType) const = 0;
 
     virtual NuclearComposition const& GetNuclearComposition() const = 0;
-
-    template <class TRNG>
-    corsika::particles::Code SampleTarget(
-        std::vector<corsika::units::si::CrossSectionType> const& sigma,
-        TRNG& randomStream) const {
-      using namespace corsika::units::si;
-
-      auto const& nuclComp = GetNuclearComposition();
-      auto const& fractions = nuclComp.GetFractions();
-      assert(sigma.size() == fractions.size());
-
-      std::vector<float> weights(fractions.size());
-
-      for (size_t i = 0; i < fractions.size(); ++i) {
-        std::cout << "HomogeneousMedium: fraction: " << fractions[i] << std::endl;
-        weights[i] = fractions[i] * sigma[i].magnitude();
-      }
-
-      std::discrete_distribution channelDist(weights.cbegin(), weights.cend());
-      const int iChannel = channelDist(randomStream);
-      return nuclComp.GetComponents()[iChannel];
-    }
   };
 
 } // namespace corsika::environment
diff --git a/Environment/NuclearComposition.h b/Environment/NuclearComposition.h
index a96abd402..4f1988aa0 100644
--- a/Environment/NuclearComposition.h
+++ b/Environment/NuclearComposition.h
@@ -15,6 +15,7 @@
 #include <corsika/particles/ParticleProperties.h>
 #include <cassert>
 #include <numeric>
+#include <random>
 #include <stdexcept>
 #include <vector>
 
@@ -26,6 +27,35 @@ namespace corsika::environment {
 
     double const fAvgMassNumber;
 
+    template <class AConstIterator, class BConstIterator>
+    class WeightProviderIterator {
+      AConstIterator fAIter;
+      BConstIterator fBIter;
+
+    public:
+      using value_type = double;
+      using iterator_category = std::input_iterator_tag;
+      using pointer = double*;
+      using reference = double&;
+      using difference_type = ptrdiff_t;
+
+      WeightProviderIterator(AConstIterator a, BConstIterator b)
+          : fAIter(a)
+          , fBIter(b) {}
+
+      double operator*() const { return ((*fAIter) * (*fBIter)).magnitude(); }
+
+      WeightProviderIterator& operator++() { // prefix ++
+        ++fAIter;
+        ++fBIter;
+        return *this;
+      }
+
+      auto operator==(WeightProviderIterator other) { return fAIter == other.fAIter; }
+
+      auto operator!=(WeightProviderIterator other) { return !(*this == other); }
+    };
+
   public:
     NuclearComposition(std::vector<corsika::particles::Code> pComponents,
                        std::vector<float> pFractions)
@@ -51,6 +81,26 @@ namespace corsika::environment {
     auto const& GetFractions() const { return fNumberFractions; }
     auto const& GetComponents() const { return fComponents; }
     auto const GetAverageMassNumber() const { return fAvgMassNumber; }
+
+    template <class TRNG>
+    corsika::particles::Code SampleTarget(
+        std::vector<corsika::units::si::CrossSectionType> const& sigma,
+        TRNG& randomStream) const {
+      using namespace corsika::units::si;
+
+      assert(sigma.size() == fNumberFractions.size());
+
+      std::discrete_distribution channelDist(
+          WeightProviderIterator<decltype(fNumberFractions.begin()),
+                                 decltype(sigma.begin())>(fNumberFractions.begin(),
+                                                          sigma.begin()),
+          WeightProviderIterator<decltype(fNumberFractions.begin()),
+                                 decltype(sigma.end())>(fNumberFractions.end(),
+                                                        sigma.end()));
+
+      auto const iChannel = channelDist(randomStream);
+      return fComponents[iChannel];
+    }
   };
 
 } // namespace corsika::environment
diff --git a/Environment/testEnvironment.cc b/Environment/testEnvironment.cc
index 8a8e9c13b..0695d02e1 100644
--- a/Environment/testEnvironment.cc
+++ b/Environment/testEnvironment.cc
@@ -17,6 +17,8 @@
 #include <corsika/environment/VolumeTreeNode.h>
 #include <corsika/particles/ParticleProperties.h>
 #include <catch2/catch.hpp>
+#include <random>
+#include <vector>
 
 using namespace corsika::geometry;
 using namespace corsika::environment;
@@ -28,3 +30,33 @@ TEST_CASE("HomogeneousMedium") {
       std::vector<float>{1.f});
   HomogeneousMedium<IMediumModel> const medium(19.2_g / cube(1_cm), protonComposition);
 }
+
+TEST_CASE("NuclearComposition") {
+  NuclearComposition const composition(
+      std::vector<corsika::particles::Code>{corsika::particles::Code::Proton,
+                                            corsika::particles::Code::Neutron},
+      std::vector<float>{2.f / 3.f, 1.f / 3.f});
+  SECTION("SampleTarget") {
+    std::vector<CrossSectionType> crossSections{50_mbarn, 100_mbarn};
+
+    std::mt19937 rng;
+
+    int proton{0}, neutron{0};
+
+    for (int i = 0; i < 1'000'000; ++i) {
+      corsika::particles::Code p = composition.SampleTarget(crossSections, rng);
+      switch (p) {
+        case corsika::particles::Code::Proton:
+          proton++;
+          break;
+        case corsika::particles::Code::Neutron:
+          neutron++;
+          break;
+        default:
+          throw std::runtime_error("");
+      }
+    }
+
+    REQUIRE(static_cast<double>(proton) / neutron == Approx(1).epsilon(1e-2));
+  }
+}
diff --git a/Processes/HadronicElasticModel/HadronicElasticModel.h b/Processes/HadronicElasticModel/HadronicElasticModel.h
index 9899e4e4c..31b6fd056 100644
--- a/Processes/HadronicElasticModel/HadronicElasticModel.h
+++ b/Processes/HadronicElasticModel/HadronicElasticModel.h
@@ -156,8 +156,8 @@ namespace corsika::process::HadronicElasticModel {
         cross_section_of_components[i] = CrossSection(s);
       }
 
-      const auto targetCode = currentNode->GetModelProperties().SampleTarget(
-          cross_section_of_components, fRNG);
+      const auto targetCode =
+          mediumComposition.SampleTarget(cross_section_of_components, fRNG);
 
       auto const targetMass = corsika::particles::GetMass(targetCode);
 
diff --git a/Processes/Sibyll/Interaction.h b/Processes/Sibyll/Interaction.h
index 288e724fe..91806cf4d 100644
--- a/Processes/Sibyll/Interaction.h
+++ b/Processes/Sibyll/Interaction.h
@@ -264,8 +264,8 @@ namespace corsika::process::sibyll {
           cross_section_of_components[i] = sigProd;
         }
 
-        const auto targetCode = currentNode->GetModelProperties().SampleTarget(
-            cross_section_of_components, fRNG);
+        const auto targetCode =
+            mediumComposition.SampleTarget(cross_section_of_components, fRNG);
         cout << "Interaction: target selected: " << targetCode << endl;
         /*
           FOR NOW: allow nuclei with A<18 or protons only.
-- 
GitLab