/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.labkit.segmentation.weka;

import com.google.gson.JsonElement;
import hr.irb.fastRandomForest.FastRandomForest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.swing.JFrame;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imagej.axis.AxisType;
import net.imagej.axis.CalibratedAxis;
import net.imglib2.Cursor;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.cache.img.CellLoader;
import net.imglib2.cache.img.DiskCachedCellImg;
import net.imglib2.cache.img.DiskCachedCellImgFactory;
import net.imglib2.cache.img.DiskCachedCellImgOptions;
import net.imglib2.cache.img.SingleCellArrayImg;
import net.imglib2.img.display.imagej.ImgPlusViews;
import net.imglib2.labkit.inputimage.ImgPlusViewsOld;
import net.imglib2.labkit.labeling.Label;
import net.imglib2.labkit.labeling.Labeling;
import net.imglib2.labkit.labeling.Labelings;
import net.imglib2.labkit.segmentation.Segmenter;
import net.imglib2.labkit.segmentation.weka.TrainableSegmentationSettingsDialog;
import net.imglib2.labkit.utils.LabkitUtils;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.sparse.SparseRandomAccessIntType;
import net.imglib2.trainable_segmentation.classification.Training;
import net.imglib2.trainable_segmentation.gson.GsonUtils;
import net.imglib2.trainable_segmentation.pixel_feature.calculator.FeatureCalculator;
import net.imglib2.trainable_segmentation.pixel_feature.filter.GroupedFeatures;
import net.imglib2.trainable_segmentation.pixel_feature.filter.SingleFeatures;
import net.imglib2.trainable_segmentation.pixel_feature.settings.ChannelSetting;
import net.imglib2.trainable_segmentation.pixel_feature.settings.FeatureSetting;
import net.imglib2.trainable_segmentation.pixel_feature.settings.FeatureSettings;
import net.imglib2.trainable_segmentation.pixel_feature.settings.GlobalSettings;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.ARGBType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.util.Pair;
import net.imglib2.view.ExtendedRandomAccessibleInterval;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import net.imglib2.view.composite.Composite;
import net.imglib2.view.composite.CompositeIntervalView;
import org.scijava.Context;
import weka.classifiers.Classifier;
import weka.core.WekaException;

