Skip to content

Commit 911bf92

Browse files
author
Grok Compression
committed
sched: add 16-bit DWT support to ScxEngine scheduler
Enable the fused-strip DWT path for 16-bit (int16_t) wavelet buffers, supporting both 5/3 reversible and 9/7 irreversible transforms. - Add FusedStripShared16 and FusedStripJob16 structures for 16-bit strips - Implement v_cascade_p0_16_53, v_cascade_p1_16_53, v_cascade_16_53, v_cascade_strip_16_53, and v_cascade_strip_16_97 kernels - Remove is16BitDwt() exclusion from Phase 2 gate setup - Add 16-bit branch to Phase 4 with qmfbid-based dispatch
1 parent 170a73a commit 911bf92

3 files changed

Lines changed: 487 additions & 11 deletions

File tree

src/lib/core/scheduling/freebyrd/SchedulerFreebyrd.cpp

Lines changed: 218 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,8 @@ bool SchedulerFreebyrd::decodeAndTransformScx(ITileProcessor* tileProcessor)
603603
continue;
604604

605605
auto tilec = tileProcessor->getTile()->comps_ + compno;
606-
// Only use strip-aware gates for whole-tile, non-16-bit
607-
if(!tilec->isWholeTileDecoding() || tilec->is16BitDwt())
606+
// Only use strip-aware gates for whole-tile decode
607+
if(!tilec->isWholeTileDecoding())
608608
continue;
609609

610610
auto& gates = compGates[compno];
@@ -1024,13 +1024,113 @@ bool SchedulerFreebyrd::decodeAndTransformScx(ITileProcessor* tileProcessor)
10241024
}
10251025
};
10261026

