/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.trainable_segmention.pixel_feature.filter.gabor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiConsumer;
import net.imagej.ops.OpEnvironment;
import net.imglib2.Cursor;
import net.imglib2.Dimensions;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.fft2.FFTConvolution;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.trainable_segmention.RevampUtils;
import net.imglib2.trainable_segmention.pixel_feature.filter.AbstractFeatureOp;
import net.imglib2.trainable_segmention.pixel_feature.filter.FeatureOp;
import net.imglib2.trainable_segmention.pixel_feature.settings.GlobalSettings;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
import net.imglib2.view.composite.CompositeIntervalView;
import net.imglib2.view.composite.GenericComposite;
import org.scijava.plugin.Parameter;
import org.scijava.plugin.Plugin;

@Plugin(type=FeatureOp.class, label="Gabor")
public class SingleGaborFeature
extends AbstractFeatureOp {
    @Parameter
    private double sigma;
    @Parameter
    private double gamma;
    @Parameter
    private double psi;
    @Parameter
    private double frequency;
    @Parameter
    private int nAngles;
    @Parameter
    private boolean legacyNormalize = false;
    private List<Img<FloatType>> kernels;

    public void initialize() {
        if (this.sigma == 0.0) {
            throw new AssertionError((Object)"sigma must be non zero.");
        }
        this.kernels = this.initGaborKernels(this.sigma, this.gamma, this.psi, this.frequency, this.nAngles);
    }

    @Override
    public int count() {
        return 2;
    }

    @Override
    public void apply(RandomAccessible<FloatType> in, List<RandomAccessibleInterval<FloatType>> out) {
        this.gaborProcessChannel(this.kernels, in, out.get(0), out.get(1));
    }

    @Override
    public List<String> attributeLabels() {
        String details = "_" + this.sigma + "_" + this.gamma + "_" + (int)(this.psi / 0.7853981633974483) + "_" + this.frequency;
        return Arrays.asList("Gabor_1" + details, "Gabor_2" + details);
    }

    @Override
    public boolean checkGlobalSettings(GlobalSettings globals) {
        return globals.numDimensions() == 2;
    }

    private List<Img<FloatType>> initGaborKernels(double sigma, double gamma, double psi, double frequency, int nAngles) {
        int largerSigma;
        double sigma_x = sigma;
        double sigma_y = sigma / gamma;
        int n = largerSigma = sigma_x > sigma_y ? (int)sigma_x : (int)sigma_y;
        if (largerSigma < 1) {
            largerSigma = 1;
        }
        int filterSizeX = 6 * largerSigma + 1;
        int filterSizeY = 6 * largerSigma + 1;
        int middleX = Math.round(filterSizeX / 2);
        int middleY = Math.round(filterSizeY / 2);
        ArrayList<Img<FloatType>> kernels = new ArrayList<Img<FloatType>>();
        double rotationAngle = Math.PI / (double)nAngles;
        FinalInterval interval = new FinalInterval(new long[]{-middleX, -middleY}, new long[]{middleX, middleX});
        for (int i = 0; i < nAngles; ++i) {
            double theta = rotationAngle * (double)i;
            Img kernel = this.ops().create().img((Dimensions)interval, (NativeType)new FloatType());
            SingleGaborFeature.garborKernel((Img<FloatType>)kernel, psi, frequency, sigma_x, sigma_y, theta);
            kernels.add((Img<FloatType>)kernel);
        }
        return kernels;
    }

    private static void garborKernel(Img<FloatType> kernel, double psi, double frequency, double sigma_x, double sigma_y, double theta) {
        Cursor cursor = kernel.cursor();
        double filterSizeX = kernel.max(0) - kernel.min(0) + 1L;
        double sigma_x2 = sigma_x * sigma_x;
        double sigma_y2 = sigma_y * sigma_y;
        while (cursor.hasNext()) {
            cursor.next();
            double x = cursor.getDoublePosition(0);
            double y = cursor.getDoublePosition(1);
            double xPrime = x * Math.cos(theta) + y * Math.sin(theta);
            double yPrime = y * Math.cos(theta) - x * Math.sin(theta);
            double a = 1.0 / (Math.PI * 2 * sigma_x * sigma_y) * Math.exp(-0.5 * (xPrime * xPrime / sigma_x2 + yPrime * yPrime / sigma_y2));
            double c = Math.cos(Math.PI * 2 * (frequency * xPrime) / filterSizeX + psi);
            ((FloatType)cursor.get()).set((float)(a * c));
        }
    }

    private RandomAccessibleInterval<FloatType> gaborProcessChannel(List<Img<FloatType>> kernels, Img<FloatType> channel, String labelDetails) {
        Img max = this.ops().create().img(channel);
        Img min = this.ops().create().img(channel);
        this.gaborProcessChannel(kernels, (RandomAccessible<FloatType>)Views.extendBorder(channel), (RandomAccessibleInterval<FloatType>)max, (RandomAccessibleInterval<FloatType>)min);
        return Views.stack((RandomAccessibleInterval[])new RandomAccessibleInterval[]{max, min});
    }

    private void gaborProcessChannel(List<Img<FloatType>> kernels, RandomAccessible<FloatType> channel, RandomAccessibleInterval<FloatType> max, RandomAccessibleInterval<FloatType> min) {
        RandomAccessibleInterval<FloatType> interval = min;
        Img stack = this.ops().create().img((Dimensions)RevampUtils.appendDimensionToInterval(interval, 0L, kernels.size() - 1), (NativeType)new FloatType());
        FFTConvolution fftConvolution = new FFTConvolution(channel, interval, (RandomAccessible)kernels.get(0), (Interval)kernels.get(0), (ImgFactory)new ArrayImgFactory());
        fftConvolution.setKeepImgFFT(true);
        for (int i = 0; i < kernels.size(); ++i) {
            Img<FloatType> kernel = kernels.get(i);
            IntervalView slice = Views.hyperSlice((RandomAccessibleInterval)stack, (int)2, (long)i);
            fftConvolution.setKernel(kernel);
            fftConvolution.setOutput((RandomAccessibleInterval)slice);
            fftConvolution.convolve();
            if (!this.legacyNormalize) continue;
            SingleGaborFeature.normalize(this.ops(), (RandomAccessibleInterval<FloatType>)slice);
        }
        SingleGaborFeature.maxAndMinProjection((Img<FloatType>)stack, max, min);
    }

    private static void maxAndMinProjection(Img<FloatType> stack, RandomAccessibleInterval<FloatType> max, RandomAccessibleInterval<FloatType> min) {
        CompositeIntervalView collapsed = Views.collapse(stack);
        long size = stack.max(2) - stack.min(2) + 1L;
        SingleGaborFeature.map(collapsed, max, (in, out) -> out.set(SingleGaborFeature.max((GenericComposite<FloatType>)in, size)));
        SingleGaborFeature.map(collapsed, min, (in, out) -> out.set(SingleGaborFeature.min((GenericComposite<FloatType>)in, size)));
    }

    private static <I, O> void map(RandomAccessible<I> in, RandomAccessibleInterval<O> out, BiConsumer<I, O> operation) {
        Views.interval((RandomAccessible)Views.pair(in, out), out).forEach(p -> operation.accept(p.getA(), p.getB()));
    }

    private static float max(GenericComposite<FloatType> in, long size) {
        float max = Float.NEGATIVE_INFINITY;
        int i = 0;
        while ((long)i < size) {
            max = Math.max(max, ((FloatType)in.get((long)i)).get());
            ++i;
        }
        return max;
    }

    private static float min(GenericComposite<FloatType> in, long size) {
        float min = Float.POSITIVE_INFINITY;
        int i = 0;
        while ((long)i < size) {
            min = Math.min(min, ((FloatType)in.get((long)i)).get());
            ++i;
        }
        return min;
    }

    static void normalize(OpEnvironment ops, RandomAccessibleInterval<FloatType> image2) {
        DoubleType mean = (DoubleType)ops.stats().mean((Iterable)Views.iterable(image2));
        DoubleType stdDev = (DoubleType)ops.stats().stdDev((Iterable)Views.iterable(image2));
        float mean2 = (float)mean.get();
        float invStdDev = stdDev.get() == 0.0 ? 1.0f : (float)(1.0 / stdDev.get());
        Views.iterable(image2).forEach(value -> value.set((value.get() - mean2) * invStdDev));
    }
}

