From d2e8018b995db0d8b4058a0209630ab954380c43 Mon Sep 17 00:00:00 2001
From: Remy Prechelt <prechelt@hawaii.edu>
Date: Tue, 30 Mar 2021 22:58:05 -1000
Subject: [PATCH] Port TrackWriter to the Parquet data format.

---
 corsika/detail/modules/TrackWriter.inl        | 57 ++++++------
 .../modules/writers/TrackWriterParquet.inl    | 68 +++++++++++++++
 corsika/modules/TrackWriter.hpp               | 18 ++--
 .../modules/writers/TrackWriterParquet.hpp    | 60 +++++++++++++
 python/corsika/io/outputs/__init__.py         |  3 +-
 python/corsika/io/outputs/track_writer.py     | 87 +++++++++++++++++++
 6 files changed, 249 insertions(+), 44 deletions(-)
 create mode 100644 corsika/detail/modules/writers/TrackWriterParquet.inl
 create mode 100644 corsika/modules/writers/TrackWriterParquet.hpp
 create mode 100644 python/corsika/io/outputs/track_writer.py

diff --git a/corsika/detail/modules/TrackWriter.inl b/corsika/detail/modules/TrackWriter.inl
index 27b47cf65..5a82ee75c 100644
--- a/corsika/detail/modules/TrackWriter.inl
+++ b/corsika/detail/modules/TrackWriter.inl
@@ -8,53 +8,48 @@
 
 #pragma once
 
-#include <corsika/modules/TrackWriter.hpp>
-
 #include <corsika/framework/core/ParticleProperties.hpp>
+#include <corsika/framework/core/PhysicalUnits.hpp>
 
-#include <corsika/setup/SetupStack.hpp>
-#include <corsika/setup/SetupTrajectory.hpp>
+// #include <corsika/setup/SetupStack.hpp>
+// #include <corsika/setup/SetupTrajectory.hpp>
 
