1 module protoc_gen_d;
2 
3 import google.protobuf;
4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse;
5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
6     OneofDescriptorProto;
7 
8 int main()
9 {
10     import std.algorithm : map;
11     import std.array : array;
12     import std.range : isInputRange, take, walkLength;
13     import std.stdio : stdin, stdout;
14 
15     foreach (inputRange; stdin.byChunk(1024 * 1024))
16     {
17         auto request = inputRange.fromProtobuf!CodeGeneratorRequest;
18         auto codeGenerator = new CodeGenerator;
19 
20         stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array);
21     }
22 
23     return 0;
24 }
25 
26 class CodeGeneratorException : Exception
27 {
28     this(string message = null, string file = __FILE__, size_t line = __LINE__,
29         Throwable next = null) @safe pure nothrow
30     {
31         super(message, file, line, next);
32     }
33 }
34 
35 class CodeGenerator
36 {
37     private enum indentSize = 4;
38 
39     CodeGeneratorResponse handle(CodeGeneratorRequest request)
40     {
41         import std.algorithm : filter, map;
42         import std.array : array;
43         import std.conv : to;
44         import std.format : format;
45 
46         with (request.compilerVersion)
47             protocVersion = format!"%d%03d%03d"(major, minor, patch);
48 
49         collectedMessageTypes.clear;
50         auto response = new CodeGeneratorResponse;
51         try
52         {
53             collectMessageAndEnumTypes(request);
54             response.files = request.protoFiles
55                 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types
56                 .map!(a => generate(a)).array;
57         }
58         catch (CodeGeneratorException generatorException)
59         {
60             response.error = generatorException.message.to!string;
61         }
62 
63         return response;
64     }
65 
66     private void collectMessageAndEnumTypes(CodeGeneratorRequest request)
67     {
68         void collect(DescriptorProto messageType, string prefix)
69         {
70             auto absoluteName = prefix ~ "." ~ messageType.name;
71 
72             if (absoluteName in collectedMessageTypes)
73                 return;
74 
75             collectedMessageTypes[absoluteName] = messageType;
76 
77             foreach (nestedType; messageType.nestedTypes)
78                 collect(nestedType, absoluteName);
79 
80             foreach (enumType; messageType.enumTypes)
81                 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType;
82         }
83 
84         foreach (file; request.protoFiles)
85         {
86             foreach (messageType; file.messageTypes)
87                 collect(messageType, file.package_ ? "." ~ file.package_ : "");
88 
89             foreach (enumType; file.enumTypes)
90                 collectedEnumTypes["." ~ enumType.name] = enumType;
91         }
92     }
93 
94     private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor)
95     {
96         import std.array : replace;
97         import std.exception : enforce;
98 
99         enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3",
100             "Can only generate D code for proto3 .proto files.\n" ~
101             "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n");
102 
103         auto file = new CodeGeneratorResponse.File;
104 
105         file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d";
106         file.content = generateFile(fileDescriptor);
107 
108         return file;
109     }
110 
111     private string generateFile(FileDescriptorProto fileDescriptor)
112     {
113         import std.array : appender, empty;
114         import std.format : format;
115 
116         auto result = appender!string;
117         result ~= "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n";
118         result ~= "// source: %s\n\n".format(fileDescriptor.name);
119         result ~= "module %s;\n\n".format(fileDescriptor.moduleName);
120         result ~= "import google.protobuf;\n";
121 
122         foreach (dependency; fileDescriptor.dependencies)
123             result ~= "import %s;\n".format(dependency.moduleName);
124 
125         if (!protocVersion.empty)
126             result ~= "\nenum protocVersion = %s;\n".format(protocVersion);
127 
128         foreach (messageType; fileDescriptor.messageTypes)
129             result ~= generateMessage(messageType);
130 
131         foreach (enumType; fileDescriptor.enumTypes)
132             result ~= generateEnum(enumType);
133 
134         return result.data;
135     }
136 
137     private string generateMessage(DescriptorProto messageType, size_t indent = 0)
138     {
139         import std.algorithm : canFind, filter, sort;
140         import std.array : appender, array;
141         import std.format : format;
142 
143         // Don't generate MapEntry messages, they are generated as associative arrays
144         if (messageType.isMap)
145             return "";
146 
147         auto result = appender!string;
148         result ~= "\n%*s%sclass %s\n".format(indent, "", indent > 0 ? "static " : "", messageType.name.escapeKeywords);
149         result ~= "%*s{\n".format(indent, "");
150 
151         int[] generatedOneofs;
152         foreach (field; messageType.fields.sort!((a, b) => a.number < b.number))
153         {
154             if (field.oneofIndex < 0)
155             {
156                 result ~= generateField(field, indent + indentSize);
157                 continue;
158             }
159 
160             if (generatedOneofs.canFind(field.oneofIndex))
161                 continue;
162 
163             result ~= generateOneof(messageType.oneofDecls[field.oneofIndex],
164                 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize);
165             generatedOneofs ~= field.oneofIndex;
166         }
167 
168         foreach (nestedType; messageType.nestedTypes)
169             result ~= generateMessage(nestedType, indent + indentSize);
170 
171         foreach (enumType; messageType.enumTypes)
172             result ~= generateEnum(enumType, indent + indentSize);
173 
174         result ~= "%*s}\n".format(indent, "");
175 
176         return result.data;
177     }
178 
179     private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true)
180     {
181         import std.format : format;
182 
183         return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field),
184             field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : "");
185     }
186 
187     private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
188     {
189         return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent);
190     }
191 
192     private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
193     {
194         import std.format : format;
195         import std.array : appender;
196 
197         auto result = appender!string;
198         result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true));
199         result ~= "%*s{\n".format(indent, "");
200         result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false));
201         foreach (field; fields)
202             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false),
203                 field.number);
204         result ~= "%*s}\n".format(indent, "");
205         result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "",
206             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
207         result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "",
208             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
209         result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "",
210             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
211 
212         return result.data;
213     }
214 
215     private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
216     {
217         import std.format : format;
218         import std.array : appender;
219 
220         auto result = appender!string;
221         result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false));
222         result ~= "%*s{\n".format(indent, "");
223         foreach (field; fields)
224             result ~= generateOneofField(field, indent + indentSize, field == fields[0]);
225         result ~= "%*s}\n".format(indent, "");
226 
227         return result.data;
228     }
229 
230     private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer)
231     {
232         import std.format : format;
233 
234         return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field),
235             typeName(field), field.name.underscoresToCamelCase(false),
236             printInitializer ? fieldInitializer(field) : "");
237     }
238 
239     private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0)
240     {
241         import std.array : appender, array;
242         import std.format : format;
243 
244         auto result = appender!string;
245         result ~= "\n%*senum %s\n".format(indent, "", enumType.name);
246         result ~= "%*s{\n".format(indent, "");
247 
248         foreach (value; enumType.values)
249             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number);
250 
251         result ~= "%*s}\n".format(indent, "");
252 
253         return result.data;
254     }
255 
256     private DescriptorProto messageType(FieldDescriptorProto field)
257     {
258         return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null;
259     }
260 
261     private EnumDescriptorProto enumType(FieldDescriptorProto field)
262     {
263         return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null;
264     }
265 
266     private Wire wireByField(FieldDescriptorProto field)
267     {
268         final switch (field.type) with (FieldDescriptorProto.Type)
269         {
270         case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64:
271         case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM:
272             return Wire.none;
273         case TYPE_MESSAGE:
274         {
275             auto fieldMessageType = messageType(field);
276 
277             if (fieldMessageType !is null && fieldMessageType.isMap)
278             {
279                 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key));
280                 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value));
281 
282                 return keyWire << 2 | valueWire << 4;
283             }
284             return Wire.none;
285         }
286         case TYPE_SINT32: case TYPE_SINT64:
287             return Wire.zigzag;
288         case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64:
289             return Wire.fixed;
290         case TYPE_GROUP: case TYPE_ERROR:
291             assert(0, "Invalid field type");
292         }
293     }
294 
295     private string baseTypeName(FieldDescriptorProto field)
296     {
297         import std.exception : enforce;
298 
299         final switch (field.type) with (FieldDescriptorProto.Type)
300         {
301         case TYPE_BOOL:
302             return "bool";
303         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
304             return "int";
305         case TYPE_UINT32: case TYPE_FIXED32:
306             return "uint";
307         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
308             return "long";
309         case TYPE_UINT64: case TYPE_FIXED64:
310             return "ulong";
311         case TYPE_FLOAT:
312             return "float";
313         case TYPE_DOUBLE:
314             return "double";
315         case TYPE_STRING:
316             return "string";
317         case TYPE_BYTES:
318             return "bytes";
319         case TYPE_MESSAGE:
320         {
321             auto fieldMessageType = messageType(field);
322             enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~
323                 "' has unknown message type " ~ field.typeName ~ "`");
324             return fieldMessageType.name;
325         }
326         case TYPE_ENUM:
327         {
328             auto fieldEnumType = enumType(field);
329             enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~
330                 "' has unknown enum type ' " ~field.typeName ~ "`");
331             return fieldEnumType.name;
332         }
333         case TYPE_GROUP: case TYPE_ERROR:
334             assert(0, "Invalid field type");
335         }
336     }
337 
338     string typeName(FieldDescriptorProto field)
339     {
340         import std.format : format;
341 
342         string fieldBaseTypeName = baseTypeName(field);
343 
344         auto fieldMessageType = messageType(field);
345 
346         if (fieldMessageType !is null && fieldMessageType.isMap)
347         {
348             auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key);
349             auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value);
350 
351             return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField));
352         }
353 
354         if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED)
355             return fieldBaseTypeName ~ "[]";
356         else
357             return fieldBaseTypeName;
358     }
359 
360     private string fieldProtoFields(FieldDescriptorProto field)
361     {
362         import std.algorithm : stripRight;
363         import std.conv : to;
364         import std.range : join;
365 
366         return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", ");
367     }
368 
369     private string fieldInitializer(FieldDescriptorProto field)
370     {
371         import std.algorithm : endsWith;
372         import std.format : format;
373 
374         auto fieldTypeName = typeName(field);
375         if (fieldTypeName.endsWith("]"))
376             return " = protoDefaultValue!(%s)".format(fieldTypeName);
377         else
378             return " = protoDefaultValue!%s".format(fieldTypeName);
379     }
380 
381     private string protocVersion;
382     private DescriptorProto[string] collectedMessageTypes;
383     private EnumDescriptorProto[string] collectedEnumTypes;
384 }
385 
386 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber)
387 {
388     import std.algorithm : find;
389     import std.array : empty;
390     import std.exception : enforce;
391     import std.format : format;
392 
393     auto result = messageType.fields.find!(a => a.number == fieldNumber);
394 
395     enforce!CodeGeneratorException(!result.empty,
396         "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber));
397 
398     return result[0];
399 }
400 
401 private bool isMap(DescriptorProto messageType)
402 {
403     return messageType.options && messageType.options.mapEntry;
404 }
405 
406 private enum MapFieldNumber
407 {
408     key = 1,
409     value = 2,
410 }
411 
412 private enum Wire
413 {
414     none,
415     fixed = 1 << 0,
416     zigzag = 1 << 1,
417     fixed_key = 1 << 2,
418     zigzag_key = 1 << 3,
419     fixed_value = 1 << 4,
420     zigzag_value = 1 << 5,
421     fixed_key_fixed_value = fixed_key | fixed_value,
422     fixed_key_zigzag_value = fixed_key | zigzag_value,
423     zigzag_key_fixed_value = zigzag_key | fixed_value,
424     zigzag_key_zigzag_value = zigzag_key | zigzag_value,
425 }
426 
427 private string toString(Wire wire)
428 {
429     final switch (wire) with (Wire)
430     {
431     case none:
432         return "Wire.none";
433     case fixed:
434         return "Wire.fixed";
435     case zigzag:
436         return "Wire.zigzag";
437     case fixed_key:
438         return "Wire.fixedKey";
439     case zigzag_key:
440         return "Wire.zigzagKey";
441     case fixed_value:
442         return "Wire.fixedValue";
443     case zigzag_value:
444         return "Wire.zigzagValue";
445     case fixed_key_fixed_value:
446         return "Wire.fixedKeyFixedValue";
447     case fixed_key_zigzag_value:
448         return "Wire.fixedKeyZigzagValue";
449     case zigzag_key_fixed_value:
450         return "Wire.zigzagKeyFixedValue";
451     case zigzag_key_zigzag_value:
452         return "Wire.zigzagKeyZigzagValue";
453     }
454 }
455 
456 private string moduleName(FileDescriptorProto fileDescriptor)
457 {
458     import std.array : empty;
459     import std.path : baseName;
460 
461     string moduleName = fileDescriptor.name.baseName(".proto");
462 
463     if (!fileDescriptor.package_.empty)
464         moduleName = fileDescriptor.package_ ~ "." ~ moduleName;
465 
466     return moduleName.escapeKeywords;
467 }
468 
469 private string moduleName(string fileName)
470 {
471     import std.array : replace;
472     import std..string : chomp;
473 
474     return fileName.chomp(".proto").replace("/", ".");
475 }
476 
477 private string underscoresToCamelCase(string input, bool capitalizeNextLetter)
478 {
479     import std.array : appender;
480 
481     auto result = appender!string;
482 
483     foreach (ubyte c; input)
484     {
485         if (c == '_')
486         {
487             capitalizeNextLetter = true;
488             continue;
489         }
490 
491         if ('a' <= c && c <= 'z' && capitalizeNextLetter)
492             c += 'A' - 'a';
493 
494         result ~= c;
495         capitalizeNextLetter = false;
496     }
497 
498     return result.data.escapeKeywords;
499 }
500 
501 private enum string[] keywords = [
502     "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch",
503     "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default",
504     "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final",
505     "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat",
506     "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro",
507     "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected",
508     "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch",
509     "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent",
510     "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with",
511     "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector",
512     "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__",
513     "__TIMESTAMP__", "__VENDOR__", "__VERSION__",
514 ];
515 
516 private string escapeKeywords(string input, string separator = ".")
517 {
518     import std.algorithm : canFind, joiner, map, splitter;
519     import std.conv : to;
520 
521     return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string;
522 }