Skip to content

Commit 4f1ba9c

Browse files
committed
[Synth][CutRewriter] Add timing-preserving area-flow reselection
1 parent 5e21c95 commit 4f1ba9c

4 files changed

Lines changed: 381 additions & 43 deletions

File tree

include/circt/Dialect/Synth/Transforms/CutRewriter.h

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/Support/Allocator.h"
2929
#include "llvm/Support/LogicalResult.h"
3030
#include "llvm/Support/raw_ostream.h"
31+
#include <limits>
3132
#include <memory>
3233
#include <optional>
3334
#include <utility>
@@ -131,6 +132,12 @@ struct LogicNetworkGate {
131132
/// inversion bit is encoded in each edge.
132133
Signal edges[3];
133134

135+
/// Number of uses by logic gates in this network.
136+
unsigned logicFanoutCount = 0;
137+
138+
/// Number of uses outside the logic network.
139+
unsigned externalUseCount = 0;
140+
134141
LogicNetworkGate() : opAndKind(nullptr, Constant), edges{} {}
135142
LogicNetworkGate(Operation *op, Kind kind,
136143
llvm::ArrayRef<Signal> operands = {})
@@ -171,11 +178,18 @@ struct LogicNetworkGate {
171178
return k == And2 || k == Xor2 || k == Maj3 || k == Identity || k == Choice;
172179
}
173180

174-
/// Check if this should always be a cut input (PI or constant).
175-
bool isAlwaysCutInput() const {
181+
/// Check if this gate is a cut leaf (PI or constant).
182+
bool isCutLeaf() const {
176183
Kind k = getKind();
177184
return k == PrimaryInput || k == Constant;
178185
}
186+
187+
unsigned getTotalRefCount() const {
188+
unsigned refCount = logicFanoutCount + externalUseCount;
189+
return refCount == 0 ? 1 : refCount;
190+
}
191+
192+
bool isPrimaryOutput() const { return externalUseCount != 0; }
179193
};
180194

181195
/// Flat logic network representation for efficient cut enumeration.
@@ -258,6 +272,16 @@ class LogicNetwork {
258272
/// Get the total number of nodes in the network.
259273
size_t size() const { return gates.size(); }
260274

275+
/// Get the total reference count used by area-flow estimation.
276+
unsigned getTotalRefCount(uint32_t index) const {
277+
return gates[index].getTotalRefCount();
278+
}
279+
280+
/// Check if a node is observed outside the logic network.
281+
bool isPrimaryOutput(uint32_t index) const {
282+
return gates[index].isPrimaryOutput();
283+
}
284+
261285
/// Add a primary input to the network.
262286
uint32_t addPrimaryInput(Value value);
263287

@@ -279,6 +303,9 @@ class LogicNetwork {
279303
void clear();
280304

281305
private:
306+
void recordLogicUse(uint32_t index) { ++gates[index].logicFanoutCount; }
307+
void recordExternalUse(uint32_t index) { ++gates[index].externalUseCount; }
308+
282309
/// Map from MLIR Value to network index.
283310
llvm::DenseMap<Value, uint32_t> valueToIndex;
284311

@@ -349,27 +376,49 @@ class MatchedPattern {
349376
private:
350377
const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
351378
SmallVector<DelayType, 1>
352-
arrivalTimes; ///< Arrival times of outputs from this pattern
353-
double area = 0.0; ///< Area cost of this pattern
379+
arrivalTimes; ///< Arrival times of outputs from this pattern
380+
/// Saved match data we reuse during area-flow reselection.
381+
MatchResult matchResult;
382+
SmallVector<unsigned, 6> patternInputToCutInput;
354383

355384
public:
356385
/// Default constructor creates an invalid matched pattern.
357386
MatchedPattern() = default;
358387

359388
/// Constructor for a valid matched pattern.
360389
MatchedPattern(const CutRewritePattern *pattern,
361-
SmallVector<DelayType, 1> arrivalTimes, double area)
362-
: pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
390+
SmallVector<DelayType, 1> arrivalTimes,
391+
MatchResult matchResult,
392+
ArrayRef<unsigned> patternInputToCutInput)
393+
: pattern(pattern), arrivalTimes(std::move(arrivalTimes)),
394+
matchResult(std::move(matchResult)),
395+
patternInputToCutInput(patternInputToCutInput.begin(),
396+
patternInputToCutInput.end()) {}
363397

364398
/// Get the arrival time of signals through this pattern.
365399
DelayType getArrivalTime(unsigned outputIndex) const;
366400
ArrayRef<DelayType> getArrivalTimes() const;
401+
DelayType getWorstOutputArrivalTime() const;
367402

368403
/// Get the library pattern that was matched.
369404
const CutRewritePattern *getPattern() const;
370405

371406
/// Get the area cost of using this pattern.
372407
double getArea() const;
408+
409+
/// Get the per-input delays used when scoring this match.
410+
ArrayRef<DelayType> getDelays() const;
411+
412+
/// Get the cached match payload used to rebuild this match.
413+
const MatchResult &getMatchResult() const { return matchResult; }
414+
415+
/// Get the mapping from pattern input indices to cut input indices.
416+
ArrayRef<unsigned> getInputPermutation() const {
417+
return patternInputToCutInput;
418+
}
419+
420+
/// Get the delay for a cut input after accounting for input permutation.
421+
DelayType getDelayForCutInput(unsigned cutInputIndex) const;
373422
};
374423

375424
/// Represents a cut in the combinational logic network.
@@ -529,6 +578,15 @@ class CutSet {
529578
bool isFrozen = false; ///< Whether cut set is finalized
530579

531580
public:
581+
/// Latest time this node is allowed to arrive.
582+
DelayType requiredTime = std::numeric_limits<DelayType>::max();
583+
584+
/// Arrival time of the currently selected cut.
585+
DelayType bestArrivalTime = 0;
586+
587+
/// Current area-flow score for the selected cut.
588+
double areaFlow = 0.0;
589+
532590
/// Check if this cut set has a valid matched pattern.
533591
bool isMatched() const { return bestCut; }
534592

@@ -551,6 +609,9 @@ class CutSet {
551609

552610
/// Get read-only access to all cuts in this set.
553611
ArrayRef<Cut *> getCuts() const;
612+
613+
/// Replace the currently selected cut during area recovery.
614+
void setBestCut(Cut *cut) { bestCut = cut; }
554615
};
555616

556617
/// Configuration options for the cut-based rewriting algorithm.
@@ -658,6 +719,12 @@ class CutEnumerator {
658719

659720
void dump() const;
660721

722+
/// Compute required times from the current timing-feasible seed mapping.
723+
void computeRequiredTimes();
724+
725+
/// Re-select cuts using area-flow while preserving required times.
726+
void reselectCutsForAreaFlow();
727+
661728
/// Get cut sets (indexed by LogicNetwork index).
662729
const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
663730
return cutSets;

0 commit comments

Comments
 (0)