diff --git a/build.gradle.kts b/build.gradle.kts index a1334f9..27d8222 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -5,7 +5,7 @@ plugins { qupathExtension { name = "qupath-extension-djl" - version = "0.4.0" + version = "0.4.1-SNAPSHOT" group = "io.github.qupath" description = "QuPath extension to use Deep Java Library" automaticModule = "qupath.extension.djl" diff --git a/src/main/java/qupath/ext/djl/DjlZoo.java b/src/main/java/qupath/ext/djl/DjlZoo.java index 36e9436..8794fbb 100644 --- a/src/main/java/qupath/ext/djl/DjlZoo.java +++ b/src/main/java/qupath/ext/djl/DjlZoo.java @@ -16,32 +16,6 @@ package qupath.ext.djl; -import java.awt.image.BandedSampleModel; -import java.awt.image.BufferedImage; -import java.awt.image.DataBufferFloat; -import java.awt.image.WritableRaster; -import java.io.IOException; -import java.lang.reflect.Type; -import java.net.URI; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; -import java.util.stream.Collectors; - -import ai.djl.repository.MRL; -import org.locationtech.jts.geom.util.AffineTransformation; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.Model; @@ -55,13 +29,14 @@ import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.DetectedObjects.DetectedObject; import ai.djl.modality.cv.output.Joints; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.types.LayoutType; -import ai.djl.ndarray.types.Shape; import ai.djl.modality.cv.output.Landmark; import ai.djl.modality.cv.output.Mask; import ai.djl.modality.cv.translator.BigGANTranslator; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.LayoutType; +import ai.djl.ndarray.types.Shape; import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; @@ -72,6 +47,9 @@ import ai.djl.util.ClassLoaderUtils; import ai.djl.util.Pair; import ai.djl.util.PairList; +import org.locationtech.jts.geom.util.AffineTransformation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import qupath.lib.analysis.images.ContourTracing; import qupath.lib.analysis.images.SimpleImage; import qupath.lib.geom.Point2; @@ -95,6 +73,27 @@ import qupath.lib.roi.RoiTools; import qupath.lib.roi.interfaces.ROI; +import java.awt.image.BandedSampleModel; +import java.awt.image.BufferedImage; +import java.awt.image.DataBufferFloat; +import java.awt.image.WritableRaster; +import java.io.IOException; +import java.lang.reflect.Type; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; + /** * Helper class for working with DeepJavaLibrary model zoos. * @@ -316,13 +315,13 @@ private static TranslatorFactory getTranslatorFactory(String factoryClass) { * @return */ public static ROI createROI(DetectedObject obj, ImageRegion region) { - var box = obj.getBoundingBox(); - if (box instanceof Mask) { - return createROI((Mask)box, region, 0.5); - } else if (box instanceof Landmark) { - return createROI((Landmark)box, region); + BoundingBox box = obj.getBoundingBox(); + if (box instanceof Mask mask) { + return createROI(mask, region, 0.5); + } else if (box instanceof Landmark landmark) { + return createROI(landmark, region); } else - return createROI((BoundingBox)box, region); + return createROI(box, region); } /** @@ -335,17 +334,21 @@ public static ROI createROI(BoundingBox box, ImageRegion region) { var bounds = box.getBounds(); double xo = 0.0; double yo = 0.0; + double xScale = 1.0; + double yScale = 1.0; var plane = ImagePlane.getDefaultPlane(); if (region != null) { plane = region.getImagePlane(); xo = region.getMinX(); yo = region.getMinY(); + xScale = region.getWidth(); + yScale = region.getHeight(); } return ROIs.createRectangleROI( - xo + bounds.getX() * region.getWidth(), - yo + bounds.getY() * region.getHeight(), - bounds.getWidth() * region.getWidth(), - bounds.getHeight() * region.getHeight(), + xo + bounds.getX() * xScale, + yo + bounds.getY() * yScale, + bounds.getWidth() * xScale, + bounds.getHeight() * yScale, plane); } @@ -358,17 +361,16 @@ public static ROI createROI(BoundingBox box, ImageRegion region) { */ public static ROI createROI(Mask mask, ImageRegion region, double threshold) { float[][] probs = mask.getProbDist(); - int w = probs.length; - int h = probs[0].length; + int h = probs.length; + int w = probs[0].length; var buffer = new DataBufferFloat(w * h, 1); var sampleModel = new BandedSampleModel(buffer.getDataType(), w, h, 1); var raster = WritableRaster.createWritableRaster(sampleModel, buffer, null); - for (int x = 0; x < w; x++) { - float[] col = probs[x]; - for (int y = 0; y < h; y++) { - raster.setSample(x, y, 0, col[y]); - } - } + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + raster.setSample(x, y, 0, probs[y][x]); + } + } if (region == null) region = ImageRegion.createInstance(0, 0, w, h, 0, 0); var geometry = ContourTracing.createTracedGeometry(raster, threshold, Double.POSITIVE_INFINITY, 0, null); @@ -377,8 +379,10 @@ public static ROI createROI(Mask mask, ImageRegion region, double threshold) { var transform = new AffineTransformation(); transform.scale(1.0/raster.getWidth(), 1.0/raster.getHeight()); - transform.scale(bounds.getWidth(), bounds.getHeight()); - transform.translate(bounds.getX(), bounds.getY()); + if(!mask.isFullImageMask()) { + transform.scale(bounds.getWidth(), bounds.getHeight()); + transform.translate(bounds.getX(), bounds.getY()); + } transform.scale(region.getWidth(), region.getHeight()); transform.translate(region.getX(), region.getY());