BatchRun.java

/*-
 * #%L
 * MATSim Episim
 * %%
 * Copyright (C) 2020 matsim-org
 * %%
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * #L%
 */
package org.matsim.episim;

import com.google.common.collect.Lists;
import com.google.inject.Module;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.matsim.core.config.Config;
import org.matsim.core.config.ConfigUtils;
import org.matsim.episim.analysis.OutputAnalysis;
import org.matsim.episim.analysis.RValuesFromEvents;

import javax.annotation.Nullable;
import java.io.IOException;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.LocalDate;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * Interface for defining the setup procedure of a batch run and the corresponding parameter class.
 * The batch runner will create the cross-product of all possible parameters configuration and prepare
 * the config for each.
 *
 * @param <T> Class holding the available parameters.
 */
public interface BatchRun<T> {

	/**
	 * Find calibration parameter for given params from a list of csv records.
	 *
	 * @param params  params to lookup
	 * @param records parsed records with parameter
	 * @param ignore fields that will be ignored and not matched
	 * @return calibration parameter if present or NaN.
	 */
	static double lookup(Object params, List<CSVRecord> records, String... ignore) {

		Field[] fields = params.getClass().getDeclaredFields();

		List<String> ignoreList = Arrays.asList(ignore);

		outer:
		for (CSVRecord record : records) {

			int matched = 0;

			for (Field field : fields) {
				if (ignoreList.contains(field.getName()))
					continue;

				try {
					Object obj = field.get(params);
					String value = EpisimUtils.asString(obj);
					try {

						String cmp = record.get(field.getName());
						if (!cmp.equals(value))
							continue outer;

					} catch (IllegalArgumentException e) {
						// skip records not present
					}

					matched++;
				} catch (ReflectiveOperationException e) {
					// noting to do
				}
			}

			// when no mismatches occurred this records is returned
			if (matched > 0) {
				String param = record.get("param");
				if (param.isEmpty())
					return Double.NaN;

				return Double.parseDouble(param);
			}
		}

		return Double.NaN;
	}

