From d2e8018b995db0d8b4058a0209630ab954380c43 Mon Sep 17 00:00:00 2001 From: Remy Prechelt <prechelt@hawaii.edu> Date: Tue, 30 Mar 2021 22:58:05 -1000 Subject: [PATCH] Port TrackWriter to the Parquet data format. --- corsika/detail/modules/TrackWriter.inl | 57 ++++++------ .../modules/writers/TrackWriterParquet.inl | 68 +++++++++++++++ corsika/modules/TrackWriter.hpp | 18 ++-- .../modules/writers/TrackWriterParquet.hpp | 60 +++++++++++++ python/corsika/io/outputs/__init__.py | 3 +- python/corsika/io/outputs/track_writer.py | 87 +++++++++++++++++++ 6 files changed, 249 insertions(+), 44 deletions(-) create mode 100644 corsika/detail/modules/writers/TrackWriterParquet.inl create mode 100644 corsika/modules/writers/TrackWriterParquet.hpp create mode 100644 python/corsika/io/outputs/track_writer.py diff --git a/corsika/detail/modules/TrackWriter.inl b/corsika/detail/modules/TrackWriter.inl index 27b47cf65..5a82ee75c 100644 --- a/corsika/detail/modules/TrackWriter.inl +++ b/corsika/detail/modules/TrackWriter.inl @@ -8,53 +8,48 @@ #pragma once -#include <corsika/modules/TrackWriter.hpp> - #include <corsika/framework/core/ParticleProperties.hpp> +#include <corsika/framework/core/PhysicalUnits.hpp> -#include <corsika/setup/SetupStack.hpp> -#include <corsika/setup/SetupTrajectory.hpp> +// #include <corsika/setup/SetupStack.hpp> +// #include <corsika/setup/SetupTrajectory.hpp> -#include <iomanip> #include <limits> namespace corsika { - TrackWriter::TrackWriter(std::string const& filename) - : filename_(filename) { - using namespace std::string_literals; - - file_.open(filename_); - file_ - << "# PID, E / eV, start coordinates / m, displacement vector to end / m, steplength / m "s - << '\n'; - } + template <typename TOutput> + TrackWriter<TOutput>::TrackWriter() {} + template <typename TOutput> template <typename TParticle, typename TTrack> - ProcessReturn TrackWriter::doContinuous(const TParticle& vP, const TTrack& vT) { + ProcessReturn TrackWriter<TOutput>::doContinuous(const TParticle& vP, + const TTrack& vT) { auto const start = vT.getPosition(0).getCoordinates(); - auto const delta = vT.getPosition(1).getCoordinates() - start; - auto const pdg = static_cast<int>(get_PDG(vP.getPID())); - - // clang-format off - file_ << std::setw(7) << pdg - << std::setw(width_) << std::scientific << std::setprecision(precision_) << vP.getEnergy() / 1_eV - << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[0] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[1] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << start[2] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[0] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[1] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta[2] / 1_m - << std::setw(width_) << std::scientific << std::setprecision(precision_) << delta.getNorm() / 1_m - << '\n'; - // clang-format on + auto const end = vT.getPosition(1).getCoordinates(); + + // write the track to the file + this->write(vP.getPID(), vP.getEnergy(), start, end); return ProcessReturn::Ok; } + template <typename TOutput> template <typename TParticle, typename TTrack> - LengthType TrackWriter::getMaxStepLength(const TParticle&, const TTrack&) { + LengthType TrackWriter<TOutput>::getMaxStepLength(const TParticle&, const TTrack&) { return meter * std::numeric_limits<double>::infinity(); } + template <typename TOutput> + YAML::Node TrackWriter<TOutput>::getConfig() const { + using namespace units::si; + + YAML::Node node; + + // add default units for values + node["type"] = "TrackWriter"; + node["units"] = "GeV | m"; + + return node; + } } // namespace corsika diff --git a/corsika/detail/modules/writers/TrackWriterParquet.inl b/corsika/detail/modules/writers/TrackWriterParquet.inl new file mode 100644 index 000000000..27c7b3c78 --- /dev/null +++ b/corsika/detail/modules/writers/TrackWriterParquet.inl @@ -0,0 +1,68 @@ +/* + * (c) Copyright 2021 CORSIKA Project, corsika-project@lists.kit.edu + * + * This software is distributed under the terms of the GNU General Public + * Licence version 3 (GPL Version 3). See file LICENSE for a full version of + * the license. + */ + +#pragma once + +namespace corsika { + + TrackWriterParquet::TrackWriterParquet() + : output_() {} + + void TrackWriterParquet::startOfLibrary(std::filesystem::path const& directory) { + + // setup the streamer + output_.initStreamer((directory / "tracks.parquet").string()); + + // build the schema + output_.addField("pdg", parquet::Repetition::REQUIRED, parquet::Type::INT32, + parquet::ConvertedType::INT_32); + output_.addField("energy", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("start_x", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("start_y", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("start_z", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("end_x", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("end_y", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + output_.addField("end_z", parquet::Repetition::REQUIRED, parquet::Type::FLOAT, + parquet::ConvertedType::NONE); + + // and build the streamer + output_.buildStreamer(); + } + + void TrackWriterParquet::endOfShower() { ++shower_; } + + void TrackWriterParquet::endOfLibrary() { output_.closeStreamer(); } + + void TrackWriterParquet::write(Code const& pid, units::si::HEPEnergyType const& energy, + QuantityVector<length_d> const& start, + QuantityVector<length_d> const& end) { + using namespace units::si; + + // write the next row - we must write `shower_` first. + // clang-format off + *(output_.getWriter()) + << shower_ + << static_cast<int>(get_PDG(pid)) + << static_cast<float>(energy / 1_GeV) + << static_cast<float>(start[0] / 1_m) + << static_cast<float>(start[1] / 1_m) + << static_cast<float>(start[2] / 1_m) + << static_cast<float>(end[0] / 1_m) + << static_cast<float>(end[1] / 1_m) + << static_cast<float>(end[2] / 1_m) + << parquet::EndRow; + // clang-format on + } + +} // namespace corsika diff --git a/corsika/modules/TrackWriter.hpp b/corsika/modules/TrackWriter.hpp index c437bc7d1..9a9addb96 100644 --- a/corsika/modules/TrackWriter.hpp +++ b/corsika/modules/TrackWriter.hpp @@ -8,18 +8,17 @@ #pragma once -#include <corsika/framework/core/PhysicalUnits.hpp> #include <corsika/framework/process/ContinuousProcess.hpp> - -#include <fstream> -#include <string> +#include <corsika/modules/writers/TrackWriterParquet.hpp> namespace corsika { - class TrackWriter : public ContinuousProcess<TrackWriter> { + template <typename TOutputWriter = TrackWriterParquet> + class TrackWriter : public ContinuousProcess<TrackWriter<TOutputWriter>>, + public TOutputWriter { public: - TrackWriter(std::string const& filename); + TrackWriter(); template <typename TParticle, typename TTrack> ProcessReturn doContinuous(TParticle const&, TTrack const&); @@ -27,12 +26,7 @@ namespace corsika { template <typename TParticle, typename TTrack> LengthType getMaxStepLength(TParticle const&, TTrack const&); - private: - std::string const filename_; - std::ofstream file_; - - int width_ = 14; - int precision_ = 6; + YAML::Node getConfig() const; }; } // namespace corsika diff --git a/corsika/modules/writers/TrackWriterParquet.hpp b/corsika/modules/writers/TrackWriterParquet.hpp new file mode 100644 index 000000000..5298222f6 --- /dev/null +++ b/corsika/modules/writers/TrackWriterParquet.hpp @@ -0,0 +1,60 @@ +/* + * (c) Copyright 2021 CORSIKA Project, corsika-project@lists.kit.edu + * + * This software is distributed under the terms of the GNU General Public + * Licence version 3 (GPL Version 3). See file LICENSE for a full version of + * the license. + */ + +#pragma once + +#include <corsika/output/BaseOutput.hpp> +#include <corsika/output/ParquetStreamer.hpp> +#include <corsika/framework/core/ParticleProperties.hpp> +#include <corsika/framework/core/PhysicalUnits.hpp> + +namespace corsika { + + class TrackWriterParquet : public BaseOutput { + + ParquetStreamer output_; ///< The primary output file. + + public: + /** + * Construct a new writer. + * + * @param name The name of this output. + */ + TrackWriterParquet(); + + /** + * Called at the start of each library. + */ + void startOfLibrary(std::filesystem::path const& directory) final override; + + /** + * Called at the end of each shower. + */ + void endOfShower() final override; + + /** + * Called at the end of each library. + * + * This must also increment the run number since we override + * the default behaviour of BaseOutput. + */ + void endOfLibrary() final override; + + protected: + /** + * Write a track to the file. + */ + void write(Code const& pid, units::si::HEPEnergyType const& energy, + QuantityVector<length_d> const& start, + QuantityVector<length_d> const& end); + + }; // class TrackWriterParquet + +} // namespace corsika + +#include <corsika/detail/modules/writers/TrackWriterParquet.inl> diff --git a/python/corsika/io/outputs/__init__.py b/python/corsika/io/outputs/__init__.py index 67fb35a7b..fa38acdde 100644 --- a/python/corsika/io/outputs/__init__.py +++ b/python/corsika/io/outputs/__init__.py @@ -8,6 +8,7 @@ """ from .observation_plane import ObservationPlane +from .track_writer import TrackWriter from .output import Output -__all__ = ["Output", "ObservationPlane"] +__all__ = ["Output", "ObservationPlane", "TrackWriter"] diff --git a/python/corsika/io/outputs/track_writer.py b/python/corsika/io/outputs/track_writer.py new file mode 100644 index 000000000..0859660d5 --- /dev/null +++ b/python/corsika/io/outputs/track_writer.py @@ -0,0 +1,87 @@ +""" + Read data written by TrackWriter + + (c) Copyright 2020 CORSIKA Project, corsika-project@lists.kit.edu + + This software is distributed under the terms of the GNU General Public + Licence version 3 (GPL Version 3). See file LICENSE for a full version of + the license. +""" +import logging +import os.path as op +from typing import Any + +import pyarrow.parquet as pq + +from .output import Output + + +class TrackWriter(Output): + """ + Read particle data from a TrackWriter + """ + + def __init__(self, path: str): + """ + Load the particle data into a parquet table. + + Parameters + ---------- + path: str + The path to the directory containing this output. + """ + super().__init__(path) + + # try and load our data + try: + self.__data = pq.read_table(op.join(path, "tracks.parquet")) + except Exception as e: + logging.getLogger("corsika").warn( + f"An error occured loading a TrackWriter: {e}" + ) + + def is_good(self) -> bool: + """ + Returns true if this output has been read successfully + and has the correct files/state/etc. + + Returns + ------- + bool: + True if this is a good output. + """ + return self.__data is not None + + def astype(self, dtype: str = "pandas", **kwargs: Any) -> Any: + """ + Load the particle data from this track writer. + + All additional keyword arguments are passed to `parquet.read_table` + + Parameters + ---------- + dtype: str + The data format to return the data in (i.e. numpy, pandas, etc.) + + Returns + ------- + Any: + The return type of this method is determined by `dtype`. + """ + if dtype == "arrow": + return self.__data + elif dtype == "pandas": + return self.__data.to_pandas() + else: + raise ValueError( + ( + f"Unknown format '{dtype}' for TrackWriter. " + "We currently only support ['arrow', 'pandas']." + ) + ) + + def __repr__(self) -> str: + """ + Return a string representation of this class. + """ + return f"TrackWriter('{self.config['name']}')" -- GitLab