Skip to content

Commit a3db491

Browse files
committed
flake8
1 parent bf7c083 commit a3db491

3 files changed

Lines changed: 63 additions & 46 deletions

File tree

detectree2/models/outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,11 @@ def clean_predictions(directory, iou_threshold=0.7):
619619
rescaled_coords = [
620620
[crown_coords[i], crown_coords[i + 1]]
621621
for i in range(0, len(crown_coords), 2)
622-
]
622+
]
623623

624624
crowns = pd.concat([crowns, gpd.GeoDataFrame({'Confidence_score': shp['score'],
625-
'geometry': [Polygon(rescaled_coords)]},
626-
geometry='geometry')])
625+
'geometry': [Polygon(rescaled_coords)]},
626+
geometry='geometry')])
627627

628628
crowns = crowns.reset_index(drop=True)
629629
crowns, indices = clean_outputs(crowns, iou_threshold)

detectree2/models/train.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def after_step(self):
377377
img = nn.functional.interpolate(img.unsqueeze(0),
378378
size=output["instances"].image_size).squeeze(0)
379379
img = np.transpose(img[:3], (1, 2, 0))
380-
v = Visualizer(img, metadata=MetadataCatalog.get(self.trainer.cfg.DATASETS.TEST[0]), scale=1)
380+
metadata = MetadataCatalog.get(self.trainer.cfg.DATASETS.TEST[0])
381+
v = Visualizer(img, metadata=metadata, scale=1)
381382
# v = v.draw_instance_predictions(output['instances'][output['instances'].scores > 0.5].to("cpu"))
382383

383384
masks = output["instances"].pred_masks.to("cpu").numpy()
@@ -392,12 +393,13 @@ def after_step(self):
392393
else:
393394
geoms.append(None)
394395

395-
gdf = gpd.GeoDataFrame(data={
396-
"Confidence_score": scores,
397-
"indices": list(range(len(scores)))
398-
},
399-
geometry=geoms,
400-
crs="EPSG:3857")
396+
gdf = gpd.GeoDataFrame(
397+
data={
398+
"Confidence_score": scores,
399+
"indices": list(range(len(scores)))
400+
},
401+
geometry=geoms,
402+
crs="EPSG:3857")
401403

402404
gdf = clean_crowns(gdf, iou_threshold=0.3, confidence=0.3, area_threshold=0, verbose=False)
403405

