Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {

qupathExtension {
name = "qupath-extension-djl"
version = "0.4.0"
version = "0.4.1-SNAPSHOT"
group = "io.github.qupath"
description = "QuPath extension to use Deep Java Library"
automaticModule = "qupath.extension.djl"
Expand Down
102 changes: 53 additions & 49 deletions src/main/java/qupath/ext/djl/DjlZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,6 @@

package qupath.ext.djl;

import java.awt.image.BandedSampleModel;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferFloat;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

import ai.djl.repository.MRL;
import org.locationtech.jts.geom.util.AffineTransformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
Expand All @@ -55,13 +29,14 @@
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Joints;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.translator.BigGANTranslator;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
Expand All @@ -72,6 +47,9 @@
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import org.locationtech.jts.geom.util.AffineTransformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.images.ContourTracing;
import qupath.lib.analysis.images.SimpleImage;
import qupath.lib.geom.Point2;
Expand All @@ -95,6 +73,27 @@
import qupath.lib.roi.RoiTools;
import qupath.lib.roi.interfaces.ROI;

import java.awt.image.BandedSampleModel;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferFloat;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Helper class for working with DeepJavaLibrary model zoos.
*
Expand Down Expand Up @@ -316,13 +315,13 @@ private static TranslatorFactory getTranslatorFactory(String factoryClass) {
* @return
*/
public static ROI createROI(DetectedObject obj, ImageRegion region) {
var box = obj.getBoundingBox();
if (box instanceof Mask) {
return createROI((Mask)box, region, 0.5);
} else if (box instanceof Landmark) {
return createROI((Landmark)box, region);
BoundingBox box = obj.getBoundingBox();
if (box instanceof Mask mask) {
return createROI(mask, region, 0.5);
} else if (box instanceof Landmark landmark) {
return createROI(landmark, region);
} else
return createROI((BoundingBox)box, region);
return createROI(box, region);
}

/**
Expand All @@ -335,17 +334,21 @@ public static ROI createROI(BoundingBox box, ImageRegion region) {
var bounds = box.getBounds();
double xo = 0.0;
double yo = 0.0;
double xScale = 1.0;
double yScale = 1.0;
var plane = ImagePlane.getDefaultPlane();
if (region != null) {
plane = region.getImagePlane();
xo = region.getMinX();
yo = region.getMinY();
xScale = region.getWidth();
yScale = region.getHeight();
}
return ROIs.createRectangleROI(
xo + bounds.getX() * region.getWidth(),
yo + bounds.getY() * region.getHeight(),
bounds.getWidth() * region.getWidth(),
bounds.getHeight() * region.getHeight(),
xo + bounds.getX() * xScale,
yo + bounds.getY() * yScale,
bounds.getWidth() * xScale,
bounds.getHeight() * yScale,
plane);
}

Expand All @@ -358,17 +361,16 @@ public static ROI createROI(BoundingBox box, ImageRegion region) {
*/
public static ROI createROI(Mask mask, ImageRegion region, double threshold) {
float[][] probs = mask.getProbDist();
int w = probs.length;
int h = probs[0].length;
int h = probs.length;
int w = probs[0].length;
var buffer = new DataBufferFloat(w * h, 1);
var sampleModel = new BandedSampleModel(buffer.getDataType(), w, h, 1);
var raster = WritableRaster.createWritableRaster(sampleModel, buffer, null);
for (int x = 0; x < w; x++) {
float[] col = probs[x];
for (int y = 0; y < h; y++) {
raster.setSample(x, y, 0, col[y]);
}
}
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
raster.setSample(x, y, 0, probs[y][x]);
}
}
if (region == null)
region = ImageRegion.createInstance(0, 0, w, h, 0, 0);
var geometry = ContourTracing.createTracedGeometry(raster, threshold, Double.POSITIVE_INFINITY, 0, null);
Expand All @@ -377,8 +379,10 @@ public static ROI createROI(Mask mask, ImageRegion region, double threshold) {

var transform = new AffineTransformation();
transform.scale(1.0/raster.getWidth(), 1.0/raster.getHeight());
transform.scale(bounds.getWidth(), bounds.getHeight());
transform.translate(bounds.getX(), bounds.getY());
if(!mask.isFullImageMask()) {
transform.scale(bounds.getWidth(), bounds.getHeight());
transform.translate(bounds.getX(), bounds.getY());
}
transform.scale(region.getWidth(), region.getHeight());
transform.translate(region.getX(), region.getY());

Expand Down