-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathidct_8x8.py
More file actions
283 lines (215 loc) · 7.85 KB
/
idct_8x8.py
File metadata and controls
283 lines (215 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# h264/transform/idct_8x8.py
"""H.264 8x8 Integer Inverse Transform for High Profile.
The 8x8 IDCT is used for I_8x8 macroblocks in High profile.
Uses 8-point butterfly operations with integer arithmetic.
H.264 Spec Reference: Section 8.5.12 - Inverse transform process
The inverse transform uses a separable 2D approach:
1. Apply 1D inverse transform to each column
2. Apply 1D inverse transform to each row
3. Normalize by right-shifting
The 8-point butterfly uses scaled integer coefficients to avoid
floating-point operations while maintaining precision.
"""
import logging
import numpy as np
# Import zigzag scan from entropy tables (already defined there)
from entropy.tables import ZIGZAG_8x8, ZIGZAG_8x8_INV
logger = logging.getLogger(__name__)
# Re-export zigzag scans with common naming convention
# ZIGZAG_8x8 contains flat indices (0-63), tests expect 2D coordinates
# Convert to tuple format: index i -> (row=i//8, col=i%8)
ZIGZAG_SCAN_8x8 = tuple((idx // 8, idx % 8) for idx in ZIGZAG_8x8)
# H.264 Table 8-13: Field scan for 8x8 transform (interlaced video)
# Column-major pattern optimized for interlaced content
_FIELD_SCAN_8x8_FLAT = np.array([
0, 8, 16, 1, 9, 24, 32, 17,
2, 25, 40, 48, 33, 26, 18, 3,
10, 41, 56, 49, 34, 27, 19, 11,
4, 12, 35, 42, 50, 57, 58, 51,
43, 36, 28, 20, 5, 13, 21, 29,
37, 44, 52, 59, 60, 53, 45, 38,
30, 22, 6, 14, 7, 15, 23, 31,
39, 46, 54, 61, 62, 55, 47, 63,
], dtype=np.int32)
# Convert to tuple format for consistency with ZIGZAG_SCAN_8x8
FIELD_SCAN_8x8 = tuple((idx // 8, idx % 8) for idx in _FIELD_SCAN_8x8_FLAT)
# H.264 8x8 DCT transform matrix (Table 8-12)
# Scaled integer approximation of DCT-II basis vectors
# T = C * X * C^T where C is this matrix
TRANSFORM_MATRIX_8x8 = np.array([
[ 8, 8, 8, 8, 8, 8, 8, 8],
[12, 10, 6, 3, -3, -6,-10,-12],
[ 8, 4, -4, -8, -8, -4, 4, 8],
[10, -3,-12, -6, 6, 12, 3,-10],
[ 8, -8, -8, 8, 8, -8, -8, 8],
[ 6,-12, 3, 10,-10, -3, 12, -6],
[ 4, -8, 8, -4, -4, 8, -8, 4],
[ 3, -6, 10,-12, 12,-10, 6, -3],
], dtype=np.int32)
# Position-dependent scaling factors for 8x8 transform (H.264 Table 8-14)
# These are the normalization factors for each position in the 8x8 block
SCALING_FACTORS_8x8 = np.array([
[64, 68, 64, 68, 64, 68, 64, 68],
[68, 72, 68, 72, 68, 72, 68, 72],
[64, 68, 64, 68, 64, 68, 64, 68],
[68, 72, 68, 72, 68, 72, 68, 72],
[64, 68, 64, 68, 64, 68, 64, 68],
[68, 72, 68, 72, 68, 72, 68, 72],
[64, 68, 64, 68, 64, 68, 64, 68],
[68, 72, 68, 72, 68, 72, 68, 72],
], dtype=np.int32)
# H.264 8x8 IDCT butterfly coefficients (Table 8-12 scaled)
# These are the basis function values scaled for integer arithmetic
# a=8, b=10, c=9, d=6, e=4, f=2, g=1 (with various combinations)
def idct_1d_8(x: np.ndarray) -> np.ndarray:
"""Apply 1D 8-point inverse transform per H.264 Section 8.5.12.
Matches JM reference decoder inverse8x8() from transform.c exactly.
Args:
x: Input vector of 8 elements (int32), p[0..7]
Returns:
Transformed vector of 8 elements (int32)
H.264 Spec Reference: Section 8.5.12
"""
x = x.astype(np.int32)
p0, p1, p2, p3, p4, p5, p6, p7 = (
x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]
)
# Even part
a0 = p0 + p4
a1 = p0 - p4
a2 = p6 - (p2 >> 1)
a3 = p2 + (p6 >> 1)
b0 = a0 + a3
b2 = a1 - a2
b4 = a1 + a2
b6 = a0 - a3
# Odd part
a0 = -p3 + p5 - p7 - (p7 >> 1)
a1 = p1 + p7 - p3 - (p3 >> 1)
a2 = -p1 + p7 + p5 + (p5 >> 1)
a3 = p3 + p5 + p1 + (p1 >> 1)
b1 = a0 + (a3 >> 2)
b3 = a1 + (a2 >> 2)
b5 = a2 - (a1 >> 2)
b7 = a3 - (a0 >> 2)
# Final combination (interleaved even/odd)
return np.array([
b0 + b7,
b2 - b5,
b4 + b3,
b6 + b1,
b6 - b1,
b4 - b3,
b2 + b5,
b0 - b7,
], dtype=np.int32)
def idct_8x8(coeffs: np.ndarray) -> np.ndarray:
"""Apply 8x8 integer inverse transform (IDCT).
This is the core transform used in H.264 High profile to convert
frequency-domain coefficients back to spatial-domain residuals
for 8x8 blocks.
Args:
coeffs: 8x8 transform coefficients (int32)
Returns:
8x8 spatial residuals (int32), normalized
H.264 Spec: Section 8.5.12
The process is:
1. Apply 1D transform to each column
2. Apply 1D transform to each row
3. Final normalization (divide by 256 with rounding)
"""
if coeffs.shape != (8, 8):
raise ValueError(f"Expected 8x8 block, got {coeffs.shape}")
logger.debug(f"IDCT 8x8 input:\n{coeffs}")
# Work with int32 to avoid overflow
temp = coeffs.astype(np.int32)
# Step 1: Apply 1D transform to each row (horizontal pass)
# JM inverse8x8: horizontal first, then vertical
row_result = np.zeros((8, 8), dtype=np.int32)
for i in range(8):
row_result[i, :] = idct_1d_8(temp[i, :])
# Step 2: Apply 1D transform to each column (vertical pass)
col_result = np.zeros((8, 8), dtype=np.int32)
for j in range(8):
col_result[:, j] = idct_1d_8(row_result[:, j])
# Step 3: Normalize per H.264 Section 8.5.12
# r[i][j] = (f[i][j] + 32) >> 6
result = (col_result + 32) >> 6
logger.debug(f"IDCT 8x8 output:\n{result}")
return result
def forward_1d_8(x: np.ndarray) -> np.ndarray:
"""Apply 1D 8-point forward transform matching JM forward8x8().
Args:
x: Input vector of 8 elements (int32)
Returns:
Transformed vector of 8 elements (int32)
"""
x = x.astype(np.int32)
p0, p1, p2, p3, p4, p5, p6, p7 = (
x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]
)
a0 = p0 + p7
a1 = p1 + p6
a2 = p2 + p5
a3 = p3 + p4
b0 = a0 + a3
b1 = a1 + a2
b2 = a0 - a3
b3 = a1 - a2
a0 = p0 - p7
a1 = p1 - p6
a2 = p2 - p5
a3 = p3 - p4
b4 = a1 + a2 + ((a0 >> 1) + a0)
b5 = a0 - a3 - ((a2 >> 1) + a2)
b6 = a0 + a3 - ((a1 >> 1) + a1)
b7 = a1 - a2 + ((a3 >> 1) + a3)
return np.array([
b0 + b1,
b4 + (b7 >> 2),
b2 + (b3 >> 1),
b5 + (b6 >> 2),
b0 - b1,
b6 - (b5 >> 2),
(b2 >> 1) - b3,
(b4 >> 2) - b7,
], dtype=np.int32)
def forward_8x8(block: np.ndarray) -> np.ndarray:
"""Apply 8x8 forward transform (DCT) matching JM forward8x8().
This is the inverse of idct_8x8, used primarily for testing
round-trip accuracy of the transform.
Args:
block: 8x8 spatial block (int32)
Returns:
8x8 transform coefficients (int32)
H.264 Spec: This is the encoder-side transform, inverse of Section 8.5.12
"""
if block.shape != (8, 8):
raise ValueError(f"Expected 8x8 block, got {block.shape}")
temp = block.astype(np.int32)
# Step 1: Apply 1D forward transform to each row
row_result = np.zeros((8, 8), dtype=np.int32)
for i in range(8):
row_result[i, :] = forward_1d_8(temp[i, :])
# Step 2: Apply 1D forward transform to each column
col_result = np.zeros((8, 8), dtype=np.int32)
for j in range(8):
col_result[:, j] = forward_1d_8(row_result[:, j])
return col_result
def idct_8x8_batch(blocks: np.ndarray) -> np.ndarray:
"""Process multiple 8x8 blocks efficiently.
Applies idct_8x8 to each block in the input array.
Args:
blocks: Array of shape (N, 8, 8) containing N coefficient blocks
Returns:
Array of shape (N, 8, 8) containing N residual blocks
Note: For maximum performance, consider using vectorized operations
in the future. Current implementation uses a simple loop.
"""
if blocks.ndim != 3 or blocks.shape[1:] != (8, 8):
raise ValueError(f"Expected shape (N, 8, 8), got {blocks.shape}")
n_blocks = blocks.shape[0]
result = np.zeros_like(blocks, dtype=np.int32)
for i in range(n_blocks):
result[i] = idct_8x8(blocks[i])
return result