Skip to content

Commit 9cec9d6

Browse files
committed
Merge #243: mpgen: support primitive std::optional struct fields
6dbfa56 mpgen: support primitive std::optional struct fields (Ryan Ofsky) 8d1277d mpgen refactor: add AccessorType function (Ryan Ofsky) db716bb mpgen refactor: Move field handling code to FieldList class (Ryan Ofsky) Pull request description: Currently C++ structs with primitive `std::optional` members (ints, bools, floats) cannot easily by mapped to Cap'n Proto structs because Cap'n Proto does not provide a way to leave primitive fields unset, so there isn't a natural way to represent `std::nullopt` values. This PR makes it possible to map C++ structs with fields like: ```c++ std::optional<int> foo; ``` to Cap'n Proto structs by using extra `Bool` fields prefixed with "has" for primitive optional members: ```capnp foo @3 :Int32; hasFoo @4 :Bool; ``` Boolean "has" fields were already supported by the code generator and used to pass primitive `std::optional` parameters and return values, so this PR just extends it work with all struct fields, not just fields in params and result structs. Note: Motivation for this change is dealing with the [`CreatedTransactionResult::change_pos`](https://github.com/furszy/bitcoin-gui/blob/4c0d4f6f93f371a8ad097735945d32510a7e83bb/src/wallet/types.h#L39) field introduced to the wallet interface in bitcoin-core/gui#807. This also could have been useful in bitcoin/bitcoin#33965 (comment) ACKs for top commit: xyzconstant: Tested ACK 6dbfa56 Tree-SHA512: d1191c67bfe7232594ec51fa939f9700f393daaf6db7ece9abcd38a7319c72c25370126b2df3f8dec7d2805650a263598338e6e9f6b77c9ea396cbc94eba58d8
2 parents 4aaff11 + 6dbfa56 commit 9cec9d6

5 files changed

Lines changed: 142 additions & 113 deletions

File tree

src/mp/gen.cpp

Lines changed: 130 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,109 @@ static bool BoxedType(const ::capnp::Type& type)
125125
type.isFloat64() || type.isEnum());
126126
}
127127

