package org.geotools.image.io;

// copied from geotools.image.io.ImageIOExt

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.awt.image.ComponentColorModel;
import java.awt.image.RenderedImage;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import javax.imageio.ImageIO;
import javax.imageio.ImageReadParam;
import javax.imageio.ImageReader;
import javax.imageio.metadata.IIOMetadata;
import javax.imageio.spi.ImageReaderSpi;
import javax.imageio.stream.ImageInputStream;
import javax.media.jai.PlanarImage;
import org.geotools.image.ImageWorker;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;


/**
 *
 * @author mark
 */
public class ImageIOExtWeps {
    
    ImageReader reader;
    ImageReaderSpi spi;
    
    public ImageIOExtWeps () {
        reader = null;
        spi = null;
    }
       
    // copied from geotools.image.io.ImageIOExt
    
    public BufferedImage readBufferedImage(Object input) throws IOException {
        RenderedImage ri = read(input);
        if (ri == null) {
            return null;
        } else if (ri instanceof BufferedImage) {
            return (BufferedImage) ri;
        } else {
            return PlanarImage.wrapRenderedImage(ri).getAsBufferedImage();
        }
    }
    
    
    public RenderedImage read(Object input) throws IOException {
        if (input == null) {
            throw new IllegalArgumentException("input == null!");
        }

        // build an image input stream
        try (ImageInputStream stream = (ImageInputStream)input) {
            if (reader == null) {
                // get the readers
                Iterator<ImageReader> iter = ImageIO.getImageReaders(stream);
                if (!iter.hasNext()) {
                    return null;
                }
                reader = iter.next();
                spi = reader.getOriginatingProvider();
            } else {
                reader = spi.createReaderInstance();
            }
            // work around PNG with transparent RGB color if needed
            // we can remove it once we run on JDK 11, see
            // https://bugs.openjdk.java.net/browse/JDK-6788458
            boolean isJdkPNGReader =
                    "com.sun.imageio.plugins.png.PNGImageReader"
                            .equals(reader.getClass().getName());
            // if it's the JDK PNG reader, we cannot skip the metadata, the tRNS section will be in
            // there
            reader.setInput(stream, true, !isJdkPNGReader);

            BufferedImage bi;
            try {
                ImageReadParam param = reader.getDefaultReadParam();
                bi = reader.read(0, param);
            } catch (javax.imageio.IIOException | IllegalStateException | IndexOutOfBoundsException ex) {
                return null;
            } finally {
                try {
                    reader.dispose();
                } catch (IllegalStateException ex) {
                }
            }

            // apply transparency in post-processing if needs be
            try {
                if (isJdkPNGReader
                        && bi.getColorModel() instanceof ComponentColorModel
                        && !bi.getColorModel().hasAlpha()
                        && bi.getColorModel().getNumComponents() == 3) {
                    IIOMetadata imageMetadata = reader.getImageMetadata(0);
                    Node tree = imageMetadata.getAsTree(imageMetadata.getNativeMetadataFormatName());
                    Node trns_rgb = getNodeFromPath(tree, Arrays.asList("tRNS", "tRNS_RGB"));
                    if (trns_rgb != null) {
                        NamedNodeMap attributes = trns_rgb.getAttributes();
                        Integer red = getIntegerAttribute(attributes, "red");
                        Integer green = getIntegerAttribute(attributes, "green");
                        Integer blue = getIntegerAttribute(attributes, "blue");

                        if (red != null && green != null && blue != null) {
                            Color color = new Color(red, green, blue);
                            ImageWorker iw = new ImageWorker(bi);
                            iw.makeColorTransparent(color);
                            return iw.getRenderedImage();
                        }
                    }
                }
            } catch (javax.imageio.IIOException | IllegalStateException ex) {
                
            }
            return bi;
        }
    }

    /** Locates a node in the tree, by giving a list of path components */
    private static Node getNodeFromPath(Node root, List<String> pathComponents) {
        if (pathComponents.isEmpty()) {
            return root;
        }

        String firstComponent = pathComponents.get(0);
        NodeList childNodes = root.getChildNodes();
        for (int i = 0; i < childNodes.getLength(); i++) {
            Node child = childNodes.item(i);

            if (firstComponent.equals(child.getNodeName())) {
                return getNodeFromPath(child, pathComponents.subList(1, pathComponents.size()));
            }
        }

        // not found
        return null;
    }

    private static Integer getIntegerAttribute(NamedNodeMap attributes, String attributeName) {
        return Optional.ofNullable(attributes.getNamedItem(attributeName))
                .map(n -> n.getNodeValue())
                .map(s -> Integer.valueOf(s))
                .orElse(null);
    }
    
}
