/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.trainable_segmention.pixel_feature.filter.hessian;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.gradient.PartialDerivative;
import net.imglib2.img.Img;
import net.imglib2.trainable_segmention.RevampUtils;
import net.imglib2.trainable_segmention.pixel_feature.filter.AbstractFeatureOp;
import net.imglib2.trainable_segmention.pixel_feature.filter.FeatureOp;
import net.imglib2.trainable_segmention.pixel_feature.filter.hessian.EigenValues;
import net.imglib2.trainable_segmention.pixel_feature.settings.GlobalSettings;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;
import net.imglib2.view.composite.CompositeIntervalView;
import net.imglib2.view.composite.RealComposite;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;

@Plugin(type=FeatureOp.class, label="Hessian")
public class SingleHessian3DFeature
extends AbstractFeatureOp {
    @Parameter
    double sigma = 4.0;
    @Parameter
    boolean absoluteValues = true;

    @Override
    public int count() {
        return 3;
    }

    @Override
    public List<String> attributeLabels() {
        return Stream.of("largest", "middle", "smallest").map(x -> "Hessian_" + x + "_" + this.sigma + "_true").collect(Collectors.toList());
    }

    @Override
    public void apply(RandomAccessible<FloatType> input, List<RandomAccessibleInterval<FloatType>> output) {
        this.calculateHessianOnChannel(input, (RandomAccessibleInterval<FloatType>)Views.stack(output), this.sigma);
    }

    @Override
    public boolean checkGlobalSettings(GlobalSettings globals) {
        return globals.numDimensions() == 3;
    }

    private void calculateHessianOnChannel(RandomAccessible<FloatType> image, RandomAccessibleInterval<FloatType> out, double sigma) {
        double[] sigmas = new double[]{0.4 * sigma, 0.4 * sigma, 0.4 * sigma};
        Interval secondDerivativeInterval = RevampUtils.removeLastDimension(out);
        FinalInterval firstDerivativeInterval = Intervals.expand((Interval)secondDerivativeInterval, (long)1L);
        FinalInterval blurredInterval = Intervals.expand((Interval)firstDerivativeInterval, (long)1L);
        RandomAccessibleInterval<FloatType> blurred = RevampUtils.gauss(this.ops(), image, (Interval)blurredInterval, sigmas);
        RandomAccessibleInterval<FloatType> dx = this.derive((RandomAccessible<FloatType>)blurred, (Interval)firstDerivativeInterval, 0);
        RandomAccessibleInterval<FloatType> dy = this.derive((RandomAccessible<FloatType>)blurred, (Interval)firstDerivativeInterval, 1);
        RandomAccessibleInterval<FloatType> dz = this.derive((RandomAccessible<FloatType>)blurred, (Interval)firstDerivativeInterval, 2);
        RandomAccessibleInterval<RealComposite<FloatType>> secondDerivatives = this.calculateSecondDerivatives(secondDerivativeInterval, dx, dy, dz);
        CompositeIntervalView eigenValues = Views.collapseReal(out);
        Views.interval((RandomAccessible)Views.pair(secondDerivatives, (RandomAccessible)eigenValues), (Interval)eigenValues).forEach(p -> this.calculateEigenValues((RealComposite<FloatType>)((RealComposite)p.getA()), (RealComposite<FloatType>)((RealComposite)p.getB())));
    }

    private RandomAccessibleInterval<RealComposite<FloatType>> calculateSecondDerivatives(Interval secondDerivativeInterval, RandomAccessibleInterval<FloatType> dx, RandomAccessibleInterval<FloatType> dy, RandomAccessibleInterval<FloatType> dz) {
        Img secondDerivatives = this.ops().create().img((Dimensions)RevampUtils.appendDimensionToInterval(secondDerivativeInterval, 0L, 5L), (NativeType)new FloatType());
        List slices = RevampUtils.slices(secondDerivatives);
        PartialDerivative.gradientCentralDifference(dx, slices.get(0), (int)0);
        PartialDerivative.gradientCentralDifference(dx, slices.get(1), (int)1);
        PartialDerivative.gradientCentralDifference(dx, slices.get(2), (int)2);
        PartialDerivative.gradientCentralDifference(dy, slices.get(3), (int)1);
        PartialDerivative.gradientCentralDifference(dy, slices.get(4), (int)2);
        PartialDerivative.gradientCentralDifference(dz, slices.get(5), (int)2);
        return Views.collapseReal((RandomAccessibleInterval)secondDerivatives);
    }

    private void calculateEigenValues(RealComposite<FloatType> derivatives, RealComposite<FloatType> eigenValues) {
        EigenValues.Vector3D v = new EigenValues.Vector3D();
        EigenValues.eigenvalues(v, ((FloatType)derivatives.get(0L)).getRealDouble(), ((FloatType)derivatives.get(1L)).getRealDouble(), ((FloatType)derivatives.get(2L)).getRealDouble(), ((FloatType)derivatives.get(3L)).getRealDouble(), ((FloatType)derivatives.get(4L)).getRealDouble(), ((FloatType)derivatives.get(5L)).getRealDouble());
        if (this.absoluteValues) {
            EigenValues.abs(v);
        }
        EigenValues.sort(v);
        ((FloatType)eigenValues.get(0L)).setReal(v.x);
        ((FloatType)eigenValues.get(1L)).setReal(v.y);
        ((FloatType)eigenValues.get(2L)).setReal(v.z);
    }

    private RandomAccessibleInterval<FloatType> derive(RandomAccessible<FloatType> source, Interval interval, int dimension) {
        Img target = this.ops().create().img((Dimensions)interval, (NativeType)new FloatType());
        PartialDerivative.gradientCentralDifference(source, (RandomAccessibleInterval)target, (int)dimension);
        return target;
    }
}