128+
struct Field
129+
{
130+
::capnp::StructSchema::Field param;
131+
bool param_is_set = false;
132+
::capnp::StructSchema::Field result;
133+
bool result_is_set = false;
134+
int args = 0;
135+
bool retval = false;
136+
bool optional = false;
137+
bool requested = false;
138+
bool skip = false;
139+
kj::StringPtr exception;
140+
};
141+
142+
struct FieldList
143+
{
144+
std::vector<Field> fields;
145+
std::map<kj::StringPtr, int> field_idx; // name -> args index
146+
bool has_result = false;
147+
148+
void addField(const ::capnp::StructSchema::Field& schema_field, bool param, bool result)
149+
{
150+
auto field_name = schema_field.getProto().getName();
151+
auto inserted = field_idx.emplace(field_name, fields.size());
152+
if (inserted.second) {
153+
fields.emplace_back();
154+
}
155+
auto& field = fields[inserted.first->second];
156+
if (param) {
157+
field.param = schema_field;
158+
field.param_is_set = true;
159+
}
160+
if (result) {
161+
field.result = schema_field;
162+
field.result_is_set = true;
163+
}
164+
165+
if (!param && field_name == kj::StringPtr{"result"}) {
166+
field.retval = true;
167+
has_result = true;
168+
}
169+
170+
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
171+
field.skip = true;
172+
}
173+
GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);
174+
175+
int32_t count = 1;
176+
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
177+
if (schema_field.getType().isStruct()) {
178+
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
179+
COUNT_ANNOTATION_ID, &count);
180+
} else if (schema_field.getType().isInterface()) {
181+
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
182+
COUNT_ANNOTATION_ID, &count);
183+
}
184+
}
185+
186+
187+
if (inserted.second && !field.retval && !field.exception.size()) {
188+
field.args = count;
189+
}
190+
}
191+
192+
void mergeFields()
193+
{
194+
for (auto& field : field_idx) {
195+
auto has_field = field_idx.find("has" + Cap(field.first));
196+
if (has_field != field_idx.end()) {
197+
fields[has_field->second].skip = true;
198+
fields[field.second].optional = true;
199+
}
200+
auto want_field = field_idx.find("want" + Cap(field.first));
201+
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
202+
fields[want_field->second].skip = true;
203+
fields[field.second].requested = true;
204+
}
205+
}
206+
}
207+
};
208+
209+
std::string AccessorType(kj::StringPtr base_name, const Field& field)
210+
{
211+
const auto& f = field.param_is_set ? field.param : field.result;
212+
const auto field_name = f.getProto().getName();
213+
const auto field_type = f.getType();
214+
215+
std::ostringstream out;
216+
out << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", ";
217+
if (!field.param_is_set) {
218+
out << "FIELD_OUT";
219+
} else if (field.result_is_set) {
220+
out << "FIELD_IN | FIELD_OUT";
221+
} else {
222+
out << "FIELD_IN";
223+
}
224+
if (field.optional) out << " | FIELD_OPTIONAL";
225+
if (field.requested) out << " | FIELD_REQUESTED";
226+
if (BoxedType(field_type)) out << " | FIELD_BOXED";
227+
out << ">";
228+
return out.str();
229+
}
230+
128231
// src_file is path to .capnp file to generate stub code from.
129232
//
130233
// src_prefix can be used to generate outputs in a different directory than the
@@ -335,6 +438,13 @@ static void Generate(kj::StringPtr src_prefix,
335438

336439
if (node.getProto().isStruct()) {
337440
const auto& struc = node.asStruct();
441+
442+
FieldList fields;
443+
for (const auto schema_field : struc.getFields()) {
444+
fields.addField(schema_field, true, true);
445+
}
446+
fields.mergeFields();
447+
338448
std::ostringstream generic_name;
339449
generic_name << node_name;
340450
dec << "template<";
@@ -355,22 +465,18 @@ static void Generate(kj::StringPtr src_prefix,
355465
dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n";
356466
dec << "{\n";
357467
dec << " using Struct = " << message_namespace << "::" << generic_name.str() << ";\n";
358-
for (const auto field : struc.getFields()) {
359-
auto field_name = field.getProto().getName();
468+
for (const auto& field : fields.fields) {
469+
auto field_name = field.param.getProto().getName();
360470
add_accessor(field_name);
361-
dec << " using " << Cap(field_name) << "Accessor = Accessor<" << base_name
362-
<< "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT";
363-
if (BoxedType(field.getType())) dec << " | FIELD_BOXED";
364-
dec << ">;\n";
471+
dec << " using " << Cap(field_name) << "Accessor = "
472+
<< AccessorType(base_name, field) << ";\n";
365473
}
366474
dec << " using Accessors = std::tuple<";
367475
size_t i = 0;
368-
for (const auto field : struc.getFields()) {
369-
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
370-
continue;
371-
}
476+
for (const auto& field : fields.fields) {
477+
if (field.skip) continue;
372478
if (i) dec << ", ";
373-
dec << Cap(field.getProto().getName()) << "Accessor";
479+
dec << Cap(field.param.getProto().getName()) << "Accessor";
374480
++i;
375481
}
376482
dec << ">;\n";
@@ -384,13 +490,11 @@ static void Generate(kj::StringPtr src_prefix,
384490
inl << "public:\n";
385491
inl << " using Struct = " << message_namespace << "::" << node_name << ";\n";
386492
size_t i = 0;
387-
for (const auto field : struc.getFields()) {
388-
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
389-
continue;
390-
}
391-
auto field_name = field.getProto().getName();
493+
for (const auto& field : fields.fields) {
494+
if (field.skip) continue;
495+
auto field_name = field.param.getProto().getName();
392496
auto member_name = field_name;
393-
GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
497+
GetAnnotationText(field.param.getProto(), NAME_ANNOTATION_ID, &member_name);
394498
inl << " static decltype(auto) get(std::integral_constant<size_t, " << i << ">) { return "
395499
<< "&" << proxied_class_type << "::" << member_name << "; }\n";
396500
++i;
@@ -433,85 +537,14 @@ static void Generate(kj::StringPtr src_prefix,
433537
const bool is_construct = method_name == kj::StringPtr{"construct"};
434538
const bool is_destroy = method_name == kj::StringPtr{"destroy"};
435539

436-
struct Field
437-
{
438-
::capnp::StructSchema::Field param;
439-
bool param_is_set = false;
440-
::capnp::StructSchema::Field result;
441-
bool result_is_set = false;
442-
int args = 0;
443-
bool retval = false;
444-
bool optional = false;
445-
bool requested = false;
446-
bool skip = false;
447-
kj::StringPtr exception;
448-
};
449-
450-
std::vector<Field> fields;
451-
std::map<kj::StringPtr, int> field_idx; // name -> args index
452-
bool has_result = false;
453-
454-
auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) {
455-
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
456-
return;
457-
}
458-
459-
auto field_name = schema_field.getProto().getName();
460-
auto inserted = field_idx.emplace(field_name, fields.size());
461-
if (inserted.second) {
462-
fields.emplace_back();
463-
}
464-
auto& field = fields[inserted.first->second];
465-
if (param) {
466-
field.param = schema_field;
467-
field.param_is_set = true;
468-
} else {
469-
field.result = schema_field;
470-
field.result_is_set = true;
471-
}
472-
473-
if (!param && field_name == kj::StringPtr{"result"}) {
474-
field.retval = true;
475-
has_result = true;
476-
}
477-
478-
GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);
479-
480-
int32_t count = 1;
481-
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
482-
if (schema_field.getType().isStruct()) {
483-
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
484-
COUNT_ANNOTATION_ID, &count);
485-
} else if (schema_field.getType().isInterface()) {
486-
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
487-
COUNT_ANNOTATION_ID, &count);
488-
}
489-
}
490-
491-
492-
if (inserted.second && !field.retval && !field.exception.size()) {
493-
field.args = count;
494-
}
495-
};
496-
540+
FieldList fields;
497541
for (const auto schema_field : method.getParamType().getFields()) {
498-
add_field(schema_field, true);
542+
fields.addField(schema_field, true, false);
499543
}
500544
for (const auto schema_field : method.getResultType().getFields()) {
501-
add_field(schema_field, false);
502-
}
503-
for (auto& field : field_idx) {
504-
auto has_field = field_idx.find("has" + Cap(field.first));
505-
if (has_field != field_idx.end()) {
506-
fields[has_field->second].skip = true;
507-
fields[field.second].optional = true;
508-
}
509-
auto want_field = field_idx.find("want" + Cap(field.first));
510-
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
511-
fields[want_field->second].skip = true;
512-
fields[field.second].requested = true;
513-
}
545+
fields.addField(schema_field, false, true);
514546
}
547+
fields.mergeFields();
515548

