/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.trainable_segmention.classification;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.imagej.ops.OpEnvironment;
import net.imagej.ops.Ops;
import net.imagej.ops.special.hybrid.AbstractUnaryHybridCF;
import net.imagej.ops.special.hybrid.UnaryHybridCF;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.trainable_segmention.RevampUtils;
import net.imglib2.trainable_segmention.classification.ClassifierSerialization;
import net.imglib2.trainable_segmention.classification.CompositeInstance;
import net.imglib2.trainable_segmention.classification.Training;
import net.imglib2.trainable_segmention.pixel_feature.calculator.FeatureCalculator;
import net.imglib2.trainable_segmention.pixel_feature.settings.FeatureSettings;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;
import net.imglib2.view.composite.Composite;
import net.imglib2.view.composite.CompositeIntervalView;
import net.imglib2.view.composite.GenericComposite;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

public class Segmenter {
    private final FeatureCalculator features;
    private final List<String> classNames;
    private Classifier classifier;
    private boolean isTrained = false;
    private final OpEnvironment ops;

    public Segmenter(OpEnvironment ops, List<String> classNames, FeatureCalculator features, Classifier classifier) {
        this.ops = Objects.requireNonNull(ops);
        this.classNames = Collections.unmodifiableList(classNames);
        this.features = Objects.requireNonNull(features);
        this.classifier = Objects.requireNonNull(classifier);
    }

    public Segmenter(OpEnvironment ops, List<String> classNames, FeatureSettings features, Classifier classifier) {
        this(ops, classNames, new FeatureCalculator(ops, features), classifier);
    }

    public FeatureCalculator features() {
        return this.features;
    }

    public FeatureSettings settings() {
        return this.features.settings();
    }

    public Img<UnsignedByteType> segment(RandomAccessibleInterval<?> image) {
        return this.segment(image, new UnsignedByteType());
    }

    public <T extends IntegerType<T> & NativeType<T>> Img<T> segment(RandomAccessibleInterval<?> image, T type) {
        Objects.requireNonNull(image);
        Objects.requireNonNull(type);
        Interval outputInterval = this.features.outputIntervalFromInput(image);
        Img result = this.ops.create().img((Dimensions)outputInterval, (NativeType<T>)type);
        this.segment((RandomAccessibleInterval<? extends IntegerType<?>>)result, (RandomAccessible<?>)Views.extendBorder(image));
        return result;
    }

    public void segment(RandomAccessibleInterval<? extends IntegerType<?>> out, RandomAccessible<?> image) {
        Objects.requireNonNull(out);
        Objects.requireNonNull(image);
        RandomAccessibleInterval<FloatType> featureValues = this.features.apply(image, (Interval)out);
        this.ops.run(Ops.Map.class, new Object[]{out, Views.collapseReal(featureValues), this.pixelClassificationOp()});
    }

    public RandomAccessibleInterval<? extends Composite<? extends RealType<?>>> predict(RandomAccessibleInterval<?> image) {
        Objects.requireNonNull(image);
        Interval outputInterval = this.features.outputIntervalFromInput(image);
        Img img = this.ops.create().img((Dimensions)RevampUtils.appendDimensionToInterval(outputInterval, 0L, this.classNames.size()), (NativeType)new FloatType());
        CompositeIntervalView collapsed = Views.collapseReal((RandomAccessibleInterval)img);
        this.predict((RandomAccessibleInterval<? extends Composite<? extends RealType<?>>>)collapsed, (RandomAccessible<?>)Views.extendBorder(image));
        return collapsed;
    }

    public void predict(RandomAccessibleInterval<? extends Composite<? extends RealType<?>>> out, RandomAccessible<?> image) {
        Objects.requireNonNull(out);
        Objects.requireNonNull(image);
        RandomAccessibleInterval<FloatType> featureValues = this.features.apply(image, (Interval)out);
        this.ops.run(Ops.Map.class, new Object[]{out, Views.collapseReal(featureValues), this.pixelPredictionOp()});
    }

    public UnaryHybridCF<Composite<? extends RealType<?>>, Composite<? extends RealType<?>>> pixelPredictionOp() {
        return new PixelPredictionOp();
    }

    public UnaryHybridCF<Composite<? extends RealType<?>>, IntegerType<?>> pixelClassificationOp() {
        return new PixelClassifierOp();
    }

