/*
 * Decompiled with CFR 0.152.
 */
package qupath.lib.images.writers.ome.zarr;

import com.bc.zarr.ArrayParams;
import com.bc.zarr.Compressor;
import com.bc.zarr.CompressorFactory;
import com.bc.zarr.DataType;
import com.bc.zarr.DimensionSeparator;
import com.bc.zarr.ZarrArray;
import com.bc.zarr.ZarrGroup;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import loci.formats.gui.AWTImageTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServers;
import qupath.lib.images.servers.PixelType;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.images.servers.TransformedServerBuilder;
import qupath.lib.images.writers.ome.zarr.OMEZarrAttributesCreator;
import qupath.lib.regions.ImageRegion;

public class OMEZarrWriter
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(OMEZarrWriter.class);
    private final ImageServer<BufferedImage> server;
    private final Map<Integer, ZarrArray> levelArrays;
    private final ExecutorService executorService;

    private OMEZarrWriter(Builder builder) throws IOException {
        TransformedServerBuilder transformedServerBuilder = new TransformedServerBuilder(ImageServers.pyramidalizeTiled(builder.server, (int)OMEZarrWriter.getChunkSize(builder.tileWidth > 0 ? builder.tileWidth : builder.server.getMetadata().getPreferredTileWidth(), builder.maxNumberOfChunks, builder.server.getWidth()), (int)OMEZarrWriter.getChunkSize(builder.tileHeight > 0 ? builder.tileHeight : builder.server.getMetadata().getPreferredTileHeight(), builder.maxNumberOfChunks, builder.server.getHeight()), (double[])(builder.downsamples.length == 0 ? builder.server.getPreferredDownsamples() : builder.downsamples)));
        if (builder.zStart != 0 || builder.zEnd != builder.server.nZSlices() || builder.tStart != 0 || builder.tEnd != builder.server.nTimepoints()) {
            transformedServerBuilder.slice(builder.zStart, builder.zEnd, builder.tStart, builder.tEnd);
        }
        if (builder.boundingBox != null) {
            transformedServerBuilder.crop(builder.boundingBox);
        }
        this.server = transformedServerBuilder.build();
        OMEZarrAttributesCreator attributes = new OMEZarrAttributesCreator(this.server.getMetadata().getName(), this.server.nZSlices(), this.server.nTimepoints(), this.server.nChannels(), this.server.getMetadata().getPixelCalibration(), this.server.getMetadata().getTimeUnit(), this.server.getPreferredDownsamples(), this.server.getMetadata().getChannels(), this.server.isRGB(), this.server.getPixelType());
        this.levelArrays = OMEZarrWriter.createLevelArrays(this.server, ZarrGroup.create((String)builder.path, attributes.getGroupAttributes()), attributes.getLevelAttributes(), builder.compressor);
        this.executorService = Executors.newFixedThreadPool(builder.numberOfThreads);
    }

    @Override
    public void close() throws InterruptedException {
        this.executorService.shutdown();
        this.executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
    }

    public void writeImage() {
        for (TileRequest tileRequest : this.server.getTileRequestManager().getAllTileRequests()) {
            this.writeTile(tileRequest);
        }
    }

    public void writeTile(TileRequest tileRequest) {
        this.executorService.execute(() -> {
            try {
                this.levelArrays.get(tileRequest.getLevel()).write(this.getData((BufferedImage)this.server.readRegion(tileRequest.getRegionRequest())), this.getDimensionsOfTile(tileRequest), this.getOffsetsOfTile(tileRequest));
            }
            catch (Exception e) {
                logger.error("Error when writing tile", (Throwable)e);
            }
        });
    }

    public ImageServer<BufferedImage> getReaderServer() {
        return this.server;
    }

    private static int getChunkSize(int tileSize, int maxNumberOfChunks, int imageSize) {
        return Math.min(imageSize, maxNumberOfChunks > 0 ? Math.max(tileSize, imageSize / maxNumberOfChunks) : tileSize);
    }

    private static Map<Integer, ZarrArray> createLevelArrays(ImageServer<BufferedImage> server, ZarrGroup root, Map<String, Object> levelAttributes, Compressor compressor) throws IOException {
        HashMap<Integer, ZarrArray> levelArrays = new HashMap<Integer, ZarrArray>();
        for (int level = 0; level < server.getMetadata().nLevels(); ++level) {
            Integer n = level;
            String string = "s" + level;
            ArrayParams arrayParams = new ArrayParams().shape(OMEZarrWriter.getDimensionsOfImage(server, level)).chunks(OMEZarrWriter.getChunksOfImage(server)).compressor(compressor);
            levelArrays.put(n, root.createArray(string, arrayParams.dataType(switch (server.getPixelType()) {
                default -> throw new MatchException(null, null);
                case PixelType.UINT8 -> DataType.u1;
                case PixelType.INT8 -> DataType.i1;
                case PixelType.UINT16 -> DataType.u2;
                case PixelType.INT16 -> DataType.i2;
                case PixelType.UINT32 -> DataType.u4;
                case PixelType.INT32 -> DataType.i4;
                case PixelType.FLOAT32 -> DataType.f4;
                case PixelType.FLOAT64 -> DataType.f8;
            }).dimensionSeparator(DimensionSeparator.SLASH), levelAttributes));
        }
        return levelArrays;
    }

    private static int[] getDimensionsOfImage(ImageServer<BufferedImage> server, int level) {
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        if (server.nTimepoints() > 1) {
            dimensions.add(server.nTimepoints());
        }
        if (server.nChannels() > 1) {
            dimensions.add(server.nChannels());
        }
        if (server.nZSlices() > 1) {
            dimensions.add(server.nZSlices());
        }
        dimensions.add((int)((double)server.getHeight() / server.getDownsampleForResolution(level)));
        dimensions.add((int)((double)server.getWidth() / server.getDownsampleForResolution(level)));
        return dimensions.stream().mapToInt(i -> i).toArray();
    }

    private static int[] getChunksOfImage(ImageServer<BufferedImage> server) {
        ArrayList<Integer> chunks = new ArrayList<Integer>();
        if (server.nTimepoints() > 1) {
            chunks.add(1);
        }
        if (server.nChannels() > 1) {
            chunks.add(1);
        }
        if (server.nZSlices() > 1) {
            chunks.add(1);
        }
        chunks.add(server.getMetadata().getPreferredTileHeight());
        chunks.add(server.getMetadata().getPreferredTileWidth());
        return chunks.stream().mapToInt(i -> i).toArray();
    }

    private Object getData(BufferedImage image) {
        Object pixels = AWTImageTools.getPixels((BufferedImage)image);
        if (this.server.isRGB()) {
            int[][] data = (int[][])pixels;
            int[] output = new int[this.server.nChannels() * image.getWidth() * image.getHeight()];
            int i = 0;
            for (int c = 0; c < this.server.nChannels(); ++c) {
                for (int y = 0; y < image.getHeight(); ++y) {
                    for (int x = 0; x < image.getWidth(); ++x) {
                        output[i] = data[c][x + image.getWidth() * y];
                        ++i;
                    }
                }
            }
            return output;
        }
        return switch (this.server.getPixelType()) {
            default -> throw new MatchException(null, null);
            case PixelType.UINT8, PixelType.INT8 -> {
                byte[][] data = (byte[][])pixels;
                byte[] output = new byte[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield output;
            }
            case PixelType.UINT16, PixelType.INT16 -> {
                short[][] data = (short[][])pixels;
                short[] output = new short[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.UINT32, PixelType.INT32 -> {
                int[][] data = (int[][])pixels;
                int[] output = new int[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.FLOAT32 -> {
                float[][] data = (float[][])pixels;
                float[] output = new float[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.FLOAT64 -> {
                double[][] data = (double[][])pixels;
                double[] output = new double[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
        };
    }

    private int[] getDimensionsOfTile(TileRequest tileRequest) {
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        if (this.server.nTimepoints() > 1) {
            dimensions.add(1);
        }
        if (this.server.nChannels() > 1) {
            dimensions.add(this.server.nChannels());
        }
        if (this.server.nZSlices() > 1) {
            dimensions.add(1);
        }
        dimensions.add(tileRequest.getTileHeight());
        dimensions.add(tileRequest.getTileWidth());
        return dimensions.stream().mapToInt(i -> i).toArray();
    }

    private int[] getOffsetsOfTile(TileRequest tileRequest) {
        ArrayList<Integer> offset = new ArrayList<Integer>();
        if (this.server.nTimepoints() > 1) {
            offset.add(tileRequest.getT());
        }
        if (this.server.nChannels() > 1) {
            offset.add(0);
        }
        if (this.server.nZSlices() > 1) {
            offset.add(tileRequest.getZ());
        }
        offset.add(tileRequest.getTileY());
        offset.add(tileRequest.getTileX());
        return offset.stream().mapToInt(i -> i).toArray();
    }

    public static class Builder {
        private static final String FILE_EXTENSION = ".ome.zarr";
        private final ImageServer<BufferedImage> server;
        private final String path;
        private Compressor compressor = CompressorFactory.createDefaultCompressor();
        private int numberOfThreads = 12;
        private double[] downsamples = new double[0];
        private int maxNumberOfChunks = 50;
        private int tileWidth = 512;
        private int tileHeight = 512;
        private ImageRegion boundingBox = null;
        private int zStart = 0;
        private int zEnd;
        private int tStart = 0;
        private int tEnd;

        public Builder(ImageServer<BufferedImage> server, String path) {
            if (!path.endsWith(FILE_EXTENSION)) {
                throw new IllegalArgumentException(String.format("The provided path (%s) does not have the OME-Zarr extension (%s)", path, FILE_EXTENSION));
            }
            this.server = server;
            this.path = path;
            this.zEnd = this.server.nZSlices();
            this.tEnd = this.server.nTimepoints();
        }

        public Builder setCompressor(Compressor compressor) {
            this.compressor = compressor;
            return this;
        }

        public Builder setNumberOfThreads(int numberOfThreads) {
            this.numberOfThreads = numberOfThreads;
            return this;
        }

        public Builder setDownsamples(double ... downsamples) {
            this.downsamples = downsamples;
            return this;
        }

        public Builder setMaxNumberOfChunksOnEachSpatialDimension(int maxNumberOfChunks) {
            this.maxNumberOfChunks = maxNumberOfChunks;
            return this;
        }

        public Builder setTileWidth(int tileWidth) {
            this.tileWidth = tileWidth;
            return this;
        }

        public Builder setTileHeight(int tileHeight) {
            this.tileHeight = tileHeight;
            return this;
        }

        public Builder setBoundingBox(ImageRegion boundingBox) {
            this.boundingBox = boundingBox;
            return this;
        }

        public Builder setZSlices(int zStart, int zEnd) {
            this.zStart = zStart;
            this.zEnd = zEnd;
            return this;
        }

        public Builder setTimepoints(int tStart, int tEnd) {
            this.tStart = tStart;
            this.tEnd = tEnd;
            return this;
        }

        public OMEZarrWriter build() throws IOException {
            return new OMEZarrWriter(this);
        }
    }
}