	/**
	 * Loads the defined parameters and executes the {@link #prepareConfig(int, Object)} procedure.
	 *
	 * @param clazz      setup class
	 * @param paramClazz class holding the parameters
	 * @param <T>        params type
	 */
	static <T> PreparedRun prepare(Class<? extends BatchRun<T>> clazz, Class<T> paramClazz) {

		Logger log = LogManager.getLogger(BatchRun.class);
		List<Field> fields = new ArrayList<>();
		List<List<Object>> allParams = new ArrayList<>();

		for (Field field : paramClazz.getDeclaredFields()) {
			Parameter param = field.getAnnotation(Parameter.class);
			if (param != null) {
				allParams.add(DoubleStream.of(param.value()).boxed().collect(Collectors.toList()));
				fields.add(field);
			}
			StartDates dateParam = field.getAnnotation(StartDates.class);
			if (dateParam != null) {
				if (!field.getName().equals("startDate"))
					throw new IllegalArgumentException("StartDates field must be called 'startDate'");

				allParams.add(Stream.of(dateParam.value()).map(LocalDate::parse).collect(Collectors.toList()));
				fields.add(field);
			}
			IntParameter intParam = field.getAnnotation(IntParameter.class);
			if (intParam != null) {
				allParams.add(IntStream.of(intParam.value()).boxed().collect(Collectors.toList()));
				fields.add(field);
			}
			StringParameter stringParam = field.getAnnotation(StringParameter.class);
			if (stringParam != null) {
				allParams.add(Arrays.asList(stringParam.value()));
				fields.add(field);
			}
			EnumParameter enumParam = field.getAnnotation(EnumParameter.class);
			if (enumParam != null) {
				try {
					Method m = enumParam.value().getDeclaredMethod("values");
					Object[] invoke = (Object[]) m.invoke(null);
					List<Object> enums = Lists.newArrayList(invoke);
					// remove the ignored enums
					enums.removeIf(p -> ArrayUtils.indexOf(enumParam.ignore(), p.toString()) > -1);

					allParams.add(enums);
					fields.add(field);
				} catch (ReflectiveOperationException e) {
					throw new IllegalStateException(e);
				}
			}
			ClassParameter classParam = field.getAnnotation(ClassParameter.class);
			if (classParam != null) {
				allParams.add(Arrays.asList(classParam.value()));
				fields.add(field);
			}
			GenerateSeeds seed = field.getAnnotation(GenerateSeeds.class);
			if (seed != null) {
				Random rnd = new Random(seed.seed());
				Object[] seeds = IntStream.range(0, seed.value()).mapToLong(i -> rnd.nextLong()).boxed().toArray();
				seeds[0] = seed.first();

				allParams.add(Arrays.asList(seeds));
				fields.add(field);
			}
		}

		List<PreparedRun.Run> runs = new ArrayList<>();
		BatchRun<T> setup;
		try {
			setup = clazz.getDeclaredConstructor().newInstance();
		} catch (ReflectiveOperationException e) {
			log.error("Could not create run class", e);
			throw new IllegalArgumentException(e);
		}

		Config base = setup.baseCase(0);
		if (base != null)
			runs.add(new PreparedRun.Run(0, Lists.newArrayList("base"), base, null));

		List<List<Object>> combinations = Lists.cartesianProduct(Lists.newArrayList(allParams));

		int id = setup.getOffset();
		for (List<Object> params : combinations) {

			try {
				T inst = paramClazz.getDeclaredConstructor().newInstance();
				for (int i = 0; i < params.size(); i++) {
					fields.get(i).setAccessible(true);
					fields.get(i).set(inst, params.get(i));
				}

				Config config = setup.prepareConfig(++id, inst);

				if (config != null)
					runs.add(new PreparedRun.Run(id, params, config, inst));

			} catch (ReflectiveOperationException e) {
				log.error("Could not create param class", e);
				throw new IllegalArgumentException(e);
			}
		}

		log.info("Prepared {} runs for {} with params {}", runs.size(), clazz.getSimpleName(), paramClazz.getName());

		return new PreparedRun(setup, fields.stream().map(Field::getName).collect(Collectors.toList()), allParams, runs);
	}


	/**
	 * Resolve input path automatically using given input, or cluster input directory.
	 *
	 * @param input input path to resolve for
	 * @param name  file name
	 * @return resolved input file name
	 */
	static String resolveForCluster(Path input, String name) {
		if (System.getProperty("EPISIM_ON_CLUSTER", "false").equals("true"))
			input = Path.of("/scratch/projects/bzz0020/episim-input");

		// convert windows path separators
		return input.resolve(name).toString().replace("\\", "/");
	}

	/**
	 * Get the offset for this run, that is used to generate ids.
	 * This can be used to concatenate multiple runs together if they have disjunkt ids.
	 */
	default int getOffset() {
		return 0;
	}

	/**
	 * The default start of the scenario as day in real world. Only needed if there are multiple start dates in the batch run.
	 */
	default LocalDate getDefaultStartDate() {
		return LocalDate.now();
	}

	/**
	 * Returns name of the region.
	 */
	default Metadata getMetadata() {
		return Metadata.of("region", "default");
	}

	/**
	 * List of options that will be added to the metadata.
	 */
	default List<Option> getOptions() {
		return List.of();
	}

	/**
	 * Return the module that should be used for configuring custom guice bindings. May also be parametrized.
	 *
	 * @param id     task id
	 * @param params parameters to use, will be null for the base case.
	 * @return module with additional bindings, or null if not needed
	 */
	@Nullable
	default Module getBindings(int id, @Nullable T params) {
		return null;
	}

	/**
	 * Provide a base case without any parametrization.
	 */
	@Nullable
	default Config baseCase(int id) {
		return null;
	}

	/**
	 * Prepare a config using the given parameters that will be used for this batch run.
	 * Any other defined config is replaced.
	 *
	 * @param id     task id
	 * @param params parameters to use
	 * @return initialized config
	 */
	@Nullable
	Config prepareConfig(int id, T params);

	/**
	 * Return classes of {@link OutputAnalysis} that should be executed on each individual run.
	 */
	default Collection<OutputAnalysis> postProcessing() {
		return List.of(new RValuesFromEvents().withArgs());
	}

