Restriction.java
package org.matsim.episim.policy;
import com.typesafe.config.Config;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.poi.util.Beta;
import org.matsim.api.core.v01.Id;
import org.matsim.episim.model.FaceMask;
import org.matsim.facilities.ActivityFacility;
import javax.annotation.Nullable;
import java.util.*;
import java.util.stream.Collectors;
/**
* Represent the current restrictions on an activity type.
*/
public final class Restriction {
private static final Logger log = LogManager.getLogger(Restriction.class);
/**
* Percentage of activities still performed.
*/
@Nullable
private Double remainingFraction;
/**
* Contact intensity correction factor.
*/
@Nullable
private Double ciCorrection;
/**
* Maximum allowed group size of activities.
*/
@Nullable
private Integer maxGroupSize;
/**
* Reduced group size for activities.
*/
@Nullable
private Integer reducedGroupSize;
/**
* Ids of closed facilities.
*/
@Nullable
private Set<Id<ActivityFacility>> closed;
/**
* {@link ClosingHours} when activity is closed.
*/
@Nullable
private ClosingHours closingHours;
/**
* Maps mask type to percentage of persons wearing it.
*/
private Map<FaceMask, Double> maskUsage = new EnumMap<>(FaceMask.class);
/**
* Maps location-based remainingFraction to district name
*/
private Map<String, Double> locationBasedRf;
/**
* Fraction of susceptible persons, performing an activity.
*/
@Nullable
private Double susceptibleRf;
/**
* Remaining fraction for vaccinated persons.
*/
@Nullable
private Double vaccinatedRf;
/**
* Stores if this restriction should be extrapolated.
*/
private boolean extrapolate = false;
/**
* Constructor.
*/
private Restriction(@Nullable Double remainingFraction, @Nullable Double ciCorrection, @Nullable Integer maxGroupSize, @Nullable Integer reducedGroupSize,
@Nullable List<String> closed, @Nullable ClosingHours closingHours, @Nullable Map<FaceMask, Double> maskUsage, Map<String, Double> locationBasedRf,
@Nullable Double susceptibleRf, @Nullable Double vaccinatedRf) {
if (remainingFraction != null && !Objects.equals(remainingFraction, ShutdownPolicy.REG_HOSPITAL) && (Double.isNaN(remainingFraction) || remainingFraction < 0 || remainingFraction > 1))
throw new IllegalArgumentException("remainingFraction must be between 0 and 1 but is=" + remainingFraction);
if (ciCorrection != null && (Double.isNaN(ciCorrection) || ciCorrection < 0))
throw new IllegalArgumentException("contact intensity correction must be larger than 0 but is=" + ciCorrection);
if (maskUsage != null && maskUsage.values().stream().anyMatch(p -> p < 0 || p > 1))
throw new IllegalArgumentException("Mask usage probabilities must be between [0, 1]");
if (susceptibleRf != null && (Double.isNaN(susceptibleRf) || susceptibleRf < 0 || susceptibleRf > 1))
throw new IllegalArgumentException("Susceptible fraction must be between 0 and 1 but is="+ susceptibleRf);
this.remainingFraction = remainingFraction;
this.ciCorrection = ciCorrection;
this.maxGroupSize = maxGroupSize;
this.reducedGroupSize = reducedGroupSize;
this.closingHours = closingHours;
this.susceptibleRf = susceptibleRf;
this.vaccinatedRf = vaccinatedRf;
if (closed != null) {
this.closed = closed.stream().map(s -> Id.create(s, ActivityFacility.class)).collect(Collectors.toSet());
}
// Compute cumulative probabilities
if (maskUsage != null && !maskUsage.isEmpty()) {
if (maskUsage.containsKey(FaceMask.NONE))
throw new IllegalArgumentException("Mask usage for NONE must not be given");
// stream must be sorted or the order is undefined, which can result in different sums
double total = maskUsage.values().stream().sorted().mapToDouble(Double::doubleValue).sum();
if (total > 1) throw new IllegalArgumentException("Sum of mask usage rates must be < 1");
double sum = 1 - total;
this.maskUsage.put(FaceMask.NONE, sum);
for (FaceMask m : FaceMask.values()) {
if (maskUsage.containsKey(m)) {
sum += maskUsage.get(m);
if (Double.isNaN(sum))
throw new IllegalArgumentException("Mask usage contains NaN value!");
this.maskUsage.put(m, sum);
}
}
}
this.locationBasedRf = locationBasedRf;
}
/**
* Create from other restriction.
*
* @param maskUsage will only be used of other is null
*/
Restriction(@Nullable Double remainingFraction, @Nullable Double ciCorrection, @Nullable Integer maxGroupSize, @Nullable Integer reducedGroupSize,
@Nullable List<String> closed, @Nullable ClosingHours closingHours, @Nullable Map<FaceMask, Double> maskUsage, Map<String, Double> locationBasedRf,
@Nullable Double susceptibleRf, @Nullable Double vaccinatedRf, Restriction other) {
this.remainingFraction = remainingFraction;
this.ciCorrection = ciCorrection;
this.maxGroupSize = maxGroupSize;
this.reducedGroupSize = reducedGroupSize;
this.closingHours = closingHours;
this.maskUsage.putAll(other != null ? other.maskUsage : maskUsage);
this.locationBasedRf = locationBasedRf;
this.susceptibleRf = susceptibleRf;
this.vaccinatedRf = vaccinatedRf;
if (closed != null) {
this.closed = closed.stream().map(s -> Id.create(s, ActivityFacility.class)).collect(Collectors.toSet());
}
}
public void setLocationBasedRf(Map<String, Double> locationBasedRf) {
this.locationBasedRf = locationBasedRf;
}
public Map<String, Double> getLocationBasedRf() {
return this.locationBasedRf;
}
/**
* Restriction that allows everything.
*/
public static Restriction none() {
return new Restriction(1d, 1d, Integer.MAX_VALUE, Integer.MAX_VALUE,null,
new ClosingHours(0, 0), Map.of(), new HashMap<>(), 1d, 1d);
}
/**
* Restriction only reducing the {@link #remainingFraction}.
*/
public static Restriction of(double remainingFraction) {
return new Restriction(remainingFraction, null, null, null,null, null, null, new HashMap<>(), null, null);
}
/**
* Restriction with remaining fraction and ci correction.
*/
public static Restriction of(double remainingFraction, double ciCorrection) {
return new Restriction(remainingFraction, ciCorrection, null, null,null, null, null, new HashMap<>(), null, null);
}
/**
* Restriction with remaining fraction, ci correction and mask usage.
*/
public static Restriction of(double remainingFraction, double ciCorrection, Map<FaceMask, Double> maskUsage) {
return new Restriction(remainingFraction, ciCorrection, null, null,null, null, maskUsage, new HashMap<>(), null, null);
}
/**
* Helper function for restriction with one mask compliance.
* See {@link #ofMask(FaceMask, double)}.
*/
public static Restriction of(double remainingFraction, FaceMask mask, double maskCompliance) {
return new Restriction(remainingFraction, null, null, null,null, null, Map.of(mask, maskCompliance), new HashMap<>(), null, null);
}
/**
* Creates a restriction with one mask type and its compliance rates.
*
* @see #ofMask(Map)
*/
public static Restriction ofMask(FaceMask mask, double complianceRate) {
return new Restriction(null, null, null, null,null, null, Map.of(mask, complianceRate), new HashMap<>(), null, null);
}
/**
* Creates a restriction with required masks and compliance rates. Sum has to be smaller than 1.
* Not defined probability goes into the {@link FaceMask#NONE}.
*/
public static Restriction ofMask(Map<FaceMask, Double> maskUsage) {
return new Restriction(null, null, null,null, null, null, maskUsage, new HashMap<>(), null, null);
}
/**
* Creates a restriction with certain facilities closed. Should not be combined with other restrictions.
*/
public static Restriction ofClosedFacilities(List<String> closed) {
return new Restriction(null, null, null,null, closed, null, null, new HashMap<>(), null, null);
}
/**
* Creates a restriction, which has only a contact intensity correction set.
*/
public static Restriction ofCiCorrection(double ciCorrection) {
return new Restriction(null, ciCorrection, null,null, null, null, null, new HashMap<>(), null, null);
}
/**
* Creates a restriction with limited maximum group size of activities.
*/
public static Restriction ofGroupSize(int maxGroupSize) {
return new Restriction(null, null, maxGroupSize, null,null, null, null, new HashMap<>(), null, null);
}
/**
* Restriction that reduces group sizes of activities.
* Unlike {@link #ofGroupSize(int)} activities above {@code maxGroupSize} are not closed completely, but have reduced participation.
*
* @param maxGroupSize maximum allowed group size
*/
public static Restriction ofReducedGroupSize(int maxGroupSize) {
return new Restriction(null, null, null, maxGroupSize, null, null, null, new HashMap<>(), null, null);
}
/**
* Creates a restriction for activity to be closed during certain hours.
* If {@code fromHour} is larger than {@code toHour} it is assumed that the closing hour is over midnight.
*
* @param fromHour hour of the day when activity will close
* @param toHour hour of the day when activity will open again.
*/
public static Restriction ofClosingHours(int fromHour, int toHour) {
ClosingHours closed = asClosingHours(List.of(fromHour * 3600, toHour * 3600));
return new Restriction(null, null, null, null,null, closed, null, new HashMap<>(), null, null);
}
/**
* Restrict percentage of activities for unvaccinated and susceptible persons.
*
* @param remainingFraction remaining fraction of unvaccinated persons, i.e. 0 bans this activity for all unvaccinated persons.
*/
@Beta
public static Restriction ofSusceptibleRf(double remainingFraction) {
return new Restriction(null, null, null, null, null, null, null, new HashMap<>(), remainingFraction, null);
}
/**
* Restrict percentage of activities for vaccinated persons.
*
* @param remainingFraction remaining fraction for vaccinated persons.
*/
@Beta
public static Restriction ofVaccinatedRf(double remainingFraction) {
return new Restriction(null, null, null, null, null, null, null, new HashMap<>(), null, remainingFraction);
}
public static Restriction ofLocationBasedRf(Map<String, Double> locationBasedRf) {
return new Restriction(null, null, null, null, null, null, null, locationBasedRf, null, null);
}
/**
* Creates a restriction from a config entry.
*/
public static Restriction fromConfig(Config config) {
// Could be integer or double
Map<String, Number> nameMap = (Map<String, Number>) config.getValue("masks").unwrapped();
Map<FaceMask, Double> enumMap = new EnumMap<>(FaceMask.class);
Map<String, Double> locationBasedRf = new HashMap<>();
if (!config.getIsNull("locationBasedRf")) {
Map<String, Double> valueFromConfig = (Map<String, Double>) config.getValue("locationBasedRf").unwrapped();
if (valueFromConfig != null) {
locationBasedRf = valueFromConfig;
}
}
if (nameMap != null)
nameMap.forEach((k, v) -> enumMap.put(FaceMask.valueOf(k), v.doubleValue()));
return new Restriction(
config.getIsNull("fraction") ? null : config.getDouble("fraction"),
config.getIsNull("ciCorrection") ? null : config.getDouble("ciCorrection"),
config.getIsNull("maxGroupSize") ? null : config.getInt("maxGroupSize"),
!config.hasPath("reducedGroupSize") || config.getIsNull("reducedGroupSize") ? null : config.getInt("reducedGroupSize"),
!config.hasPath("closed") || config.getIsNull("closed") ? null : config.getStringList("closed"),
!config.hasPath("closingHours") || config.getIsNull("closingHours") ? null : asClosingHours(config.getIntList("closingHours")),
enumMap, locationBasedRf,
config.getIsNull("susceptibleRf") ? null : config.getDouble("susceptibleRf"),
config.getIsNull("vaccinatedRf") ? null : config.getDouble("vaccinatedRf"),
null
);
}
/**
* Convert list of ints to closing hour instances.
*/
private static ClosingHours asClosingHours(List<Integer> closingHours) {
if (closingHours == null)
return null;
return new ClosingHours(closingHours.get(0), closingHours.get(1));
}
/**
* Creates a copy of a restriction.
*/
static Restriction clone(Restriction restriction) {
return new Restriction(restriction.remainingFraction, restriction.ciCorrection, restriction.maxGroupSize, restriction.reducedGroupSize,
restriction.closed == null ? null : restriction.closed.stream().map(Objects::toString).collect(Collectors.toList()),
restriction.closingHours,
null,
restriction.locationBasedRf == null ? new HashMap<>() : restriction.locationBasedRf,
restriction.susceptibleRf,
restriction.vaccinatedRf,
restriction);
}
/**
* Determines / Randomly draws which mask a persons wears while this restriction is in place.
*/
public FaceMask determineMask(SplittableRandom rnd) {
if (maskUsage.isEmpty()) return FaceMask.NONE;
double p = Double.NaN;
for (Map.Entry<FaceMask, Double> e : maskUsage.entrySet()) {
if (e.getValue() == 1d) return e.getKey();
else if (Double.isNaN(p))
p =rnd.nextDouble();
if (p < e.getValue())
return e.getKey();
}
throw new IllegalStateException("Could not determine mask. Probabilities are likely wrong.");
}
/**
* Check whether one time falls into a closing hour.
*
* @param sod timestamp as seconds of day
* @param adjustFrom when true result time is shifted to be later, otherwise shifted to start of closing hour
* @return adjusted time, unchanged when not in closing hour. Otherwise moved to closing hours
*/
double calculateOverlap(double sod, boolean adjustFrom) {
// seconds of day
ClosingHours ch = closingHours;
if (adjustFrom) {
if (ch.overnight)
return sod >= ch.from ? ch.length - (sod - ch.from) : ch.length - (sod + 86400 - ch.from);
return ch.length - (sod - ch.from);
} else {
if (ch.overnight)
return sod <= ch.to ? ch.length - (ch.to - sod) : sod - ch.from;
return ch.length - (ch.to - sod);
}
}
/**
* Calculate how many seconds are overlapped by the closing hour, given a time interval.
*
* @return overlap or 0 if the time interval is inside the closing.
*/
public double overlapWithClosingHour(double from, double to) {
if (closingHours == null)
return 0;
// closing of 0-24 needs to be handled separately, as overlap would be infinite
if (closingHours.length >= 86400)
return Integer.MAX_VALUE;
double fSod = from % 86400;
double tSod = to % 86400;
ClosingHours ch = closingHours;
boolean containsFrom = ch.contains(fSod);
boolean containsTo = ch.contains(tSod);
boolean actOvernight = tSod < fSod;
if (containsFrom && containsTo) {
// whole time nullified
return to - from;
} else if (containsFrom) {
return calculateOverlap(fSod, true);
} else if (containsTo) {
return calculateOverlap(tSod, false);
} else if (ch.includedIn(fSod, tSod) && (ch.overnight == actOvernight)) {
// reduce by time of closing hour length
return ch.length;
} else if (actOvernight && !ch.overnight && tSod >= ch.to) {
// also covered completely
return ch.length;
}
return 0;
}
/**
* This method is also used to write the restriction to csv.
*/
@Override
public String toString() {
return String.format(Locale.ENGLISH, "%.2f_%.2f_%s", remainingFraction, ciCorrection, maskUsage);
}
/**
* Set restriction values from other restriction update.
*/
void update(Restriction r) {
// All values may be optional and are only set if present
if (r.getRemainingFraction() != null)
remainingFraction = r.getRemainingFraction();
if (r.getCiCorrection() != null)
ciCorrection = r.getCiCorrection();
if (r.getMaxGroupSize() != null)
maxGroupSize = r.getMaxGroupSize();
if (r.getReducedGroupSize() != null)
reducedGroupSize = r.getReducedGroupSize();
if (r.getSusceptibleRf() != null)
susceptibleRf = r.getSusceptibleRf();
if (r.getVaccinatedRf() != null)
vaccinatedRf = r.getVaccinatedRf();
if (r.closed != null)
closed = r.closed;
if (r.closingHours != null)
closingHours = r.closingHours;
if (!r.maskUsage.isEmpty()) {
maskUsage.clear();
maskUsage.putAll(r.maskUsage);
}
if (r.locationBasedRf !=null && !r.locationBasedRf.isEmpty()) {
locationBasedRf = new HashMap<>();
locationBasedRf.putAll(r.locationBasedRf);
}
}
/**
* Merges another restrictions into this one. Will fail if any attribute would be overwritten.
*
* localRf: if new restriction has a Rf, the localRf will NOT be included in merged result; reason: to aid users who
* manually enter restrictions, but forget to clear the localRf.
*
* @see #asMap()
*/
void merge(Map<String, Object> restriction) {
Double otherRf = (Double) restriction.get("fraction");
Double otherE = (Double) restriction.get("ciCorrection");
Integer otherGroup = (Integer) restriction.get("maxGroupSize");
Integer otherReduced = (Integer) restriction.get("reducedGroupSize");
Double otherSRf = (Double) restriction.get("susceptibleRf");
Double otherVRf = (Double) restriction.get("vaccinatedRf");
ClosingHours otherClosingH = asClosingHours((List<Integer>) restriction.get("closingHours"));
Map<FaceMask, Double> otherMasks = new EnumMap<>(FaceMask.class);
((Map<String, Double>) restriction.get("masks"))
.forEach((k, v) -> otherMasks.put(FaceMask.valueOf(k), v));
Map<String, Double> otherLocationBasedRf = new HashMap<>();
((Map<String, Double>) restriction.get("locationBasedRf")).forEach(((key, value) -> otherLocationBasedRf.put(key, value)));
// if new and old Restriction have equal Rfs: warn
if (remainingFraction != null && otherRf != null && !remainingFraction.equals(otherRf)){
log.warn("Duplicated remainingFraction " + remainingFraction + " and " + otherRf);
}
// if new Restriction doesn't have value for Rf:
// 1) keep old value for Rf
// 2) check whether new Restriction has localRf, and if not, then use localRf from old Restriction
else if (remainingFraction == null){
remainingFraction = otherRf;
if (!locationBasedRf.isEmpty() && !otherLocationBasedRf.isEmpty() && !locationBasedRf.equals(otherLocationBasedRf)) {
log.warn("Duplicated locationBasedRf usage; existing value=" + locationBasedRf + "; new value=" + otherLocationBasedRf + "; keeping existing value.");
} else if (locationBasedRf.isEmpty()) {
locationBasedRf.putAll(otherLocationBasedRf);
}
} else {
if ((locationBasedRf == null || locationBasedRf.isEmpty())
&& (!otherLocationBasedRf.isEmpty() || otherLocationBasedRf != null)) {
log.warn("localRf removed during merge, as new remainingFraction was provided");
}
}
if (ciCorrection != null && otherE != null && !ciCorrection.equals(otherE))
log.warn("Duplicated ci correction " + ciCorrection + " and " + otherE);
else if (ciCorrection == null)
ciCorrection = otherE;
if (maxGroupSize != null && otherGroup != null && !maxGroupSize.equals(otherGroup))
log.warn("Duplicated max group size " + maxGroupSize + " and " + otherGroup);
else if (maxGroupSize == null)
maxGroupSize = otherGroup;
if (reducedGroupSize != null && otherReduced != null && !reducedGroupSize.equals(otherReduced))
log.warn("Duplicated reduced group size " + reducedGroupSize + " and " + otherReduced);
else if (reducedGroupSize == null)
reducedGroupSize = otherReduced;
if (closingHours != null && otherClosingH != null && !closingHours.equals(otherClosingH))
log.warn("Duplicated max closing hours " + closingHours + " and " + otherClosingH);
else if (closingHours == null)
closingHours = otherClosingH;
if (susceptibleRf != null && otherSRf != null && !susceptibleRf.equals(otherSRf))
log.warn("Duplicated susceptible fraction " + susceptibleRf + " and " + otherSRf);
else if (susceptibleRf == null)
susceptibleRf = otherSRf;
if (vaccinatedRf != null && otherVRf != null && !vaccinatedRf.equals(otherVRf))
log.warn("Duplicated vaccinated fraction " + vaccinatedRf + " and " + otherVRf);
else if (vaccinatedRf == null)
vaccinatedRf = otherVRf;
if (!maskUsage.isEmpty() && !otherMasks.isEmpty() && !maskUsage.equals(otherMasks)) {
log.warn("Duplicated mask usage; existing value=" + maskUsage + "; new value=" + otherMasks + "; keeping existing value.");
log.warn("(full new restriction=" + restriction + ")");
} else if (maskUsage.isEmpty())
maskUsage.putAll(otherMasks);
}
boolean isExtrapolate() {
return extrapolate;
}
void setExtrapolate(boolean extrapolate) {
this.extrapolate = extrapolate;
}
public Double getRemainingFraction() {
return remainingFraction;
}
void setRemainingFraction(double remainingFraction) {
this.remainingFraction = remainingFraction;
}
public Double getCiCorrection() {
return ciCorrection;
}
public boolean isClosed(Id<?> containerId) {
if (closed == null)
return false;
return closed.contains(containerId);
}
@Nullable
ClosingHours getClosingHours() {
return closingHours;
}
/**
* Whether closing hours are set.
*/
public boolean hasClosingHours() {
return closingHours != null && closingHours.to != closingHours.from;
}
@Nullable
public Integer getMaxGroupSize() {
return maxGroupSize;
}
@Nullable
public Integer getReducedGroupSize() {
return reducedGroupSize;
}
Map<FaceMask, Double> getMaskUsage() {
return maskUsage;
}
@Nullable
public Double getSusceptibleRf() {
return susceptibleRf;
}
@Nullable
public Double getVaccinatedRf() {
return vaccinatedRf;
}
/**
* Attributes of this restriction as map.
*/
public Map<String, Object> asMap() {
Map<String, Object> map = new HashMap<>();
map.put("fraction", remainingFraction);
map.put("ciCorrection", ciCorrection);
map.put("maxGroupSize", maxGroupSize);
map.put("reducedGroupSize", reducedGroupSize);
if (closed != null)
map.put("closed", closed.stream().map(Object::toString).collect(Collectors.toList()));
// Must be converted to map with strings
Map<String, Double> nameMap = new LinkedHashMap<>();
maskUsage.forEach((k, v) -> nameMap.put(k.name(), v));
map.put("masks", nameMap);
if (closingHours != null) {
map.put("closingHours", List.of(closingHours.from, closingHours.to));
}
map.put("locationBasedRf", locationBasedRf);
map.put("susceptibleRf", susceptibleRf);
map.put("vaccinatedRf", vaccinatedRf);
return map;
}
/**
* Hours when an activity is closed.
*/
public static final class ClosingHours {
/**
* Starting second when activity is closed (exclusive)
*/
public final int from;
/**
* Seconds until when activity is still closed.
*/
public final int to;
/**
* Closed overnight.
*/
public final boolean overnight;
/**
* Length of closing hours.
*/
public final int length;
ClosingHours(int from, int to) {
this.from = from;
this.to = to;
this.overnight = from > to;
this.length = overnight ? (86400 - from) + to : to - from;
}
/**
* Check whether timestamp is contained in the closing hours.
*/
public boolean contains(double sod) {
if (overnight) {
return sod >= from || sod <= to;
} else
return sod >= from && sod <= to;
}
/**
* Closing hour is completely included in this interval (as seconds of day).
*/
public boolean includedIn(double fSod, double tSod) {
return fSod < from && tSod > to;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClosingHours that = (ClosingHours) o;
return from == that.from &&
to == that.to;
}
@Override
public int hashCode() {
return Objects.hash(from, to);
}
@Override
public String toString() {
return "ClosingHours{" +
"from=" + from +
", to=" + to +
'}';
}
}
}