Transition.java

  1. package org.matsim.episim.model;

  2. import com.google.common.base.Objects;
  3. import com.typesafe.config.Config;
  4. import com.typesafe.config.ConfigFactory;
  5. import com.typesafe.config.ConfigRenderOptions;
  6. import com.typesafe.config.ConfigValue;
  7. import org.apache.commons.math3.util.FastMath;
  8. import org.matsim.episim.EpisimPerson.DiseaseStatus;
  9. import org.matsim.episim.EpisimUtils;

  10. import java.util.*;

  11. /**
  12.  * Describes how long a person stays in a certain state.
  13.  * Also provides factory methods for all available transitions.
  14.  * <p>
  15.  * Please note that it is not possible nor intended to inherit from this class outside of this package,
  16.  * as this would break serialization.
  17.  */
  18. public abstract class Transition {

  19.     /**
  20.      * Inheritance is prohibited for external classes.
  21.      */
  22.     Transition() {
  23.     }

  24.     /**
  25.      * Parse Transition builder from a config file.
  26.      */
  27.     public static Builder parse(Config config) {
  28.         return new Builder(config);
  29.     }

  30.     /**
  31.      * Create a new transition config builder.
  32.      */
  33.     public static Builder config() {
  34.         return new Builder((String) null);
  35.     }

  36.     /**
  37.      * Create a new transition config builder with a filename, that will be used if the config is persisted.
  38.      */
  39.     public static Builder config(String filename) {
  40.         return new Builder(filename);
  41.     }

  42.     /**
  43.      * Creates a to transition, to be used in conjunction with the {@link Builder}.
  44.      *
  45.      * @param status target state
  46.      * @param t      desired transition
  47.      */
  48.     public static ToHolder to(DiseaseStatus status, Transition t) {
  49.         return new ToHolder(status, t);
  50.     }

  51.     /**
  52.      * Deterministic transition at day {@code day}.
  53.      */
  54.     public static Transition fixed(int day) {
  55.         return new FixedTransition(day);
  56.     }

  57.     /**
  58.      * Probabilistic transition with log normal distribution.
  59.      * Parameter of the distribution will be calculated from given mean und standard deviation.
  60.      *
  61.      * @param mean desired mean of the distribution
  62.      * @param std  desired standard deviation
  63.      * @see LogNormalTransition
  64.      */
  65.     public static Transition logNormalWithMean(double mean, double std) {

  66.         // mean==median if std=0
  67.         if (std == 0) return logNormalWithMedianAndSigma(mean, 0);

  68.         double mu = Math.log((mean * mean) / Math.sqrt(mean * mean + std * std));
  69.         double sigma = Math.log(1 + (std * std) / (mean * mean));

  70.         return new LogNormalTransition(mu, Math.sqrt(sigma));
  71.     }

  72.     /**
  73.      * Same as {@link #logNormalWithMean(double, double)}.
  74.      */
  75.     public static Transition logNormalWithMeanAndStd(double mean, double std) {
  76.         return logNormalWithMean(mean, std);
  77.     }

  78.     /**
  79.      * Probabilistic state transition with log normal distribution.
  80.      *
  81.      * @param median desired median, i.e. exp(mu)
  82.      * @param sigma  sigma parameter
  83.      * @see LogNormalTransition
  84.      */
  85.     public static Transition logNormalWithMedianAndSigma(double median, double sigma) {

  86.         double mu = Math.log(median);
  87.         return new LogNormalTransition(mu, sigma);
  88.     }

  89.     /**
  90.      * Probabilistic state transition with log normal distribution.
  91.      *
  92.      * @param median desired median, i.e. exp(mu)
  93.      * @param std    desired standard deviation
  94.      * @see LogNormalTransition
  95.      */
  96.     public static Transition logNormalWithMedianAndStd(double median, double std) {

  97.         // equation below is numerical unstable for std near zero
  98.         if (std == 0) return logNormalWithMedianAndSigma(median, 0);

  99.         double mu = Math.log(median);

  100.         // Given the formula for std:
  101.         // \sqrt{\operatorname{Var}(X)}= \sqrt{\mathrm{e}^{2\mu+\sigma^{2}}(\mathrm{e}^{\sigma^{2}}-1)}=\mathrm{e}^{\mu+\frac{\sigma^{2}}{2}}\cdot\sqrt{\mathrm{e}^{\sigma^{2}}-1}

  102.         // solve for sigma
  103.         // https://www.wolframalpha.com/input/?i=solve+e%5E%28mu+%2B+s+%2F+2%29+*+sqrt%28e%5Es+-+1%29+%3D+x+for+s
  104.         double ssigma = Math.log(0.5 * Math.exp(-2 * mu) * (Math.exp(2 * mu) + Math.sqrt(Math.exp(4 * mu) + 4 * Math.exp(2 * mu) * std * std)));

  105.         return new LogNormalTransition(mu, Math.sqrt(ssigma));
  106.     }

  107.     /**
  108.      * Returns the day when the transition should occur.
  109.      */
  110.     public abstract int getTransitionDay(SplittableRandom rnd);

  111.     /**
  112.      * Implementation for a fixed transition.
  113.      */
  114.     private static final class FixedTransition extends Transition {

  115.         private final int day;

  116.         private FixedTransition(int day) {
  117.             this.day = day;
  118.         }

  119.         @Override
  120.         public int getTransitionDay(SplittableRandom rnd) {
  121.             return day;
  122.         }

  123.         @Override
  124.         public boolean equals(Object o) {
  125.             if (this == o) return true;
  126.             if (o == null || getClass() != o.getClass()) return false;
  127.             FixedTransition that = (FixedTransition) o;
  128.             return day == that.day;
  129.         }

  130.         @Override
  131.         public int hashCode() {
  132.             return Objects.hashCode(day);
  133.         }
  134.     }

  135.     /**
  136.      * Implementation for log normal distributed transition.
  137.      *
  138.      * @see EpisimUtils#nextLogNormal(SplittableRandom, double, double)
  139.      */
  140.     private static final class LogNormalTransition extends Transition {

  141.         private final double mu;
  142.         private final double sigma;

  143.         private LogNormalTransition(double mu, double sigma) {
  144.             this.mu = mu;
  145.             this.sigma = sigma;

  146.             if (sigma < 0 || Double.isNaN(sigma))
  147.                 throw new IllegalArgumentException("Sigma must be >= 0");
  148.         }

  149.         @Override
  150.         public int getTransitionDay(SplittableRandom rnd) {
  151.             return (int) FastMath.round(EpisimUtils.nextLogNormal(rnd, mu, sigma));
  152.         }

  153.         @Override
  154.         public boolean equals(Object o) {
  155.             if (this == o) return true;
  156.             if (o == null || getClass() != o.getClass()) return false;
  157.             LogNormalTransition that = (LogNormalTransition) o;
  158.             return Double.compare(that.mu, mu) == 0 &&
  159.                     Double.compare(that.sigma, sigma) == 0;
  160.         }

  161.         @Override
  162.         public int hashCode() {
  163.             return Objects.hashCode(mu, sigma);
  164.         }
  165.     }

  166.     /**
  167.      * Builder for a transition config.
  168.      */
  169.     public static final class Builder {

  170.         private final String origin;
  171.         private final Map<DiseaseStatus, Map<DiseaseStatus, Transition>> transitions = new EnumMap<>(DiseaseStatus.class);

  172.         private Builder(String origin) {
  173.             this.origin = origin;
  174.         }

  175.         /**
  176.          * Initialize from config.
  177.          */
  178.         @SuppressWarnings("unchecked")
  179.         private Builder(Config config) {

  180.             for (Map.Entry<String, ConfigValue> e : config.root().entrySet()) {

  181.                 DiseaseStatus status = DiseaseStatus.valueOf(e.getKey());
  182.                 Config toConfig = config.getConfig(e.getKey());

  183.                 List<ToHolder> tos = new ArrayList<>();

  184.                 for (Map.Entry<String, ConfigValue> to : toConfig.root().entrySet()) {

  185.                     Map<String, String> params = (Map<String, String>) to.getValue().unwrapped();

  186.                     DiseaseStatus toStatus = DiseaseStatus.valueOf(to.getKey());
  187.                     Transition t;
  188.                     if (params.get("type").equals("FixedTransition"))
  189.                         t = new FixedTransition(Integer.parseInt(params.get("day")));
  190.                     else if (params.get("type").equals("LogNormalTransition"))
  191.                         t = new LogNormalTransition(Double.parseDouble(params.get("mu")), Double.parseDouble(params.get("sigma")));
  192.                     else
  193.                         throw new IllegalArgumentException("Could not parse transition: " + params);

  194.                     tos.add(to(toStatus, t));
  195.                 }

  196.                 from(status, tos.toArray(new ToHolder[0]));
  197.             }

  198.             this.origin = config.origin().description();
  199.         }

  200.         /**
  201.          * Defines which transitions should be taken from the state {@code} status to the states defined in {@code to}.
  202.          *
  203.          * @param status the current disease status
  204.          * @param to     collection of target states and their transitions.
  205.          * @see #to(DiseaseStatus, Transition)
  206.          */
  207.         public Builder from(DiseaseStatus status, ToHolder... to) {
  208.             if (to.length == 0) throw new IllegalArgumentException("No target states specified");

  209.             for (ToHolder t : to) {
  210.                 transitions.computeIfAbsent(status, (k) -> new EnumMap<>(DiseaseStatus.class))
  211.                         .put(t.status, t.t);
  212.             }
  213.             return this;
  214.         }

  215.         /**
  216.          * Creates a config representation.
  217.          */
  218.         public Config build() {

  219.             Map<String, Object> config = new LinkedHashMap<>();

  220.             for (Map.Entry<DiseaseStatus, Map<DiseaseStatus, Transition>> e : transitions.entrySet()) {

  221.                 Map<String, Object> toConfig = new LinkedHashMap<>();

  222.                 for (Map.Entry<DiseaseStatus, Transition> to : e.getValue().entrySet()) {
  223.                     // params of the transition
  224.                     Map<String, String> params = new LinkedHashMap<>();

  225.                     Transition t = to.getValue();

  226.                     params.put("type", t.getClass().getSimpleName());

  227.                     if (t instanceof FixedTransition) {
  228.                         params.put("day", String.valueOf(((FixedTransition) t).day));
  229.                     } else if (t instanceof LogNormalTransition) {
  230.                         params.put("mu", String.valueOf(((LogNormalTransition) t).mu));
  231.                         params.put("sigma", String.valueOf(((LogNormalTransition) t).sigma));
  232.                     } else
  233.                         throw new IllegalArgumentException("Can not serialize unknown transition " + t);

  234.                     toConfig.put(to.getKey().name(), params);
  235.                 }

  236.                 config.put(e.getKey().name(), toConfig);
  237.             }

  238.             return ConfigFactory.parseMap(config, origin);
  239.         }


  240.         /**
  241.          * Returns the config as matrix with entries as transition from -> to, according to {@link DiseaseStatus#ordinal()}.
  242.          * Not defined transitions will be null.
  243.          */
  244.         public Transition[] asArray() {
  245.             Transition[] array = new Transition[DiseaseStatus.values().length * DiseaseStatus.values().length];

  246.             for (Map.Entry<DiseaseStatus, Map<DiseaseStatus, Transition>> e : transitions.entrySet()) {
  247.                 for (Map.Entry<DiseaseStatus, Transition> to : e.getValue().entrySet()) {
  248.                     array[e.getKey().ordinal() * DiseaseStatus.values().length + to.getKey().ordinal()] = to.getValue();
  249.                 }
  250.             }

  251.             return array;
  252.         }

  253.         @Override
  254.         public String toString() {
  255.             return build().root().render(ConfigRenderOptions.concise().setJson(false));
  256.         }
  257.     }

  258.     /**
  259.      * Holder class that saves the target status and desired transition.
  260.      */
  261.     public static final class ToHolder {

  262.         public final DiseaseStatus status;
  263.         public final Transition t;

  264.         private ToHolder(DiseaseStatus status, Transition t) {
  265.             this.status = status;
  266.             this.t = t;
  267.         }

  268.         @Override
  269.         public String toString() {
  270.             return "ToHolder{" +
  271.                     "status=" + status +
  272.                     ", t=" + t +
  273.                     '}';
  274.         }
  275.     }
  276. }