2525#include <math.h>
2626#include <stdbool.h>
2727#include <stdlib.h>
28+ #include <string.h>
2829#include "common.h"
2930#include "descale.h"
3031
@@ -531,11 +532,127 @@ static void process_plane_v_c(int height, int current_height, int current_width,
531532 }
532533}
533534
535+ static inline int check_imask (unsigned char value ) {
536+ return value >= 128 ;
537+ }
538+
539+ static void process_plane_masked (int dst_dim , int src_dim , int vector_count , enum DescaleDir dir , int bandwidth , int * restrict weights_left_idx , int * restrict weights_right_idx ,
540+ int weights_columns , float * restrict weights , double * restrict multiplied_weights ,
541+ int src_stride , int imask_stride , int dst_stride , const float * restrict srcp , const unsigned char * restrict imaskp , float * restrict dstp )
542+ {
543+ double * modified_ldlt = calloc (dst_dim * bandwidth , sizeof (double ));
544+ int c = bandwidth / 2 ;
545+
546+ int imuls = dir == DESCALE_DIR_HORIZONTAL ? src_stride : 1 ;
547+ int jmuls = dir == DESCALE_DIR_HORIZONTAL ? 1 : src_stride ;
548+
549+ int imuli = dir == DESCALE_DIR_HORIZONTAL ? imask_stride : 1 ;
550+ int jmuli = dir == DESCALE_DIR_HORIZONTAL ? 1 : imask_stride ;
551+
552+ int imuld = dir == DESCALE_DIR_HORIZONTAL ? dst_stride : 1 ;
553+ int jmuld = dir == DESCALE_DIR_HORIZONTAL ? 1 : dst_stride ;
554+
555+ double eps = DBL_EPSILON ;
556+
557+ for (int i = 0 ; i < vector_count ; i ++ ) {
558+
559+ int same_mask = i > 0 ;
560+ if (i > 0 ) {
561+ for (int j = 0 ; j < src_dim ; j ++ ) {
562+ if (check_imask (imaskp [i * imuli + j * jmuli ]) != check_imask (imaskp [(i - 1 ) * imuli + j * jmuli ])) {
563+ same_mask = false;
564+ break ;
565+ }
566+ }
567+ }
568+
569+ if (!same_mask ) {
570+ int imask_start = 0 ;
571+
572+ for (int j = 0 ; j < src_dim ; j ++ ) {
573+ imask_start = j ;
574+ if (check_imask (imaskp [i * imuli + j * jmuli ]))
575+ break ;
576+ }
577+
578+ // Restore the ldlt to the original multiplied weights
579+ memcpy (modified_ldlt , multiplied_weights , dst_dim * bandwidth * sizeof (double ));
580+
581+ // Subtract the multiplied masked weights to obtain the new matrix:
582+ // M' P M = M' M - M' (I - P) M
583+ for (int j = imask_start ; j < src_dim ; j ++ ) {
584+ if (!check_imask (imaskp [i * imuli + j * jmuli ]))
585+ continue ;
586+ for (int r = 0 ; r < dst_dim ; r ++ ) {
587+ if (j < weights_left_idx [r ] || j >= weights_right_idx [r ]) continue ;
588+ for (int s = r ; s < dst_dim ; s ++ ) {
589+ if (j < weights_left_idx [s ] || j >= weights_right_idx [s ]) continue ;
590+ modified_ldlt [r * bandwidth + s - r ] -= weights [r * weights_columns + j - weights_left_idx [r ]] * weights [s * weights_columns + j - weights_left_idx [s ]];
591+ }
592+ }
593+ }
594+
595+ // Now, redo the LDLT decomposition
596+ for (int i = 0 ; i < dst_dim ; i ++ ) {
597+ int end = DSMIN (c + 1 , dst_dim - i );
598+
599+ for (int j = 1 ; j < end ; j ++ ) {
600+ double d = modified_ldlt [i * bandwidth + j ] / (modified_ldlt [i * bandwidth ] + eps );
601+
602+ for (int k = 0 ; k < end - j ; k ++ ) {
603+ modified_ldlt [(i + j ) * bandwidth + k ] -= d * modified_ldlt [i * bandwidth + j + k ];;
604+ }
605+ }
606+
607+ double e = 1.0 / (modified_ldlt [i * bandwidth ] + eps );
608+ for (int j = 1 ; j < end ; j ++ ) {
609+ modified_ldlt [i * bandwidth + j ] *= e ;
610+ }
611+ }
612+ }
613+
614+ // Now we can do the usual forward/backward substitution
615+ for (int j = 0 ; j < dst_dim ; j ++ ) {
616+ float sum = 0.0f ;
617+ int start = DSMAX (0 , j - c );
618+
619+ // A' b
620+ for (int k = weights_left_idx [j ]; k < weights_right_idx [j ]; ++ k )
621+ sum += weights [j * weights_columns + k - weights_left_idx [j ]] * srcp [i * imuls + k * jmuls ] * (1 - check_imask (imaskp [i * imuli + k * jmuli ]));
622+
623+ // Solve LD y = A' b
624+ for (int k = start ; k < j ; k ++ ) {
625+ sum -= modified_ldlt [k * bandwidth + j - k ] * modified_ldlt [k * bandwidth ] * dstp [i * imuld + k * jmuld ];
626+ }
627+
628+ dstp [i * imuld + j * jmuld ] = sum / (eps + modified_ldlt [j * bandwidth ]);
629+ }
630+
631+ // Solve L' x = y
632+ for (int j = dst_dim - 2 ; j >= 0 ; j -- ) {
633+ float sum = 0.0f ;
634+ int start = DSMIN (dst_dim - 1 , j + c );
635+
636+ for (int k = start ; k > j ; k -- ) {
637+ sum += modified_ldlt [j * bandwidth + k - j ] * dstp [i * imuld + k * jmuld ];
638+ }
639+
640+ dstp [i * imuld + j * jmuld ] -= sum ;
641+ }
642+ }
643+
644+ free (modified_ldlt );
645+ }
646+
534647
535648static void descale_process_vectors_c (struct DescaleCore * core , enum DescaleDir dir , int vector_count ,
536- int src_stride , int dst_stride , const float * srcp , float * dstp )
649+ int src_stride , int imask_stride , int dst_stride , const float * srcp , const unsigned char * imaskp , float * dstp )
537650{
538- if (dir == DESCALE_DIR_HORIZONTAL ) {
651+
652+ if (imaskp ) {
653+ process_plane_masked (core -> dst_dim , core -> src_dim , vector_count , dir , core -> bandwidth , core -> weights_left_idx , core -> weights_right_idx ,
654+ core -> weights_columns , core -> weights , core -> multiplied_weights , src_stride , imask_stride , dst_stride , srcp , imaskp , dstp );
655+ } else if (dir == DESCALE_DIR_HORIZONTAL ) {
539656 if (core -> bandwidth == 3 )
540657 process_plane_h_b3_c (core -> dst_dim , core -> src_dim , vector_count , core -> bandwidth , core -> weights_left_idx , core -> weights_right_idx ,
541658 core -> weights_columns , core -> weights , core -> lower , core -> upper , core -> diagonal , src_stride , dst_stride , srcp , dstp );
@@ -592,7 +709,7 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
592709 double * weights ;
593710 double * transposed_weights ;
594711 double * multiplied_weights ;
595- double * lower ;
712+ double * ldlt ;
596713
597714 scaling_weights (params -> mode , support , dst_dim , src_dim , params -> param1 , params -> param2 , params -> shift , params -> active_dim , params -> border_handling , & params -> custom_kernel , & weights );
598715 transpose_matrix (src_dim , dst_dim , weights , & transposed_weights );
@@ -615,9 +732,10 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
615732 }
616733
617734 multiply_sparse_matrices (dst_dim , src_dim , core .weights_left_idx , core .weights_right_idx , transposed_weights , weights , & multiplied_weights );
618- banded_ldlt_decomposition (dst_dim , core .bandwidth , multiplied_weights );
619- transpose_matrix (dst_dim , dst_dim , multiplied_weights , & lower );
620- multiply_banded_matrix_with_diagonal (dst_dim , core .bandwidth , lower );
735+
736+ ldlt = calloc (dst_dim * dst_dim , sizeof (double ));
737+ memcpy (ldlt , multiplied_weights , dst_dim * dst_dim * sizeof (double ));
738+ banded_ldlt_decomposition (dst_dim , core .bandwidth , ldlt );
621739
622740 int max = 0 ;
623741 for (int i = 0 ; i < dst_dim ; i ++ ) {
@@ -633,12 +751,24 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
633751 }
634752 }
635753
636- extract_compressed_lower_upper_diagonal (dst_dim , core .bandwidth , lower , multiplied_weights , & core .lower , & core .upper , & core .diagonal );
637-
754+ if (params -> has_ignore_mask ) {
755+ core .multiplied_weights = calloc (dst_dim * core .bandwidth , sizeof (double ));
756+ for (int i = 0 ; i < dst_dim ; i ++ ) {
757+ for (int j = 0 ; j < core .bandwidth ; j ++ ) {
758+ core .multiplied_weights [i * core .bandwidth + j ] = multiplied_weights [i * dst_dim + i + j ];
759+ }
760+ }
761+ } else {
762+ double * lower ;
763+ transpose_matrix (dst_dim , dst_dim , ldlt , & lower );
764+ multiply_banded_matrix_with_diagonal (dst_dim , core .bandwidth , lower );
765+ extract_compressed_lower_upper_diagonal (dst_dim , core .bandwidth , lower , ldlt , & core .lower , & core .upper , & core .diagonal );
766+ free (lower );
767+ }
638768 free (weights );
639769 free (transposed_weights );
640770 free (multiplied_weights );
641- free (lower );
771+ free (ldlt );
642772
643773 struct DescaleCore * corep = malloc (sizeof core );
644774 * corep = core ;
@@ -652,8 +782,9 @@ static void free_core(struct DescaleCore *core)
652782 free (core -> weights );
653783 free (core -> weights_left_idx );
654784 free (core -> weights_right_idx );
785+ free (core -> multiplied_weights );
655786 free (core -> diagonal );
656- for (int i = 0 ; i < core -> bandwidth / 2 ; i ++ ) {
787+ for (int i = 0 ; core -> upper && i < core -> bandwidth / 2 ; i ++ ) {
657788 free (core -> lower [i ]);
658789 free (core -> upper [i ]);
659790 }
0 commit comments