    public List<String> classNames() {
        return this.classNames;
    }

    public Training training() {
        return new MyTrainingData();
    }

    public boolean isTrained() {
        return this.isTrained;
    }

    public JsonElement toJsonTree() {
        JsonObject json = new JsonObject();
        json.add("features", this.features.settings().toJson());
        json.add("classNames", new Gson().toJsonTree(this.classNames));
        json.add("classifier", ClassifierSerialization.wekaToJson(this.classifier));
        return json;
    }

    public static Segmenter fromJson(OpEnvironment ops, JsonElement json) {
        JsonObject object = json.getAsJsonObject();
        return new Segmenter(ops, (List<String>)((List)new Gson().fromJson(object.get("classNames"), new TypeToken<List<String>>(){}.getType())), FeatureSettings.fromJson(object.get("features")), ClassifierSerialization.jsonToWeka(object.get("classifier")));
    }

    private Attribute[] attributesAsArray() {
        List<Attribute> attributes = this.attributes();
        return attributes.toArray(new Attribute[attributes.size()]);
    }

    private List<Attribute> attributes() {
        Stream<Attribute> featureAttributes = this.features.attributeLabels().stream().map(Attribute::new);
        Stream<Attribute> classAttribute = Stream.of(new Attribute("class", this.classNames));
        return Stream.concat(featureAttributes, classAttribute).collect(Collectors.toList());
    }

    private class PixelPredictionOp
    extends AbstractUnaryHybridCF<Composite<? extends RealType<?>>, Composite<? extends RealType<?>>> {
        CompositeInstance compositeInstance;

        private PixelPredictionOp() {
            this.compositeInstance = new CompositeInstance(null, Segmenter.this.attributesAsArray());
        }

        public UnaryHybridCF<Composite<? extends RealType<?>>, Composite<? extends RealType<?>>> getIndependentInstance() {
            return new PixelPredictionOp();
        }

        public void compute(Composite<? extends RealType<?>> input, Composite<? extends RealType<?>> output) {
            this.compositeInstance.setSource(input);
            double[] result = RevampUtils.wrapException(() -> Segmenter.this.classifier.distributionForInstance((Instance)this.compositeInstance));
            int n = result.length;
            for (int i = 0; i < n; ++i) {
                ((RealType)output.get((long)i)).setReal(result[i]);
            }
        }

        public Composite<? extends RealType<?>> createOutput(Composite<? extends RealType<?>> input) {
            return new GenericComposite((RandomAccess)ArrayImgs.doubles((long[])new long[]{this.compositeInstance.numClasses()}).randomAccess());
        }
    }

    private class PixelClassifierOp
    extends AbstractUnaryHybridCF<Composite<? extends RealType<?>>, IntegerType<?>> {
        CompositeInstance compositeInstance;

        private PixelClassifierOp() {
            this.compositeInstance = new CompositeInstance(null, Segmenter.this.attributesAsArray());
        }

        public UnaryHybridCF<Composite<? extends RealType<?>>, IntegerType<?>> getIndependentInstance() {
            return new PixelClassifierOp();
        }

        public IntegerType<?> createOutput(Composite<? extends RealType<?>> input) {
            return new UnsignedByteType();
        }

        public void compute(Composite<? extends RealType<?>> input, IntegerType<?> output) {
            this.compositeInstance.setSource(input);
            RevampUtils.wrapException(() -> output.setInteger((int)Segmenter.this.classifier.classifyInstance((Instance)this.compositeInstance)));
        }
    }

    private class MyTrainingData
    implements Training {
        final Instances instances;
        final int featureCount;

        MyTrainingData() {
            this.instances = new Instances("segment", new ArrayList(Segmenter.this.attributes()), 1);
            this.featureCount = Segmenter.this.features.count();
            this.instances.setClassIndex(this.featureCount);
        }

        @Override
        public void add(Composite<? extends RealType<?>> featureVector, int classIndex) {
            this.instances.add((Instance)RevampUtils.getInstance(this.featureCount, classIndex, featureVector));
        }

        @Override
        public void train() {
            RevampUtils.wrapException(() -> Segmenter.this.classifier.buildClassifier(this.instances));
        }
    }
}