516549
if (!is_construct && !is_destroy && (&method_interface == &interface)) {
517550
methods << "template<>\n";
@@ -527,25 +560,11 @@ static void Generate(kj::StringPtr src_prefix,
527560
std::ostringstream server_invoke_start;
528561
std::ostringstream server_invoke_end;
529562
int argc = 0;
530-
for (const auto& field : fields) {
563+
for (const auto& field : fields.fields) {
531564
if (field.skip) continue;
532565

533566
const auto& f = field.param_is_set ? field.param : field.result;
534567
auto field_name = f.getProto().getName();
535-
auto field_type = f.getType();
536-
537-
std::ostringstream field_flags;
538-
if (!field.param_is_set) {
539-
field_flags << "FIELD_OUT";
540-
} else if (field.result_is_set) {
541-
field_flags << "FIELD_IN | FIELD_OUT";
542-
} else {
543-
field_flags << "FIELD_IN";
544-
}
545-
if (field.optional) field_flags << " | FIELD_OPTIONAL";
546-
if (field.requested) field_flags << " | FIELD_REQUESTED";
547-
if (BoxedType(field_type)) field_flags << " | FIELD_BOXED";
548-
549568
add_accessor(field_name);
550569

551570
std::ostringstream fwd_args;
@@ -572,8 +591,7 @@ static void Generate(kj::StringPtr src_prefix,
572591
client_invoke << "MakeClientParam<";
573592
}
574593

575-
client_invoke << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
576-
<< field_flags.str() << ">>(";
594+
client_invoke << AccessorType(base_name, field) << ">(";
577595

578596
if (field.retval) {
579597
client_invoke << field_name;
@@ -589,8 +607,7 @@ static void Generate(kj::StringPtr src_prefix,
589607
} else {
590608
server_invoke_start << "MakeServerField<" << field.args;
591609
}
592-
server_invoke_start << ", Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
593-
<< field_flags.str() << ">>(";
610+
server_invoke_start << ", " << AccessorType(base_name, field) << ">(";
594611
server_invoke_end << ")";
595612
}
596613

@@ -606,12 +623,12 @@ static void Generate(kj::StringPtr src_prefix,
606623
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
607624
<< "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
608625
<< "(" << super_str << client_args.str() << ") {\n";
609-
if (has_result) {
626+
if (fields.has_result) {
610627
def_client << " typename M" << method_ordinal << "::Result result;\n";
611628
}
612629
def_client << " clientInvoke(" << self_str << ", &" << message_namespace << "::" << node_name
613630
<< "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
614-
if (has_result) def_client << " return result;\n";
631+
if (fields.has_result) def_client << " return result;\n";
615632
def_client << "}\n";
616633

617634
server << " kj::Promise<void> " << method_name << "(" << Cap(method_name)

test/mp/test/foo-types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <mp/type-map.h>
2121
#include <mp/type-message.h>
2222
#include <mp/type-number.h>
23+
#include <mp/type-optional.h>
2324
#include <mp/type-pointer.h>
2425
#include <mp/type-set.h>
2526
#include <mp/type-string.h>

test/mp/test/foo.capnp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
5555
name @0 :Text;
5656
setint @1 :List(Int32);
5757
vbool @2 :List(Bool);
58+
optionalInt @3 :Int32 $Proxy.name("optional_int");
59+
hasOptionalInt @4 :Bool;
5860
}
5961

6062
struct FooCustom $Proxy.wrap("mp::test::FooCustom") {

test/mp/test/foo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <functional>
1010
#include <map>
1111
#include <memory>
12+
#include <optional>
1213
#include <string>
1314
#include <set>
1415
#include <vector>
@@ -21,6 +22,7 @@ struct FooStruct
2122
std::string name;
2223
std::set<int> setint;
2324
std::vector<bool> vbool;
25+
std::optional<int> optional_int;
2426
};
2527

2628
enum class FooEnum : uint8_t { ONE = 1, TWO = 2, };

test/mp/test/test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ KJ_TEST("Call FooInterface methods")
146146
in.vbool.push_back(false);
147147
in.vbool.push_back(true);
148148
in.vbool.push_back(false);
149+
in.optional_int = 3;
149150
FooStruct out = foo->pass(in);
150151
KJ_EXPECT(in.name == out.name);
151152
KJ_EXPECT(in.setint.size() == out.setint.size());
@@ -156,6 +157,12 @@ KJ_TEST("Call FooInterface methods")
156157
for (size_t i = 0; i < in.vbool.size(); ++i) {
157158
KJ_EXPECT(in.vbool[i] == out.vbool[i]);
158159
}
160+
KJ_EXPECT(in.optional_int == out.optional_int);
161+
162+
// Additional checks for std::optional member
163+
KJ_EXPECT(foo->pass(in).optional_int == 3);
164+
in.optional_int.reset();
165+
KJ_EXPECT(!foo->pass(in).optional_int);
159166

160167
FooStruct err;
161168
try {

0 commit comments

Comments
 (0)