1027+
// Shared state per resolution for 16-bit (per-thread intermediate buffers)
1028+
struct FusedStripShared16
1029+
{
1030+
WaveletReverse* wavelet = nullptr;
1031+
uint32_t intermediateStride = 0;
1032+
uint32_t maxLRows = 0;
1033+
uint32_t maxHRows = 0;
1034+
std::vector<std::unique_ptr<int16_t[]>> threadBufL;
1035+
std::vector<std::unique_ptr<int16_t[]>> threadBufH;
1036+
};
1037+
1038+
// Per-strip job context for 16-bit (both 5/3 and 9/7)
1039+
struct FusedStripJob16
1040+
{
1041+
FusedStripShared16* shared;
1042+
ScxEngine* engine;
1043+
StripGeometry geom;
1044+
Buffer2dSimple<int16_t> llBand, hlBand, lhBand, hhBand;
1045+
Buffer2dSimple<int16_t> winDest;
1046+
uint32_t hSn, hDn, hParity;
1047+
uint32_t resWidth;
1048+
uint8_t qmfbid;
1049+
DcShiftParam dcShift;
1050+
std::vector<uint32_t> dependentGates;
1051+
1052+
static void execute([[maybe_unused]] size_t i, size_t thread_id, void* ud)
1053+
{
1054+
auto* job = static_cast<FusedStripJob16*>(ud);
1055+
auto* wav = job->shared->wavelet;
1056+
1057+
const uint32_t stride = job->shared->intermediateStride;
1058+
int16_t* tempL = job->shared->threadBufL[thread_id].get();
1059+
int16_t* tempH = job->shared->threadBufH[thread_id].get();
1060+
1061+
// === Step 1: H-DWT for L rows ===
1062+
{
1063+
wav->horizPool16_[thread_id].sn = job->hSn;
1064+
wav->horizPool16_[thread_id].dn = job->hDn;
1065+
wav->horizPool16_[thread_id].parity = job->hParity;
1066+
1067+
auto winL = job->llBand;
1068+
auto winH = job->hlBand;
1069+
winL.incY_IN_PLACE(job->geom.rangeL.lo);
1070+
winH.incY_IN_PLACE(job->geom.rangeL.lo);
1071+
Buffer2dSimple<int16_t> dest(tempL, stride, job->geom.rangeL.count());
1072+
1073+
if(job->qmfbid == 1)
1074+
wav->h_strip_16_53(&wav->horizPool16_[thread_id], 0, job->geom.rangeL.count(), winL, winH,
1075+
dest);
1076+
else
1077+
wav->h_strip_16_97(&wav->horizPool16_[thread_id], 0, job->geom.rangeL.count(), winL, winH,
1078+
dest);
1079+
}
1080+
1081+
// === Step 2: H-DWT for H rows ===
1082+
{
1083+
auto winL = job->lhBand;
1084+
auto winH = job->hhBand;
1085+
winL.incY_IN_PLACE(job->geom.rangeH.lo);
1086+
winH.incY_IN_PLACE(job->geom.rangeH.lo);
1087+
Buffer2dSimple<int16_t> dest(tempH, stride, job->geom.rangeH.count());
1088+
1089+
if(job->qmfbid == 1)
1090+
wav->h_strip_16_53(&wav->horizPool16_[thread_id], 0, job->geom.rangeH.count(), winL, winH,
1091+
dest);
1092+
else
1093+
wav->h_strip_16_97(&wav->horizPool16_[thread_id], 0, job->geom.rangeH.count(), winL, winH,
1094+
dest);
1095+
}
1096+
1097+
// === Step 3: Cascade V-DWT ===
1098+
{
1099+
uint32_t localSn = job->geom.rangeL.count();
1100+
uint32_t localDn = job->geom.rangeH.count();
1101+
1102+
wav->vertPool16_[thread_id].sn = localSn;
1103+
wav->vertPool16_[thread_id].dn = localDn;
1104+
wav->vertPool16_[thread_id].parity = job->geom.localParity;
1105+
1106+
Buffer2dSimple<int16_t> winL(tempL, stride, localSn);
1107+
Buffer2dSimple<int16_t> winH(tempH, stride, localDn);
1108+
1109+
if(job->qmfbid == 1)
1110+
wav->v_cascade_strip_16_53(&wav->vertPool16_[thread_id], 0, job->resWidth, winL, winH,
1111+
job->winDest, job->dcShift, job->geom.outputStartInStripe,
1112+
job->geom.outCount);
1113+
else
1114+
wav->v_cascade_strip_16_97(&wav->vertPool16_[thread_id], 0, job->resWidth, winL, winH,
1115+
job->winDest, job->dcShift, job->geom.outputStartInStripe,
1116+
job->geom.outCount);
1117+
}
1118+
1119+
// === Step 4: Signal downstream strip gates ===
1120+
for(auto gateId : job->dependentGates)
1121+
scx_engine_signal_gate(job->engine, gateId);
1122+
}
1123+
};
1124+
10271125
struct CompDwtState
10281126
{
10291127
std::unique_ptr<WaveletReverse> wavelet;
1030-
std::vector<FusedStripShared> shared; // one per resolution (5/3)
1031-
std::vector<FusedStripJob> jobs; // all strip jobs (5/3)
1032-
std::vector<FusedStripShared97> shared97; // one per resolution (9/7)
1033-
std::vector<FusedStripJob97> jobs97; // all strip jobs (9/7)
1128+
std::vector<FusedStripShared> shared; // one per resolution (5/3 32-bit)
1129+
std::vector<FusedStripJob> jobs; // all strip jobs (5/3 32-bit)
1130+
std::vector<FusedStripShared97> shared97; // one per resolution (9/7 float)
1131+
std::vector<FusedStripJob97> jobs97; // all strip jobs (9/7 float)
1132+
std::vector<FusedStripShared16> shared16; // one per resolution (16-bit)
1133+
std::vector<FusedStripJob16> jobs16; // all strip jobs (16-bit)
10341134
bool usedScxDwt = false;
10351135
};
10361136
std::vector<CompDwtState> dwtStates(numcomps_);
@@ -1045,8 +1145,8 @@ bool SchedulerFreebyrd::decodeAndTransformScx(ITileProcessor* tileProcessor)
10451145
auto maxDim = std::max(tileProcessor->getCodingParams()->t_width_,
10461146
tileProcessor->getCodingParams()->t_height_);
10471147