@@ -546,12 +548,12 @@ def resume_or_load(self, resume=True):
546548
logger = logging.getLogger("detectree2")
547549
if input_channels_in_checkpoint != 3:
548550
logger.warning(
549-
"Input channel modification only works if checkpoint was trained on RGB images (3 channels). " \
551+
"Input channel modification only works if checkpoint was trained on RGB images (3 channels). "
550552
"The first three channels will be copied and then repeated in the model."
551553
)
552554
logger.warning(
553-
"Mismatch in input channels in checkpoint and model, meaning fvcommon would not have been able to automatically load them. " \
554-
"Adjusting weights for 'backbone.bottom_up.stem.conv1.weight' manually."
555+
"Mismatch in input channels in checkpoint and model, meaning fvcommon would not have been able "
556+
"to automatically load them. Adjusting weights for 'backbone.bottom_up.stem.conv1.weight' manually."
555557
)
556558
with torch.no_grad():
557559
self.model.backbone.bottom_up.stem.conv1.weight[:, :3] = checkpoint[:, :3]
@@ -1079,10 +1081,12 @@ def setup_cfg(
10791081
default_pixel_mean = cfg.MODEL.PIXEL_MEAN
10801082
default_pixel_std = cfg.MODEL.PIXEL_STD
10811083
# Extend or truncate the PIXEL_MEAN and PIXEL_STD based on num_bands
1082-
cfg.MODEL.PIXEL_MEAN = (default_pixel_mean * (num_bands // len(default_pixel_mean)) +
1083-
default_pixel_mean[:num_bands % len(default_pixel_mean)])
1084-
cfg.MODEL.PIXEL_STD = (default_pixel_std * (num_bands // len(default_pixel_std)) +
1085-
default_pixel_std[:num_bands % len(default_pixel_std)])
1084+
cfg.MODEL.PIXEL_MEAN = (
1085+
default_pixel_mean * (num_bands // len(default_pixel_mean))
1086+
+ default_pixel_mean[:num_bands % len(default_pixel_mean)])
1087+
cfg.MODEL.PIXEL_STD = (
1088+
default_pixel_std * (num_bands // len(default_pixel_std))
1089+
+ default_pixel_std[:num_bands % len(default_pixel_std)])
10861090
if visualize_training:
10871091
cfg.VALIDATION_VIS_PERIOD = eval_period
10881092
else:
@@ -1204,8 +1208,8 @@ def multiply_conv1_weights(model):
12041208

12051209
# Multiply weights round-robin
12061210
for i in range(in_channels):
1207-
new_weights[:, i, :, :] = old_weights[:, (i + 1) % 3, :, :] / ((in_channels // 3) +
1208-
(1 if i % 3 < in_channels % 3 else 0))
1211+
new_weights[:, i, :, :] = old_weights[:, (i + 1) % 3, :, :] / (
1212+
(in_channels // 3) + (1 if i % 3 < in_channels % 3 else 0))
12091213

12101214
# Create a fresh Conv2d that has the correct shape
12111215
new_conv = Conv2d(in_channels=in_channels,

detectree2/preprocessing/tiling.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,7 @@ def process_tile_train(
481481
image_statistics: List[Dict[str, float]] = None,
482482
ignore_bands_indices: List[int] = [],
483483
use_convex_mask: bool = True,
484-
enhance_rgb_contrast: bool = True
485-
) -> None:
484+
enhance_rgb_contrast: bool = True) -> None:
486485
"""Process a single tile for training data.
487486
488487
Args:
@@ -584,17 +583,24 @@ def _calculate_tile_placements(
584583
int(math.ceil(data.bounds[1])) + buffer, int(data.bounds[3] - tile_height - buffer), tile_height)
585584
]
586585
if overlapping_tiles:
587-
grid_coords.extend([(int(minx), int(miny)) for minx in np.arange(
588-
int(math.ceil(data.bounds[0])) + buffer + tile_width // 2, int(data.bounds[2] - tile_width - buffer -
589-
tile_width // 2), tile_width) for miny in np.arange(
590-
int(math.ceil(data.bounds[1])) + buffer + tile_height // 2, int(data.bounds[3] - tile_height - buffer -
591-
tile_height // 2), tile_height)])
586+
grid_coords.extend([
587+
(int(minx), int(miny)) for minx in np.arange(
588+
int(math.ceil(data.bounds[0])) + buffer + tile_width // 2,
589+
int(data.bounds[2] - tile_width - buffer - tile_width // 2),
590+
tile_width)
591+
for miny in np.arange(
592+
int(math.ceil(data.bounds[1])) + buffer + tile_height // 2,
593+
int(data.bounds[3] - tile_height - buffer - tile_height // 2),
594+
tile_height)
595+
])
592596
coordinates = grid_coords
593597
elif tile_placement == "adaptive":
594598

595599
if crowns is None:
596600
logger.warning(
597-
'Crowns must be supplied if tile_placement="adaptive" (crowns is None). Assuming tiling for test dataset, and tile placement will be done with tile_placement == "grid" instead.'
601+
'Crowns must be supplied if tile_placement="adaptive" (crowns is None). '
602+
'Assuming tiling for test dataset, and tile placement will be done with '
603+
'tile_placement == "grid" instead.'
598604
)
599605
return _calculate_tile_placements(img_path, buffer, tile_width, tile_height)
600606

@@ -619,8 +625,7 @@ def _calculate_tile_placements(
619625
bar = gpd.GeoSeries([
620626
box(crowns.total_bounds[0] - x_offset, crowns.total_bounds[1] - y_offset + row * tile_height,
621627
crowns.total_bounds[2] + x_offset, crowns.total_bounds[1] - y_offset + (row + 1) * tile_height)
622-
],
623-
crs=crowns.crs)
628+
], crs=crowns.crs)
624629

625630
intersection = unioned_crowns.intersection(bar)
626631
if intersection.is_empty.all():
@@ -691,9 +696,10 @@ def calc_on_everything():
691696
# Buffer for percentiles
692697
percentile_buffer = []
693698
buffer_size = 0
694-
MAX_BUFFER = 5_000_000 # 5 million pixels ~ 40MB
699+
MAX_BUFFER = 5_000_000 # 5 million pixels ~ 40MB
695700

696-
for row_off in tqdm(range(0, height, chunk_height), desc=f"Calculating stats for band {band_idx}", leave=False):
701+
for row_off in tqdm(range(0, height, chunk_height),
702+
desc=f"Calculating stats for band {band_idx}", leave=False):
697703
h = min(chunk_height, height - row_off)
698704
window = rasterio.windows.Window(0, row_off, width, h)
699705

@@ -711,8 +717,10 @@ def calc_on_everything():
711717
c_sum_sq = np.sum(valid_chunk ** 2)
712718
c_count = valid_chunk.size
713719

714-
if c_min < global_min: global_min = c_min
715-
if c_max > global_max: global_max = c_max
720+
if c_min < global_min:
721+
global_min = c_min
722+
if c_max > global_max:
723+
global_max = c_max
716724
total_sum += c_sum
717725
total_sum_sq += c_sum_sq
718726
total_count += c_count
@@ -857,9 +865,9 @@ def tile_data(
857865
"""Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles.
858866
859867
Tiles up large rasters into manageable tiles for training and prediction. If crowns are not supplied, the function
860-
will tile up the entire landscape for prediction. If crowns are supplied, the function will tile these with the image
861-
and skip tiles without a minimum coverage of crowns. The 'threshold' can be varied to ensure good coverage of
862-
crowns across a training tile. Tiles that do not have sufficient coverage are skipped.
868+
will tile up the entire landscape for prediction. If crowns are supplied, the function will tile these with the
869+
image and skip tiles without a minimum coverage of crowns. The 'threshold' can be varied to ensure good coverage
870+
of crowns across a training tile. Tiles that do not have sufficient coverage are skipped.
863871
864872
Args:
865873
img_path: Path to the orthomosaic
@@ -875,7 +883,8 @@ def tile_data(
875883
class_column: Name of the column in `crowns` DataFrame for class-based tiling
876884
tile_placement: Strategy for placing tiles.
877885
"grid" for fixed grid placement based on the bounds of the input image, optimized for speed.
878-
"adaptive" for dynamic placement of tiles based on crowns, adjusts based on data features for better coverage.
886+
"adaptive" for dynamic placement of tiles based on crowns, adjusts based on data features for better
887+
coverage.
879888
880889
Returns:
881890
None
@@ -893,24 +902,27 @@ def tile_data(
893902
tile_coordinates = _calculate_tile_placements(img_path, buffer, tile_width, tile_height, crowns, tile_placement,
894903
overlapping_tiles)
895904

896-
image_statistics = calculate_image_statistics(img_path,
897-
values_to_ignore=additional_nodata,
898-
mode=mode,
899-
ignore_bands_indices=ignore_bands_indices) if mode == "ms" else None # Only needed for multispectral data
905+
# Only needed for multispectral data
906+
image_statistics = calculate_image_statistics(
907+
img_path,
908+
values_to_ignore=additional_nodata,
909+
mode=mode,
910+
ignore_bands_indices=ignore_bands_indices) if mode == "ms" else None
900911

901912
tile_args = [
902913
(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold,
903914
nan_threshold, mode, class_column, mask_gdf, additional_nodata, image_statistics, ignore_bands_indices,
904915
use_convex_mask, enhance_rgb_contrast) for minx, miny in tile_coordinates
905916
if mask_path is None or (mask_path is not None and mask_gdf.intersects(
906-
box(minx, miny, minx + tile_width, miny + tile_height) #TODO maybe add to_crs here
917+
box(minx, miny, minx + tile_width, miny + tile_height) # TODO maybe add to_crs here
907918
).any())
908919
]
909920

910921
if random_subset > -1:
911922
if random_subset > len(tile_args):
912923
logger.warning(
913-
f"random_subset is larger than the amount of tile places ({len(tile_args)}>{random_subset}). Using all possible tiles instead."
924+
f"random_subset is larger than the amount of tile places ({len(tile_args)}>{random_subset}). "
925+
f"Using all possible tiles instead."
914926
)
915927
else:
916928
tile_args = random.sample(tile_args, random_subset)
@@ -947,8 +959,8 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path],
947959
Path to the folder containing multispectral .tif files, along with any .geojson, train, or test subdirectories.
948960
out_dir (str or Path, optional):
949961
Path to the output directory where RGB images will be saved. If None, a default folder with a suffix
950-
"_<conversion>-rgb" is created alongside the input tile folder. If `out_dir` already exists and is not empty,
951-
we append also append the current date and time to avoid overwriting.
962+
"_<conversion>-rgb" is created alongside the input tile folder. If `out_dir` already exists and is not
963+
empty, we append also append the current date and time to avoid overwriting.
952964
conversion (str, optional):
953965
The method of converting multispectral imagery to three bands:
954966
- "pca": perform a principal-component analysis reduction to three components.
@@ -1115,7 +1127,8 @@ def create_RGB_from_MS(tile_folder_path: Union[str, Path],
11151127
data = src.read(indexes=[1, 2, 3])
11161128
if np.nanmax(data) > 255:
11171129
logger.exception(
1118-
"The input folder seems to be an RGB folder and you are taking the first three bands. This will not change the output. Did you choose the wrong folder? Aborting."
1130+
"The input folder seems to be an RGB folder and you are taking the first three bands. "
1131+
"This will not change the output. Did you choose the wrong folder? Aborting."
11191132
)
11201133
return
11211134
except RasterioIOError as e:

0 commit comments

Comments
 (0)