Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions lib/Kernel/KernelImplementationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,22 @@ namespace kernel {
namespace {

using tensor4d = std::vector<std::vector<std::vector<std::vector<int>>>>;
using tensor3d = std::vector<std::vector<std::vector<int>>>;

std::function<int(const std::vector<int64_t>&)> getDataValueFn(tensor4d data) {
std::function<int(const std::vector<int64_t>&)> getDataValueFn4D(
tensor4d data) {
return [data](const std::vector<int64_t>& domainPoint) -> int {
return data[domainPoint[0]][domainPoint[1]][domainPoint[2]][domainPoint[3]];
};
}

std::function<int(const std::vector<int64_t>&)> getDataValueFn3D(
tensor3d data) {
return [data](const std::vector<int64_t>& domainPoint) -> int {
return data[domainPoint[0]][domainPoint[1]][domainPoint[2]];
};
}

// Parametrize over whether the kernel is unrolled and whether rows are
// interchanged
class KernelImplementationTest
Expand Down Expand Up @@ -447,7 +456,7 @@ TEST_P(KernelImplementationTest, TestConv2dNchwFchwStride2) {

auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn(data));
evaluateLayout(dataLayout, getDataValueFn4D(data));

SmallVector<int64_t> strides = {2, 2};
auto filterLayout = get2dConvChwFchwFilterDiagonalizedRelation(
Expand Down Expand Up @@ -493,6 +502,66 @@ TEST_P(KernelImplementationTest, TestConv2dNchwFchwStride2) {
EXPECT_EQ(actualUnpacked, expected);
}

TEST_P(KernelImplementationTest, TestConv1dCwFcwStride2) {
MLIRContext context;
RankedTensorType dataType =
RankedTensorType::get({1, 2, 5}, mlir::IndexType::get(&context));
RankedTensorType filterType =
RankedTensorType::get({2, 2, 3}, mlir::IndexType::get(&context));

int numSlots = 16;
// 1x2x5 input data
// [[0,1,2,3,4],[5,6,7,8,9]]
tensor3d data = {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}};
// 2x2x3 filter data
// [[3 4 1], [1 5 2]]
// [[1 2 3], [2 2 2]]
tensor3d filter = {{{3, 4, 1}, {1, 5, 2}}, {{1, 2, 3}, {2, 2, 2}}};

auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn3D(data));

int64_t stride = 2;
auto filterLayout = get1dConvCwFcwFilterDiagonalizedRelation(
filterType, dataType, stride, 0, numSlots,
/*interchangeRows=*/std::get<1>(GetParam()));
ASSERT_TRUE(succeeded(filterLayout));
std::function<int(const std::vector<int64_t>&)> getFilterValueFn =
[&](const std::vector<int64_t>& domainPoint) -> int {
return filter[domainPoint[0]][domainPoint[1]][domainPoint[2]];
};
std::vector<std::vector<int>> packedFilter =
evaluateLayout(filterLayout.value(), getFilterValueFn);
auto expandedFilterShape =
get1dConvCwFcwFilterExpandedType(filterType, dataType, stride, 0);

// 1x2x2 output
tensor3d expected = {{{55, 87}, {44, 68}}};

LiteralValue matrixInput = packedFilter;
LiteralValue vectorInput = packedData[0];

auto dag = implementHaleviShoup(
vectorInput, matrixInput, expandedFilterShape.getShape(),
DagType::intTensor(32, {numSlots}),
/*zeroDiagonals=*/{}, /*unroll=*/std::get<0>(GetParam()));
LiteralValue actual = evalKernel(dag)[0];
auto actual3d = std::get<std::vector<int>>(actual.get());

RankedTensorType outputType =
RankedTensorType::get({1, 2, 2}, mlir::IndexType::get(&context));
auto resultLayout =
get1dConvResultRelation(outputType, stride, 0, numSlots,
/*interchangeRows=*/std::get<1>(GetParam()));

auto actualUnpacked =
unpackLayoutTo3DTensor<int>(resultLayout, {actual3d}, {1, 2, 2});

// Result is 4 2x2 tensors with a row-major layout.
EXPECT_EQ(actualUnpacked, expected);
}

TEST_P(KernelImplementationTest,
TestConv2dNchwFchwStride2InterchangedLargeSlots) {
MLIRContext context;
Expand All @@ -512,7 +581,7 @@ TEST_P(KernelImplementationTest,

auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn(data));
evaluateLayout(dataLayout, getDataValueFn4D(data));

SmallVector<int64_t> strides = {2, 2};
auto filterLayout = get2dConvChwFchwFilterDiagonalizedRelation(
Expand Down Expand Up @@ -599,7 +668,7 @@ TEST_P(KernelImplementationTest, TestConv2dNchwFchwOrionFigure4) {

auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn(data));
evaluateLayout(dataLayout, getDataValueFn4D(data));

SmallVector<int64_t> strides = {1, 1};
int64_t padding = 1;
Expand Down Expand Up @@ -713,7 +782,7 @@ TEST_P(KernelImplementationTest, TestConv2dNchwFchwStride2MultiInput) {

auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn(data));
evaluateLayout(dataLayout, getDataValueFn4D(data));

SmallVector<int64_t> strides = {2, 2};
auto filterLayout = get2dConvChwFchwFilterDiagonalizedRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,140 @@ struct ConvertLinalgConv2D
bool unrollKernels;
};

struct ConvertLinalgConv1DNcwFcw
: public ContextAwareOpConversionPattern<linalg::Conv1DNcwFcwOp> {
public:
using ContextAwareOpConversionPattern<
linalg::Conv1DNcwFcwOp>::ContextAwareOpConversionPattern;

ConvertLinalgConv1DNcwFcw(
const ContextAwareTypeConverter& contextAwareTypeConverter,
MLIRContext* context, bool unrollKernels = true)
: ContextAwareOpConversionPattern(contextAwareTypeConverter, context,
/*benefit=*/10),
unrollKernels(unrollKernels) {}

LayoutAttr getLayoutAttr(Value value) const {
auto layoutLookup = getTypeConverter()->getContextualAttr(value);
if (failed(layoutLookup)) {
return nullptr;
}
return dyn_cast<LayoutAttr>(layoutLookup.value());
}

bool supportsExpandedHaleviShoup(linalg::Conv1DNcwFcwOp op,
OpAdaptor adaptor) const {
Value filter = adaptor.getInputs().back();
auto materializedFilterType = cast<RankedTensorType>(filter.getType());

// If one of these dimensions is not a power of two, then we can't do
// the Halevi-Shoup or Squat Packing Matrix Multiplication conversion.
auto dimensions = materializedFilterType.getShape();
int64_t numRows = dimensions[0];
int64_t numCols = dimensions[1];
bool isPowerOfTwoDims = isPowerOfTwo(numRows) && isPowerOfTwo(numCols);

auto kernelAttr = op->getAttrOfType<secret::KernelAttr>(
secret::SecretDialect::kKernelAttrName);
bool isConv1dAsMatvec =
kernelAttr && kernelAttr.getName() == KernelName::MatvecDiagonal;

LLVM_DEBUG(llvm::dbgs()
<< "supports expanded conv1d as matvec with halevi-shoup: "
<< "isPowerOfTwoDims=" << isPowerOfTwoDims
<< " isConv1dAsMatvec=" << isConv1dAsMatvec << "\n");

return isPowerOfTwoDims && isConv1dAsMatvec;
}

void haleviShoupKernel(
linalg::Conv1DNcwFcwOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const {
LLVM_DEBUG(
llvm::dbgs()
<< "Converting linalg.conv_1d_ncw_fcw op with halevi shoup kernel: "
<< op << "\n");

TypedValue<RankedTensorType> data =
cast<TypedValue<RankedTensorType>>(adaptor.getInputs()[0]);
SSAValue vectorLeaf(data);
TypedValue<RankedTensorType> matrix =
cast<TypedValue<RankedTensorType>>(adaptor.getInputs()[1]);
SSAValue matrixLeaf(matrix);

// The original matrix shape is the shape of the expanded filter before
// diagonalization.
RankedTensorType expandedMatrixType = get1dConvCwFcwFilterExpandedType(
cast<RankedTensorType>(op.getInputs()[1].getType()),
cast<RankedTensorType>(op.getInputs()[0].getType()),
llvm::to_vector(op.getStrides().getValues<int64_t>()).front(),
/*padding=*/0);
// Collect any zero diagonals of the filter matrix.
LayoutAttr filterLayout = getLayoutAttr(adaptor.getInputs()[1]);
auto filterRelation = filterLayout.getIntegerRelation();

PointCollector collector;
std::map<int, bool> zeroDiagonals;
// TODO(#2897): Enable this one row interchange is supported.
// getCtComplementPoints(filterRelation, collector, matrix.getType());
// for (const auto& point : collector.points) {
// zeroDiagonals[point[0]] = true;
// }

auto dagType = kernel::mlirTypeToDagType(data.getType(),
data.getType().getShape().back());
std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel =
implementHaleviShoup(vectorLeaf, matrixLeaf,
expandedMatrixType.getShape(), dagType,
zeroDiagonals,
/*unroll=*/unrollKernels);

rewriter.setInsertionPointAfter(op);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
IRMaterializingVisitor visitor(data.getType(), [&](Operation* createdOp) {
setMaterializedAttr(createdOp);
});
Value finalOutput = visitor.process(implementedKernel, b)[0];

auto layoutAttr = cast<LayoutAttr>(op->getAttr(kLayoutAttrName));
auto finalOutputOp = finalOutput.getDefiningOp();
finalOutputOp->setAttr(kLayoutAttrName, layoutAttr);
setMaterializedAttr(finalOutputOp);

// Add the initial accumulator value.
Value result = adaptor.getOutputs()[0];
Operation* addBias =
makeAppropriatelyTypedAddOp(b, op->getLoc(), finalOutput, result);
addBias->setAttr(kLayoutAttrName, layoutAttr);
setMaterializedAttr(addBias);
rewriter.replaceOp(op, addBias);
}

LogicalResult matchAndRewrite(
linalg::Conv1DNcwFcwOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const final {
Value data = adaptor.getInputs().front();
Value filter = adaptor.getInputs().back();
LayoutAttr dataLayout = getLayoutAttr(data);
LayoutAttr filterLayout = getLayoutAttr(filter);

if (!dataLayout || !filterLayout) {
return rewriter.notifyMatchFailure(
op, "missing new layout attribute for data and filter");
}

if (supportsExpandedHaleviShoup(op, adaptor)) {
haleviShoupKernel(op, adaptor, rewriter);
return success();
}

return op.emitError() << "unsupported layout for 1d conv";
}

private:
bool unrollKernels;
};

struct ConvertLinalgConv2DNchwFchw
: public ContextAwareOpConversionPattern<linalg::Conv2DNchwFchwOp> {
public:
Expand Down Expand Up @@ -2455,8 +2589,9 @@ struct ConvertToCiphertextSemantics
ConvertTensorInsertLayout, ConvertTensorInsertSlice>(
typeConverter, context);
patterns.add<ConvertLinalgMatvecLayout, ConvertLinalgConv1D,
ConvertLinalgConv2D, ConvertLinalgConv2DNchwFchw>(
typeConverter, context, unrollKernels);
ConvertLinalgConv2D, ConvertLinalgConv2DNchwFchw,
ConvertLinalgConv1DNcwFcw>(typeConverter, context,
unrollKernels);
patterns.add<ConvertAssignLayout>(typeConverter, context, ciphertextSize);

ConversionConfig config;
Expand Down
Loading
Loading