/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.labkit.models;

import java.util.AbstractList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.imagej.ImgPlus;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.cell.CellImg;
import net.imglib2.img.cell.CellImgFactory;
import net.imglib2.labkit.inputimage.InputImage;
import net.imglib2.labkit.labeling.Labeling;
import net.imglib2.labkit.models.ImageLabelingModel;
import net.imglib2.labkit.models.SegmentationModel;
import net.imglib2.labkit.models.SegmenterListModel;
import net.imglib2.labkit.segmentation.Segmenter;
import net.imglib2.labkit.utils.DimensionUtils;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import org.scijava.Context;

public class DefaultSegmentationModel
implements SegmentationModel {
    private final Context context;
    private final ImageLabelingModel imageLabelingModel;
    private final SegmenterListModel segmenterList;

    public DefaultSegmentationModel(Context context, InputImage inputImage) {
        this.context = context;
        this.imageLabelingModel = new ImageLabelingModel(inputImage);
        this.segmenterList = new SegmenterListModel(context);
        this.segmenterList().trainingData().set(new SingletonTrainingData(this.imageLabelingModel));
    }

    @Override
    public Context context() {
        return this.context;
    }

    @Override
    public ImageLabelingModel imageLabelingModel() {
        return this.imageLabelingModel;
    }

    @Override
    public SegmenterListModel segmenterList() {
        return this.segmenterList;
    }

    public <T extends IntegerType<T> & NativeType<T>> List<RandomAccessibleInterval<T>> getSegmentations(T type) {
        ImgPlus<?> image = this.imageLabelingModel().imageForSegmentation().get();
        Stream<Segmenter> trainedSegmenters = this.getTrainedSegmenters();
        return trainedSegmenters.map(segmenter -> {
            CellImg labels = new CellImgFactory((NativeType)type).create((Dimensions)image);
            segmenter.segment(image, (RandomAccessibleInterval<? extends IntegerType<?>>)labels);
            return labels;
        }).collect(Collectors.toList());
    }

    public List<RandomAccessibleInterval<FloatType>> getPredictions() {
        ImgPlus<?> image = this.imageLabelingModel().imageForSegmentation().get();
        Stream<Segmenter> trainedSegmenters = this.getTrainedSegmenters();
        return trainedSegmenters.map(segmenter -> {
            int numberOfClasses = segmenter.classNames().size();
            CellImg prediction = new CellImgFactory((NativeType)new FloatType()).create((Dimensions)DimensionUtils.appendDimensionToInterval((Interval)image, 0L, numberOfClasses - 1));
            segmenter.predict(image, (RandomAccessibleInterval<? extends RealType<?>>)prediction);
            return prediction;
        }).collect(Collectors.toList());
    }

    public boolean isTrained() {
        return this.getTrainedSegmenters().findAny().isPresent();
    }

    private Stream<Segmenter> getTrainedSegmenters() {
        return this.segmenterList.segmenters().get().stream().filter(Segmenter::isTrained).map(x -> x);
    }

    private class SingletonTrainingData
    extends AbstractList<Pair<ImgPlus<?>, Labeling>> {
        private final ImageLabelingModel imageLabelingModel;

        public SingletonTrainingData(ImageLabelingModel imageLabelingModel) {
            this.imageLabelingModel = imageLabelingModel;
        }

        @Override
        public Pair<ImgPlus<?>, Labeling> get(int index) {
            ImgPlus<?> image = this.imageLabelingModel.imageForSegmentation().get();
            Labeling labeling = this.imageLabelingModel.labeling().get();
            return new ValuePair(image, (Object)labeling);
        }

        @Override
        public int size() {
            return 1;
        }
    }
}