public class TrainableSegmentationSegmenter
implements Segmenter {
    private final Context context;
    private boolean useGpu;
    private FeatureSettings featureSettings;
    private net.imglib2.trainable_segmentation.classification.Segmenter segmenter;

    public TrainableSegmentationSegmenter(Context context) {
        this.context = Objects.requireNonNull(context);
        this.useGpu = false;
        this.segmenter = null;
        this.featureSettings = null;
    }

    @Override
    public List<String> classNames() {
        return this.segmenter.classNames();
    }

    @Override
    public void editSettings(JFrame dialogParent, List<Pair<ImgPlus<?>, Labeling>> trainingData) {
        this.initFeatureSettings(trainingData);
        TrainableSegmentationSettingsDialog dialog = new TrainableSegmentationSettingsDialog(this.context, dialogParent, this.useGpu, this.featureSettings);
        dialog.show();
        if (dialog.okClicked()) {
            this.featureSettings = dialog.featureSettings();
            this.setUseGpu(dialog.useGpu());
        }
    }

    @Override
    public void segment(ImgPlus<?> image, RandomAccessibleInterval<? extends IntegerType<?>> labels) {
        if (ImgPlusViewsOld.hasAxis(image, Axes.TIME)) {
            this.applyOnSlices(this::segment, image, labels, image.dimensionIndex(Axes.TIME), labels.numDimensions() - 1);
        } else if (ImgPlusViewsOld.hasAxis(image, Axes.Z) && this.is2D()) {
            this.applyOnSlices(this::segment, image, labels, image.dimensionIndex(Axes.Z), labels.numDimensions() - 1);
        } else {
            this.segmenter.segment(labels, (RandomAccessible)Views.extendBorder(image));
        }
    }

    @Override
    public void predict(ImgPlus<?> image, RandomAccessibleInterval<? extends RealType<?>> prediction) {
        if (ImgPlusViewsOld.hasAxis(image, Axes.TIME)) {
            this.applyOnSlices(this::predict, image, prediction, image.dimensionIndex(Axes.TIME), prediction.numDimensions() - 2);
        } else if (ImgPlusViewsOld.hasAxis(image, Axes.Z) && this.is2D()) {
            this.applyOnSlices(this::predict, image, prediction, image.dimensionIndex(Axes.Z), prediction.numDimensions() - 2);
        } else {
            this.segmenter.predict(prediction, (RandomAccessible)Views.extendBorder(image));
        }
    }

    private boolean is2D() {
        return this.segmenter.features().settings().globals().numDimensions() == 2;
    }

    @Override
    public void train(List<Pair<ImgPlus<?>, Labeling>> trainingData) {
        try {
            this.initFeatureSettings(trainingData);
            List<String> classes = TrainableSegmentationSegmenter.collectLabels(trainingData.stream().map(Pair::getB).collect(Collectors.toList()));
            net.imglib2.trainable_segmentation.classification.Segmenter segmenter = new net.imglib2.trainable_segmentation.classification.Segmenter(this.context, classes, this.featureSettings, (Classifier)new FastRandomForest());
            segmenter.setUseGpu(this.useGpu);
            Training training = segmenter.training();
            for (Pair<ImgPlus<?>, Labeling> pair : trainingData) {
                this.trainStack(training, classes, (Labeling)((Object)pair.getB()), (ImgPlus)pair.getA(), segmenter.features());
            }
            training.train();
            this.segmenter = segmenter;
        }
        catch (RuntimeException e) {
            Throwable cause = e.getCause();
            if (cause instanceof WekaException && cause.getMessage().contains("Not enough training instances")) {
                throw new CancellationException("The training requires some labeled regions.");
            }
            throw e;
        }
    }

    public void setUseGpu(boolean useGpu) {
        this.useGpu = useGpu;
        if (this.segmenter != null) {
            this.segmenter.setUseGpu(this.useGpu);
        }
    }

    private static List<String> collectLabels(List<? extends Labeling> labelings) {
        return labelings.stream().flatMap(labeling -> labeling.getLabels().stream()).map(Label::name).distinct().collect(Collectors.toList());
    }

    private void trainStack(Training training, List<String> classes, Labeling labeling, ImgPlus<?> image, FeatureCalculator featuresCalculator) {
        if (ImgPlusViewsOld.hasAxis(image, Axes.TIME)) {
            List<ImgPlus<?>> imageSlices = ImgPlusViewsOld.hyperSlices(image, Axes.TIME);
            List<Labeling> labelSlices = Labelings.slices(labeling);
            for (int i = 0; i < imageSlices.size(); ++i) {
                this.trainStack(training, classes, labelSlices.get(i), imageSlices.get(i), featuresCalculator);
            }
        } else if (ImgPlusViewsOld.hasAxis(image, Axes.Z) && this.featureSettings.globals().numDimensions() == 2) {
            List<ImgPlus<?>> imageSlices = ImgPlusViewsOld.hyperSlices(image, Axes.Z);
            List<Labeling> labelSlices = Labelings.slices(labeling);
            for (int i = 0; i < imageSlices.size(); ++i) {
                this.trainStack(training, classes, labelSlices.get(i), imageSlices.get(i), featuresCalculator);
            }
        } else {
            this.trainFrame(training, classes, labeling, image, featuresCalculator);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void trainFrame(Training training, List<String> classes, Labeling labeling, ImgPlus<?> image, FeatureCalculator featuresCalculator) {
        SparseRandomAccessIntType classIndices = this.getClassIndices(labeling, classes);
        if (classIndices.sparsityPattern().size() == 0L) {
            return;
        }
        DiskCachedCellImg<FloatType, ?> cachedFeatureBlock = this.cachedFeatureBlock(featuresCalculator, image);
        try {
            CompositeIntervalView features = Views.collapse(cachedFeatureBlock);
            this.addSamples(training, classIndices, (RandomAccessible<? extends Composite<FloatType>>)features);
        }
        finally {
            cachedFeatureBlock.shutdown();
        }
    }

    private DiskCachedCellImg<FloatType, ?> cachedFeatureBlock(FeatureCalculator feature, ImgPlus<?> image) {
        int count = feature.count();
        if (count <= 0) {
            throw new IllegalArgumentException();
        }
        long[] dimensions = Intervals.dimensionsAsLongArray((Dimensions)feature.outputIntervalFromInput(image));
        dimensions = LabkitUtils.extend(dimensions, (long)count);
        int[] cellDimensions = this.suggestCellSize(image);
        cellDimensions = LabkitUtils.extend(cellDimensions, count);
        DiskCachedCellImgOptions featureOpts = ((DiskCachedCellImgOptions)DiskCachedCellImgOptions.options().cellDimensions(cellDimensions)).dirtyAccesses(false);
        DiskCachedCellImgFactory featureFactory = new DiskCachedCellImgFactory((NativeType)new FloatType(), featureOpts);
        ExtendedRandomAccessibleInterval input = Views.extendBorder(image);
        CellLoader loader = arg_0 -> TrainableSegmentationSegmenter.lambda$cachedFeatureBlock$1(feature, (RandomAccessible)input, arg_0);
        return featureFactory.create(dimensions, loader);
    }

    private void addSamples(Training training, SparseRandomAccessIntType classIndices, RandomAccessible<? extends Composite<FloatType>> features) {
        Cursor<IntType> classIndicesCursor = classIndices.sparseCursor();
        RandomAccess ra = features.randomAccess();
        while (classIndicesCursor.hasNext()) {
            int classIndex = ((IntType)classIndicesCursor.next()).get();
            ra.setPosition(classIndicesCursor);
            training.add((Composite)ra.get(), classIndex);
        }
    }

    private SparseRandomAccessIntType getClassIndices(Labeling labeling, List<String> classes) {
        SparseRandomAccessIntType result = new SparseRandomAccessIntType((Interval)labeling, -1);
        HashMap<Set, Integer> classIndices = new HashMap<Set, Integer>();
        Function<Set, Integer> compute = set -> set.stream().mapToInt(label -> classes.indexOf(label.name())).filter(i -> i >= 0).min().orElse(-1);
        Cursor<?> cursor = labeling.sparsityCursor();
        RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
        RandomAccess<IntType> out = result.randomAccess();
        while (cursor.hasNext()) {
            cursor.fwd();
            randomAccess.setPosition(cursor);
            Set labels = (Set)randomAccess.get();
            if (labels.isEmpty()) continue;
            Integer classIndex = classIndices.computeIfAbsent(labels, compute);
            out.setPosition(cursor);
            ((IntType)out.get()).set(classIndex.intValue());
        }
        return result;
    }

    @Override
    public boolean isTrained() {
        return this.segmenter != null;
    }

    @Override
    public synchronized void saveModel(String path) {
        GsonUtils.write((JsonElement)this.segmenter.toJsonTree(), (String)path);
    }

    @Override
    public void openModel(String path) {
        this.segmenter = net.imglib2.trainable_segmentation.classification.Segmenter.fromJson((Context)this.context, (JsonElement)GsonUtils.read((String)path));
        this.segmenter.setUseGpu(this.useGpu);
        this.featureSettings = this.segmenter.features().settings();
    }

    @Override
    public int[] suggestCellSize(ImgPlus<?> image) {
        int spacialDimensions;
        int cellSize;
        if (ImgPlusViewsOld.hasAxis(image, Axes.CHANNEL)) {
            image = ImgPlusViewsOld.hyperSlice(image, Axes.CHANNEL, 0L);
        }
        int n = cellSize = (spacialDimensions = ImgPlusViewsOld.numberOfSpatialDimensions(image)) <= 2 ? 128 : 32;
        if (this.useGpu) {
            cellSize *= 2;
        }
        int[] cellDimension = new int[image.numDimensions()];
        for (int i = 0; i < cellDimension.length; ++i) {
            cellDimension[i] = ((CalibratedAxis)image.axis(i)).type().isSpatial() ? cellSize : 1;
        }
        return cellDimension;
    }

    @Override
    public boolean requiresFixedCellSize() {
        return this.useGpu;
    }

    private static List<Double> getPixelSize(ImgPlus<?> image) {
        ArrayList<Double> pixelSize = new ArrayList<Double>();
        double x = TrainableSegmentationSegmenter.getPixelSize(image, Axes.X);
        double y = TrainableSegmentationSegmenter.getPixelSize(image, Axes.Y);
        pixelSize.add(1.0);
        pixelSize.add(y / x);
        if (ImgPlusViewsOld.hasAxis(image, Axes.Z)) {
            double z = TrainableSegmentationSegmenter.getPixelSize(image, Axes.Z);
            pixelSize.add(z / x);
        }
        return pixelSize;
    }

    private static double getPixelSize(ImgPlus<?> image, AxisType axis) {
        double scale = image.averageScale(image.dimensionIndex(axis));
        return Double.isNaN(scale) || scale == 0.0 ? 1.0 : scale;
    }

    private void initFeatureSettings(List<Pair<ImgPlus<?>, Labeling>> trainingData) {
        if (this.featureSettings != null) {
            return;
        }
        GlobalSettings globalSettings = this.initGlobalSettings(trainingData);
        this.featureSettings = new FeatureSettings(globalSettings, new FeatureSetting[]{SingleFeatures.identity(), GroupedFeatures.gauss(), GroupedFeatures.differenceOfGaussians(), GroupedFeatures.gradient(), GroupedFeatures.laplacian(), GroupedFeatures.hessian()});
    }

    private GlobalSettings initGlobalSettings(List<Pair<ImgPlus<?>, Labeling>> trainingData) {
        if (trainingData.isEmpty()) {
            return GlobalSettings.default2d().build();
        }
        ImgPlus image = (ImgPlus)trainingData.get(0).getA();
        ChannelSetting channelSetting = TrainableSegmentationSegmenter.getChannelSetting(image);
        return ((GlobalSettings.Builder)((GlobalSettings.Builder)((GlobalSettings.Builder)((GlobalSettings.Builder)GlobalSettings.default2d().dimensions(ImgPlusViewsOld.numberOfSpatialDimensions(image))).channels(channelSetting)).sigmaRange(1.0, 8.0)).pixelSize(TrainableSegmentationSegmenter.getPixelSize(image))).build();
    }

    private static ChannelSetting getChannelSetting(ImgPlus<?> image) {
        if (ImgPlusViewsOld.hasAxis(image, Axes.CHANNEL)) {
            return ChannelSetting.multiple((int)((int)ImgPlusViewsOld.getDimension(image, Axes.CHANNEL)));
        }
        return image.firstElement() instanceof ARGBType ? ChannelSetting.RGB : ChannelSetting.SINGLE;
    }

    private <T> void applyOnSlices(BiConsumer<ImgPlus<?>, RandomAccessibleInterval<T>> action, ImgPlus<?> image, RandomAccessibleInterval<T> target, int imageTimeAxis, int targetTimeAxis) {
        long min = target.min(targetTimeAxis);
        long max = target.max(targetTimeAxis);
        if (min < image.min(imageTimeAxis) || max > image.max(imageTimeAxis)) {
            throw new IllegalStateException("Last dimensions must fit.");
        }
        for (long pos = min; pos <= max; ++pos) {
            IntervalView targetSlize = Views.hyperSlice(target, (int)targetTimeAxis, (long)pos);
            ImgPlus imageSlice = ImgPlusViews.hyperSlice((ImgPlus)((ImgPlus)Cast.unchecked(image)), (int)imageTimeAxis, (long)pos);
            action.accept((ImgPlus<?>)imageSlice, (RandomAccessibleInterval<ImgPlus>)targetSlize);
        }
    }

    public void setFeatureSettings(FeatureSettings featureSettings) {
        this.featureSettings = featureSettings;
    }

    private static /* synthetic */ void lambda$cachedFeatureBlock$1(FeatureCalculator feature, RandomAccessible input, SingleCellArrayImg target) throws Exception {
        feature.apply(input, (RandomAccessibleInterval)target);
    }
}