	/**
	 * Write additionally needed files to {@code directory}, if any are needed.
	 * Don't write the config!
	 */
	default void writeAuxiliaryFiles(Path directory, Config config) throws IOException {
		EpisimConfigGroup episimConfig = ConfigUtils.addOrGetModule(config, EpisimConfigGroup.class);
		if (episimConfig.getPolicyConfig() != null)
			Files.writeString(directory.resolve(episimConfig.getPolicyConfig()), episimConfig.getPolicy().root().render());

		if (!episimConfig.getProgressionConfig().isEmpty())
			Files.writeString(directory.resolve(episimConfig.getProgressionConfigName()), episimConfig.getProgressionConfig().root().render());
	}

	/**
	 * This declares a field as parameter for a batch run.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface Parameter {
		/**
		 * All values this parameter should attain.
		 */
		double[] value();
	}

	/**
	 * Declares parameter as dates. Receiver must be {@link LocalDate} and named {@code startDate}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface StartDates {
		/**
		 * Desired start dates in the form yyyy-mm-dd.
		 */
		String[] value();
	}

	/**
	 * See {@link Parameter}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface IntParameter {
		int[] value();
	}

	/**
	 * See {@link Parameter}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface LongParameter {
		long[] value();
	}

	/**
	 * See {@link Parameter}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface StringParameter {
		String[] value();
	}


	/**
	 * See {@link Parameter}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface EnumParameter {
		/**
		 * Desired enum class, by default all values will be used.
		 */
		Class<? extends Enum<?>> value();
		String[] ignore() default {};
	}

	/**
	 * See {@link Parameter}.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface ClassParameter {
		/**
		 * List of classes to use as parameters.
		 */
		Class<?>[] value();
	}

	/**
	 * Generates desired number of seeds by using a different random number generator.
	 */
	@Target(ElementType.FIELD)
	@Retention(RetentionPolicy.RUNTIME)
	@interface GenerateSeeds {
		/**
		 * Number of seeds to generate.
		 */
		int value();

		/**
		 * The first seed, which is fixed and not generated.
		 */
		long first() default 4711L;

		/**
		 * Starting seed to feed into the first rng.
		 */
		int seed() default 1;
	}


	/**
	 * Describes one option group of parameters with multiple measures.
	 */
	final class Option {

		public final String heading;
		public final String subheading;
		public final int day;

		/**
		 * Tuples of (title, paramName).
		 */
		public final List<Pair<String, String>> measures = new ArrayList<>();

		private Option(String heading, String subheading, int day) {
			this.heading = heading;
			this.subheading = subheading;
			this.day = day;
		}

		/**
		 * Creates a new option group.
		 *
		 * @param heading    header shown in ui
		 * @param subheading description shown ui
		 * @param day        day when it will be in effect
		 */
		public static Option of(String heading, String subheading, int day) {
			return new Option(heading, subheading, day);
		}

		/**
		 * See {@link #of(String, String, int)}.
		 */
		public static Option of(String heading, int day) {
			return new Option(heading, "", day);
		}

		/**
		 * See {@link #of(String, String, int)}.
		 */
		public static Option of(String heading) {
			return new Option(heading, "", -1);
		}

		/**
		 * Adds an measure to this option.
		 *
		 * @param title title shown in ui
		 * @param param name of the parameter in code
		 */
		public Option measure(String title, String param) {
			measures.add(Pair.of(title, param));
			return this;
		}
	}

	/**
	 * Contains the metadata of a batch run.
	 */
	final class Metadata {

		public final String region;
		public final String name;

		/**
		 * End date for the ui.
		 */
		String endDate = null;

		public Metadata(String region, String name) {
			this.region = region;
			this.name = name;
		}

		/**
		 * Creates new metadata instance.
		 */
		public static Metadata of(String region, String name) {
			return new Metadata(region, name);
		}

		/**
		 * Sets the end date and returns the same instance.
		 */
		public Metadata withEndDate(String date) {
			this.endDate = date;
			return this;
		}

	}

}