/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.labkit.segmentation;

import java.util.Arrays;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.cache.img.CellLoader;
import net.imglib2.cache.img.DiskCachedCellImgFactory;
import net.imglib2.cache.img.DiskCachedCellImgOptions;
import net.imglib2.img.Img;
import net.imglib2.img.cell.CellGrid;
import net.imglib2.labkit.inputimage.ImgPlusViewsOld;
import net.imglib2.labkit.segmentation.Segmenter;
import net.imglib2.labkit.utils.DimensionUtils;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;

public class SegmentationUtils {
    private SegmentationUtils() {
    }

    public static Img<FloatType> createCachedProbabilityMap(Segmenter segmenter, ImgPlus<?> image) {
        int[] cellSize = segmenter.suggestCellSize(image);
        CellLoader loader = target -> segmenter.predict(image, SegmentationUtils.ensureCellSize(segmenter, cellSize, target));
        Interval interval = SegmentationUtils.intervalNoChannels(image);
        int count = segmenter.classNames().size();
        CellGrid gridWithoutChannels = new CellGrid(Intervals.dimensionsAsLongArray((Dimensions)interval), cellSize);
        CellGrid gridWithChannel = SegmentationUtils.addDimensionToGrid(count, gridWithoutChannels);
        return SegmentationUtils.setupCachedImage(loader, gridWithChannel, new FloatType());
    }

    public static Img<ShortType> createCachedSegmentation(Segmenter segmenter, ImgPlus<?> image) {
        int[] cellSize = segmenter.suggestCellSize(image);
        CellLoader loader = target -> segmenter.segment(image, SegmentationUtils.ensureCellSize(segmenter, cellSize, target));
        Interval interval = SegmentationUtils.intervalNoChannels(image);
        CellGrid grid = new CellGrid(Intervals.dimensionsAsLongArray((Dimensions)interval), cellSize);
        return SegmentationUtils.setupCachedImage(loader, grid, new ShortType());
    }

    private static CellGrid addDimensionToGrid(int size, CellGrid grid) {
        return new CellGrid(DimensionUtils.extend(grid.getImgDimensions(), (long)size), DimensionUtils.extend(SegmentationUtils.getCellDimensions(grid), size));
    }

    private static <T extends NativeType<T> & NumericType<T>> RandomAccessibleInterval<T> ensureCellSize(Segmenter segmenter, int[] cellSize, RandomAccessibleInterval<T> target) {
        int[] targetSize;
        if (segmenter.requiresFixedCellSize() && !Arrays.equals(cellSize, targetSize = Intervals.dimensionsAsIntArray(target))) {
            long[] min = Intervals.minAsLongArray(target);
            long[] max = new long[min.length];
            Arrays.setAll(max, d -> min[d] + (long)cellSize[d] - 1L);
            return Views.interval((RandomAccessible)Views.extendZero(target), (long[])min, (long[])max);
        }
        return target;
    }

    public static Interval intervalNoChannels(ImgPlus<?> image) {
        return new FinalInterval(ImgPlusViewsOld.hasAxis(image, Axes.CHANNEL) ? ImgPlusViewsOld.hyperSlice(image, Axes.CHANNEL, 0L) : image);
    }

    private static <T extends NativeType<T>> Img<T> setupCachedImage(CellLoader<T> loader, CellGrid grid, T type) {
        int[] cellDimensions = SegmentationUtils.getCellDimensions(grid);
        long[] imgDimensions = grid.getImgDimensions();
        Arrays.setAll(cellDimensions, i -> (int)Math.min((long)cellDimensions[i], imgDimensions[i]));
        DiskCachedCellImgOptions optional = (DiskCachedCellImgOptions)((DiskCachedCellImgOptions)DiskCachedCellImgOptions.options().cellDimensions(cellDimensions)).initializeCellsAsDirty(true);
        DiskCachedCellImgFactory factory = new DiskCachedCellImgFactory(type, optional);
        return factory.create(imgDimensions, loader, (DiskCachedCellImgOptions)DiskCachedCellImgOptions.options().initializeCellsAsDirty(true));
    }

    private static int[] getCellDimensions(CellGrid grid) {
        int[] cellDimensions = new int[grid.numDimensions()];
        grid.cellDimensions(cellDimensions);
        return cellDimensions;
    }
}

