/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.trainable_segmention.pixel_feature.calculator;

import java.util.List;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.imagej.ops.OpEnvironment;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.trainable_segmention.RevampUtils;
import net.imglib2.trainable_segmention.pixel_feature.calculator.ColorInputPreprocessor;
import net.imglib2.trainable_segmention.pixel_feature.calculator.GrayInputPreprocessor;
import net.imglib2.trainable_segmention.pixel_feature.calculator.InputPreprocessor;
import net.imglib2.trainable_segmention.pixel_feature.calculator.MultiChannelInputPreprocessor;
import net.imglib2.trainable_segmention.pixel_feature.filter.FeatureJoiner;
import net.imglib2.trainable_segmention.pixel_feature.filter.FeatureOp;
import net.imglib2.trainable_segmention.pixel_feature.settings.ChannelSetting;
import net.imglib2.trainable_segmention.pixel_feature.settings.FeatureSettings;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;

public class FeatureCalculator {
    private final FeatureJoiner joiner;
    private final FeatureSettings settings;
    private final InputPreprocessor preprocessor;

    public FeatureCalculator(OpEnvironment ops, FeatureSettings settings) {
        this.settings = settings;
        List<FeatureOp> featureOps = settings.features().stream().map(x -> x.newInstance(ops, settings.globals())).collect(Collectors.toList());
        this.joiner = new FeatureJoiner(featureOps);
        this.preprocessor = this.initPreprocessor(settings.globals().channelSetting());
    }

    private InputPreprocessor initPreprocessor(ChannelSetting channelSetting) {
        if (ChannelSetting.RGB.equals(channelSetting)) {
            return new ColorInputPreprocessor(this.settings.globals());
        }
        if (ChannelSetting.SINGLE.equals(channelSetting)) {
            return new GrayInputPreprocessor(this.settings.globals());
        }
        if (channelSetting.isMultiple()) {
            return new MultiChannelInputPreprocessor(this.settings.globals());
        }
        throw new UnsupportedOperationException("Unsupported channel setting: " + this.settings().globals().channelSetting());
    }

    public OpEnvironment ops() {
        return this.joiner.ops();
    }

    public FeatureSettings settings() {
        return this.settings;
    }

    public List<FeatureOp> features() {
        return this.joiner.features();
    }

    public int count() {
        return this.joiner.count() * this.channelCount();
    }

    public List<String> attributeLabels() {
        return FeatureCalculator.prepend(this.settings.globals().channelSetting().channels(), this.joiner.attributeLabels());
    }

    public void apply(RandomAccessible<?> input, List<RandomAccessibleInterval<FloatType>> output) {
        List<RandomAccessible<FloatType>> channels = this.preprocessor.getChannels(input);
        List<List<RandomAccessibleInterval<FloatType>>> outputs = FeatureCalculator.split(output, channels.size());
        for (int i = 0; i < channels.size(); ++i) {
            this.joiner.apply(channels.get(i), outputs.get(i));
        }
    }

    public RandomAccessibleInterval<FloatType> apply(RandomAccessibleInterval<?> image) {
        return this.apply((RandomAccessible<?>)Views.extendBorder(image), this.preprocessor.outputIntervalFromInput(image));
    }

    public RandomAccessibleInterval<FloatType> apply(RandomAccessible<?> extendedImage, Interval interval) {
        if (interval.numDimensions() != this.settings().globals().numDimensions()) {
            throw new IllegalArgumentException("Wrong dimension of the output interval.");
        }
        Img result = this.ops().create().img((Dimensions)RevampUtils.appendDimensionToInterval(interval, 0L, this.count() - 1), (NativeType)new FloatType());
        this.apply(extendedImage, RevampUtils.slices(result));
        return result;
    }

    public Interval outputIntervalFromInput(RandomAccessibleInterval<?> image) {
        return this.preprocessor.outputIntervalFromInput(image);
    }

    private int channelCount() {
        return this.settings.globals().channelSetting().channels().size();
    }

    private static List<String> prepend(List<String> prepend, List<String> labels) {
        return labels.stream().flatMap(label -> prepend.stream().map(pre -> pre.isEmpty() ? label : pre + "_" + label)).collect(Collectors.toList());
    }

    private static <T> List<List<T>> split(List<T> input, int count) {
        return IntStream.range(0, count).mapToObj(i -> FeatureCalculator.filterByIndexPredicate(input, index -> index % count == i)).collect(Collectors.toList());
    }

    private static <T> List<T> filterByIndexPredicate(List<T> in, IntPredicate predicate) {
        return IntStream.range(0, in.size()).filter(predicate).mapToObj(in::get).collect(Collectors.toList());
    }
}