1048-
// Don't use ScxEngine DWT for partial decompress or 16-bit (handled in fallback)
1049-
if(!tilec->isWholeTileDecoding() || tilec->is16BitDwt())
1148+
// Don't use ScxEngine DWT for partial decompress (handled in fallback)
1149+
if(!tilec->isWholeTileDecoding())
10501150
continue;
10511151

10521152
auto& ds = dwtStates[compno];
@@ -1063,11 +1163,117 @@ bool SchedulerFreebyrd::decodeAndTransformScx(ITileProcessor* tileProcessor)
10631163
auto& gates = compGates[compno];
10641164

10651165
auto bandLL = tilec->resolutions_;
1066-
auto tileBuffer = tilec->getWindow();
10671166

1068-
if(cw.qmfbid == 1)
1167+
if(tilec->is16BitDwt())
1168+
{
1169+
// === 16-bit DWT path (both 5/3 and 9/7) ===
1170+
auto tileBuffer16 = tilec->getWindow16();
1171+
wav->horizPool16_ = std::make_unique<dwt_scratch<int16_t>[]>(num_threads);
1172+
wav->vertPool16_ = std::make_unique<dwt_scratch<int16_t>[]>(num_threads);
1173+
1174+
ds.shared16.resize(cw.numRes - 1);
1175+
1176+
for(uint8_t res = 1; res < cw.numRes; ++res)
1177+
{
1178+
wav->horiz_.sn = bandLL->width();
1179+
wav->vert_.sn = bandLL->height();
1180+
for(uint32_t i = 0; i < num_threads; ++i)
1181+
{
1182+
wav->horizPool16_[i].sn = bandLL->width();
1183+
wav->vertPool16_[i].sn = bandLL->height();
1184+
}
1185+
++bandLL;
1186+
auto resWidth = bandLL->width();
1187+
auto resHeight = bandLL->height();
1188+
if(resWidth == 0 || resHeight == 0)
1189+
continue;
1190+
wav->horiz_.dn = resWidth - wav->horiz_.sn;
1191+
wav->horiz_.parity = bandLL->x0 & 1;
1192+
wav->vert_.dn = resHeight - wav->vert_.sn;
1193+
wav->vert_.parity = bandLL->y0 & 1;
1194+
for(uint32_t i = 0; i < num_threads; ++i)
1195+
{
1196+
wav->horizPool16_[i].dn = resWidth - wav->horizPool16_[i].sn;
1197+
wav->horizPool16_[i].parity = bandLL->x0 & 1;
1198+
wav->horizPool16_[i].allocatedMem = (int16_t*)waveletPoolData_->getHoriz(i);
1199+
wav->horizPool16_[i].mem = (int16_t*)waveletPoolData_->getHoriz(i);
1200+
1201+
wav->vertPool16_[i].dn = resHeight - wav->vertPool16_[i].sn;
1202+
wav->vertPool16_[i].parity = bandLL->y0 & 1;
1203+
wav->vertPool16_[i].allocatedMem = (int16_t*)waveletPoolData_->getVert(i);
1204+
wav->vertPool16_[i].mem = (int16_t*)waveletPoolData_->getVert(i);
1205+
}
1206+
1207+
auto& stripGeoms = gates.stripGeoms[res - 1];
1208+
if(stripGeoms.empty())
1209+
continue;
1210+
1211+
uint32_t intermediateStride = (resWidth + 15U) & ~15U;
1212+
1213+
uint32_t maxLRows = 0, maxHRows = 0;
1214+
for(auto& sg : stripGeoms)
1215+
{
1216+
maxLRows = std::max(maxLRows, sg.rangeL.count());
1217+
maxHRows = std::max(maxHRows, sg.rangeH.count());
1218+
}
1219+
1220+
auto& sh = ds.shared16[res - 1];
1221+
sh.wavelet = wav;
1222+
sh.intermediateStride = intermediateStride;
1223+
sh.maxLRows = maxLRows;
1224+
sh.maxHRows = maxHRows;
1225+
sh.threadBufL.resize(num_threads);
1226+
sh.threadBufH.resize(num_threads);
1227+
for(uint32_t t = 0; t < num_threads; ++t)
1228+
{
1229+
sh.threadBufL[t] = std::make_unique<int16_t[]>((size_t)intermediateStride * maxLRows);
1230+
sh.threadBufH[t] = std::make_unique<int16_t[]>((size_t)intermediateStride * maxHRows);
1231+
}
1232+
1233+
DcShiftParam dcShift = (res == cw.numRes - 1) ? wav->dcShift_ : DcShiftParam{};
1234+
1235+
auto llBand = tileBuffer16->getResWindowBufferSimple((uint8_t)(res - 1U));
1236+
auto hlBand = tileBuffer16->getBandWindowBufferPaddedSimple(res, t1::BAND_ORIENT_HL);
1237+
auto lhBand = tileBuffer16->getBandWindowBufferPaddedSimple(res, t1::BAND_ORIENT_LH);
1238+
auto hhBand = tileBuffer16->getBandWindowBufferPaddedSimple(res, t1::BAND_ORIENT_HH);
1239+
auto winDest = tileBuffer16->getResWindowBufferSimple(res);
1240+
1241+
size_t jobBase = ds.jobs16.size();
1242+
1243+
for(size_t s = 0; s < stripGeoms.size(); ++s)
1244+
{
1245+
FusedStripJob16 job;
1246+
job.shared = &sh;
1247+
job.engine = engine;
1248+
job.geom = stripGeoms[s];
1249+
job.llBand = llBand;
1250+
job.hlBand = hlBand;
1251+
job.lhBand = lhBand;
1252+
job.hhBand = hhBand;
1253+
job.winDest = winDest;
1254+
job.winDest.incY_IN_PLACE(stripGeoms[s].outStart);
1255+
job.hSn = wav->horiz_.sn;
1256+
job.hDn = wav->horiz_.dn;
1257+
job.hParity = wav->horiz_.parity;
1258+
job.resWidth = resWidth;
1259+
job.qmfbid = cw.qmfbid;
1260+
job.dcShift = dcShift;
1261+
job.dependentGates = gates.crossResDepGates[res - 1][s];
1262+
ds.jobs16.push_back(std::move(job));
1263+
}
1264+
1265+
for(size_t s = 0; s < stripGeoms.size(); ++s)
1266+
{
1267+
uint32_t stripGate = gates.stripGateIds[res - 1][s];
1268+
scx_engine_submit_full_batch(engine, cw.dwtDomainId, 0, 1, FusedStripJob16::execute,
1269+
&ds.jobs16[jobBase + s], stripGate, SCX_NO_GATE);
1270+
}
1271+
}
1272+
}
1273+
else if(cw.qmfbid == 1)
10691274
{
1070-
// === 5/3 reversible DWT path ===
1275+
// === 5/3 reversible DWT path (32-bit) ===
1276+
auto tileBuffer = tilec->getWindow();
10711277
wav->horizPool_ = std::make_unique<dwt_scratch<int32_t>[]>(num_threads);
10721278
wav->vertPool_ = std::make_unique<dwt_scratch<int32_t>[]>(num_threads);
10731279

@@ -1179,6 +1385,7 @@ bool SchedulerFreebyrd::decodeAndTransformScx(ITileProcessor* tileProcessor)
11791385
else
11801386
{
11811387
// === 9/7 irreversible DWT path ===
1388+
auto tileBuffer = tilec->getWindow();
11821389
ds.shared97.resize(cw.numRes - 1);
11831390

11841391
for(uint8_t res = 1; res < cw.numRes; ++res)

0 commit comments

Comments
 (0)