-#include <iomanip>
 #include <limits>
 
 namespace corsika {
 
-  TrackWriter::TrackWriter(std::string const& filename)
-      : filename_(filename) {
-    using namespace std::string_literals;
-
-    file_.open(filename_);
-    file_
-        << "# PID, E / eV, start coordinates / m, displacement vector to end / m, steplength / m "s
-        << '\n';
-  }
+  template <typename TOutput>
+  TrackWriter<TOutput>::TrackWriter() {}
 
+  template <typename TOutput>
   template <typename TParticle, typename TTrack>
-  ProcessReturn TrackWriter::doContinuous(const TParticle& vP, const TTrack& vT) {
+  ProcessReturn TrackWriter<TOutput>::doContinuous(const TParticle& vP,
+                                                   const TTrack& vT) {
     auto const start = vT.getPosition(0).getCoordinates();
-    auto const delta = vT.getPosition(1).getCoordinates() - start;
-    auto const pdg = static_cast<int>(get_PDG(vP.getPID()));
-
-    // clang-format off
-    file_ << std::setw(7) << pdg
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << vP.getEnergy() / 1_eV
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[0] / 1_m 
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[1] / 1_m
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[2] / 1_m
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[0] / 1_m
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[1] / 1_m
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[2] / 1_m 
-          << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta.getNorm() / 1_m
-          << '\n';
-    // clang-format on
+    auto const end = vT.getPosition(1).getCoordinates();
+
+    // write the track to the file
+    this->write(vP.getPID(), vP.getEnergy(), start, end);
 
     return ProcessReturn::Ok;
   }
 
+  template <typename TOutput>
   template <typename TParticle, typename TTrack>
-  LengthType TrackWriter::getMaxStepLength(const TParticle&, const TTrack&) {
+  LengthType TrackWriter<TOutput>::getMaxStepLength(const TParticle&, const TTrack&) {
     return meter * std::numeric_limits<double>::infinity();
   }
 
+  template <typename TOutput>
+  YAML::Node TrackWriter<TOutput>::getConfig() const {
+    using namespace units::si;
+
+    YAML::Node node;
+
+    // add default units for values
+    node["type"] = "TrackWriter";
+    node["units"] = "GeV | m";
+
+    return node;
+  }
 } // namespace corsika
diff --git a/corsika/detail/modules/writers/TrackWriterParquet.inl b/corsika/detail/modules/writers/TrackWriterParquet.inl
new file mode 100644
index 000000000..27c7b3c78
--- /dev/null
+++ b/corsika/detail/modules/writers/TrackWriterParquet.inl
@@ -0,0 +1,68 @@
+/*
+ * (c) Copyright 2021 CORSIKA Project, corsika-project@lists.kit.edu
+ *
+ * This software is distributed under the terms of the GNU General Public
+ * Licence version 3 (GPL Version 3). See file LICENSE for a full version of
+ * the license.
+ */
+
+#pragma once
+
+namespace corsika {
+
+  TrackWriterParquet::TrackWriterParquet()
+      : output_() {}
+
+  void TrackWriterParquet::startOfLibrary(std::filesystem::path const& directory) {
+
+    // setup the streamer
+    output_.initStreamer((directory / "tracks.parquet").string());
+
+    // build the schema
+    output_.addField("pdg", parquet::Repetition::REQUIRED, parquet::Type::INT32,
+                     parquet::ConvertedType::INT_32);
+    output_.addField("energy", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("start_x", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("start_y", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("start_z", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("end_x", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("end_y", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+    output_.addField("end_z", parquet::Repetition::REQUIRED, parquet::Type::FLOAT,
+                     parquet::ConvertedType::NONE);
+
+    // and build the streamer
+    output_.buildStreamer();
+  }
+
+  void TrackWriterParquet::endOfShower() { ++shower_; }
+
+  void TrackWriterParquet::endOfLibrary() { output_.closeStreamer(); }
+
+  void TrackWriterParquet::write(Code const& pid, units::si::HEPEnergyType const& energy,
+                                 QuantityVector<length_d> const& start,
+                                 QuantityVector<length_d> const& end) {
+    using namespace units::si;
+
+    // write the next row - we must write `shower_` first.
+    // clang-format off
+    *(output_.getWriter())
+        << shower_
+        << static_cast<int>(get_PDG(pid))
+        << static_cast<float>(energy / 1_GeV)
+        << static_cast<float>(start[0] / 1_m)
+        << static_cast<float>(start[1] / 1_m)
+        << static_cast<float>(start[2] / 1_m)
+        << static_cast<float>(end[0] / 1_m)
+        << static_cast<float>(end[1] / 1_m)
+        << static_cast<float>(end[2] / 1_m)
+        << parquet::EndRow;
+    // clang-format on
+  }
+
+} // namespace corsika
diff --git a/corsika/modules/TrackWriter.hpp b/corsika/modules/TrackWriter.hpp
index c437bc7d1..9a9addb96 100644
--- a/corsika/modules/TrackWriter.hpp
+++ b/corsika/modules/TrackWriter.hpp
@@ -8,18 +8,17 @@
 
 #pragma once
 
-#include <corsika/framework/core/PhysicalUnits.hpp>
 #include <corsika/framework/process/ContinuousProcess.hpp>
-
-#include <fstream>
-#include <string>
+#include <corsika/modules/writers/TrackWriterParquet.hpp>
 
 namespace corsika {
 
-  class TrackWriter : public ContinuousProcess<TrackWriter> {
+  template <typename TOutputWriter = TrackWriterParquet>
+  class TrackWriter : public ContinuousProcess<TrackWriter<TOutputWriter>>,
+                      public TOutputWriter {
 
   public:
-    TrackWriter(std::string const& filename);
+    TrackWriter();
 
     template <typename TParticle, typename TTrack>
     ProcessReturn doContinuous(TParticle const&, TTrack const&);
@@ -27,12 +26,7 @@ namespace corsika {
     template <typename TParticle, typename TTrack>
     LengthType getMaxStepLength(TParticle const&, TTrack const&);
 
-  private:
-    std::string const filename_;
-    std::ofstream file_;
-
-    int width_ = 14;
-    int precision_ = 6;
+    YAML::Node getConfig() const;
   };
 
 } // namespace corsika
diff --git a/corsika/modules/writers/TrackWriterParquet.hpp b/corsika/modules/writers/TrackWriterParquet.hpp
new file mode 100644
index 000000000..5298222f6
--- /dev/null
+++ b/corsika/modules/writers/TrackWriterParquet.hpp
@@ -0,0 +1,60 @@
+/*
+ * (c) Copyright 2021 CORSIKA Project, corsika-project@lists.kit.edu
+ *
+ * This software is distributed under the terms of the GNU General Public
+ * Licence version 3 (GPL Version 3). See file LICENSE for a full version of
+ * the license.
+ */
+
+#pragma once
+
+#include <corsika/output/BaseOutput.hpp>
+#include <corsika/output/ParquetStreamer.hpp>
+#include <corsika/framework/core/ParticleProperties.hpp>
+#include <corsika/framework/core/PhysicalUnits.hpp>
+
+namespace corsika {
+
+  class TrackWriterParquet : public BaseOutput {
+
+    ParquetStreamer output_; ///< The primary output file.
+
+  public:
+    /**
+     * Construct a new writer.
+     *
+     * @param name    The name of this output.
+     */
+    TrackWriterParquet();
+
+    /**
+     * Called at the start of each library.
+     */
+    void startOfLibrary(std::filesystem::path const& directory) final override;
+
+    /**
+     * Called at the end of each shower.
+     */
+    void endOfShower() final override;
+
+    /**
+     * Called at the end of each library.
+     *
+     * This must also increment the run number since we override
+     * the default behaviour of BaseOutput.
+     */
+    void endOfLibrary() final override;
+
+  protected:
+    /**
+     * Write a track to the file.
+     */
+    void write(Code const& pid, units::si::HEPEnergyType const& energy,
+               QuantityVector<length_d> const& start,
+               QuantityVector<length_d> const& end);
+
+  }; // class TrackWriterParquet
+
+} // namespace corsika
+
+#include <corsika/detail/modules/writers/TrackWriterParquet.inl>
diff --git a/python/corsika/io/outputs/__init__.py b/python/corsika/io/outputs/__init__.py
index 67fb35a7b..fa38acdde 100644
--- a/python/corsika/io/outputs/__init__.py
+++ b/python/corsika/io/outputs/__init__.py
@@ -8,6 +8,7 @@
 """
 
 from .observation_plane import ObservationPlane
+from .track_writer import TrackWriter
 from .output import Output
 
-__all__ = ["Output", "ObservationPlane"]
+__all__ = ["Output", "ObservationPlane", "TrackWriter"]
diff --git a/python/corsika/io/outputs/track_writer.py b/python/corsika/io/outputs/track_writer.py
new file mode 100644
index 000000000..0859660d5
--- /dev/null
+++ b/python/corsika/io/outputs/track_writer.py
@@ -0,0 +1,87 @@
+"""
+ Read data written by TrackWriter
+
+ (c) Copyright 2020 CORSIKA Project, corsika-project@lists.kit.edu
+
+ This software is distributed under the terms of the GNU General Public
+ Licence version 3 (GPL Version 3). See file LICENSE for a full version of
+ the license.
+"""
+import logging
+import os.path as op
+from typing import Any
+
+import pyarrow.parquet as pq
+
+from .output import Output
+
+
+class TrackWriter(Output):
+    """
+    Read particle data from a TrackWriter
+    """
+
+    def __init__(self, path: str):
+        """
+        Load the particle data into a parquet table.
+
+        Parameters
+        ----------
+        path: str
+            The path to the directory containing this output.
+        """
+        super().__init__(path)
+
+        # try and load our data
+        try:
+            self.__data = pq.read_table(op.join(path, "tracks.parquet"))
+        except Exception as e:
+            logging.getLogger("corsika").warn(
+                f"An error occured loading a TrackWriter: {e}"
+            )
+
+    def is_good(self) -> bool:
+        """
+        Returns true if this output has been read successfully
+        and has the correct files/state/etc.
+
+        Returns
+        -------
+        bool:
+            True if this is a good output.
+        """
+        return self.__data is not None
+
+    def astype(self, dtype: str = "pandas", **kwargs: Any) -> Any:
+        """
+        Load the particle data from this track writer.
+
+        All additional keyword arguments are passed to `parquet.read_table`
+
+        Parameters
+        ----------
+        dtype: str
+            The data format to return the data in (i.e. numpy, pandas, etc.)
+
+        Returns
+        -------
+        Any:
+            The return type of this method is determined by `dtype`.
+        """
+        if dtype == "arrow":
+            return self.__data
+        elif dtype == "pandas":
+            return self.__data.to_pandas()
+        else:
+            raise ValueError(
+                (
+                    f"Unknown format '{dtype}' for TrackWriter. "
+                    "We currently only support ['arrow', 'pandas']."
+                )
+            )
+
+    def __repr__(self) -> str:
+        """
+        Return a string representation of this class.
+        """
+        return f"TrackWriter('{self.config['name']}')"
-- 
GitLab