1 module protoc_gen_d;
2 
3 import google.protobuf;
4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse;
5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
6     OneofDescriptorProto;
7 
8 int main()
9 {
10     import std.algorithm : map;
11     import std.array : array;
12     import std.range : isInputRange, take, walkLength;
13     import std.stdio : stdin, stdout;
14 
15     foreach (inputRange; stdin.byChunk(1024 * 1024))
16     {
17         auto request = inputRange.fromProtobuf!CodeGeneratorRequest;
18         auto codeGenerator = new CodeGenerator;
19 
20         stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array);
21     }
22 
23     return 0;
24 }
25 
26 class CodeGeneratorException : Exception
27 {
28     this(string message = null, string file = __FILE__, size_t line = __LINE__,
29         Throwable next = null) @safe pure nothrow
30     {
31         super(message, file, line, next);
32     }
33 }
34 
35 class CodeGenerator
36 {
37     private enum indentSize = 4;
38     private bool messageAsStruct = false;
39 
40     CodeGeneratorResponse handle(CodeGeneratorRequest request)
41     {
42         import std.algorithm : filter, map, canFind, splitter;
43         import std.array : array;
44         import std.conv : to;
45         import std.format : format;
46 
47         if (request.parameter.splitter(",").canFind("message-as-struct"))
48             messageAsStruct = true;
49 
50         if (request.compilerVersion) with (request.compilerVersion)
51             protocVersion = format!"%d%03d%03d"(major, minor, patch);
52 
53         collectedMessageTypes.clear;
54         auto response = new CodeGeneratorResponse;
55         try
56         {
57             collectMessageAndEnumTypes(request);
58             response.files = request.protoFiles
59                 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types
60                 .map!(a => generate(a)).array;
61         }
62         catch (CodeGeneratorException generatorException)
63         {
64             response.error = generatorException.message.to!string;
65         }
66 
67         return response;
68     }
69 
70     private void collectMessageAndEnumTypes(CodeGeneratorRequest request)
71     {
72         void collect(DescriptorProto messageType, string prefix)
73         {
74             auto absoluteName = prefix ~ "." ~ messageType.name;
75 
76             if (absoluteName in collectedMessageTypes)
77                 return;
78 
79             collectedMessageTypes[absoluteName] = messageType;
80 
81             foreach (nestedType; messageType.nestedTypes)
82                 collect(nestedType, absoluteName);
83 
84             foreach (enumType; messageType.enumTypes)
85                 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType;
86         }
87 
88         foreach (file; request.protoFiles)
89         {
90             auto packagePrefix = file.package_ ? "." ~ file.package_ : "";
91 
92             foreach (messageType; file.messageTypes)
93                 collect(messageType, packagePrefix);
94 
95             foreach (enumType; file.enumTypes)
96                 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType;
97         }
98     }
99 
100     private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor)
101     {
102         import std.array : replace;
103         import std.exception : enforce;
104 
105         enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3",
106             "Can only generate D code for proto3 .proto files.\n" ~
107             "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n");
108 
109         auto file = new CodeGeneratorResponse.File;
110 
111         file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d";
112         file.content = generateFile(fileDescriptor);
113 
114         return file;
115     }
116 
117     private string generateFile(FileDescriptorProto fileDescriptor)
118     {
119         import std.array : appender, empty;
120         import std.format : format;
121 
122         auto result = appender!string;
123         result ~= "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n";
124         result ~= "// source: %s\n\n".format(fileDescriptor.name);
125         result ~= "module %s;\n\n".format(fileDescriptor.moduleName);
126         result ~= "import google.protobuf;\n";
127 
128         foreach (dependency; fileDescriptor.dependencies)
129             result ~= "import %s;\n".format(dependency.moduleName);
130 
131         if (!protocVersion.empty)
132             result ~= "\nenum protocVersion = %s;\n".format(protocVersion);
133 
134         foreach (messageType; fileDescriptor.messageTypes)
135             result ~= generateMessage(messageType);
136 
137         foreach (enumType; fileDescriptor.enumTypes)
138             result ~= generateEnum(enumType);
139 
140         return result.data;
141     }
142 
143     private string generateMessage(DescriptorProto messageType, size_t indent = 0)
144     {
145         import std.algorithm : canFind, filter, sort;
146         import std.array : appender, array;
147         import std.format : format;
148 
149         // Don't generate MapEntry messages, they are generated as associative arrays
150         if (messageType.isMap)
151             return "";
152 
153         auto indentation = "%*s".format(indent, "");
154 
155         auto result = appender!string;
156         result ~= "\n";
157         result ~= indentation;
158         result ~= indent > 0 ? "static " : "";
159         result ~= messageAsStruct ? "struct" : "class";
160         result ~= " ";
161         result ~= messageType.name.escapeKeywords;
162         result ~= "\n";
163         result ~= indentation;
164         result ~= "{\n";
165 
166         int[] generatedOneofs;
167         foreach (field; messageType.fields.sort!((a, b) => a.number < b.number))
168         {
169             if (field.oneofIndex < 0)
170             {
171                 result ~= generateField(field, indent + indentSize);
172                 continue;
173             }
174 
175             if (generatedOneofs.canFind(field.oneofIndex))
176                 continue;
177 
178             result ~= generateOneof(messageType.oneofDecls[field.oneofIndex],
179                 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize);
180             generatedOneofs ~= field.oneofIndex;
181         }
182 
183         foreach (nestedType; messageType.nestedTypes)
184             result ~= generateMessage(nestedType, indent + indentSize);
185 
186         foreach (enumType; messageType.enumTypes)
187             result ~= generateEnum(enumType, indent + indentSize);
188 
189         result ~= indentation;
190         result ~= "}\n";
191 
192         return result.data;
193     }
194 
195     private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true)
196     {
197         import std.format : format;
198 
199         return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field),
200             field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : "");
201     }
202 
203     private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
204     {
205         return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent);
206     }
207 
208     private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
209     {
210         import std.format : format;
211         import std.array : appender;
212 
213         auto result = appender!string;
214         result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true));
215         result ~= "%*s{\n".format(indent, "");
216         result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false));
217         foreach (field; fields)
218             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false),
219                 field.number);
220         result ~= "%*s}\n".format(indent, "");
221         result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "",
222             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
223         result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "",
224             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
225         result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "",
226             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
227 
228         return result.data;
229     }
230 
231     private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
232     {
233         import std.format : format;
234         import std.array : appender;
235 
236         auto result = appender!string;
237         result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false));
238         result ~= "%*s{\n".format(indent, "");
239         foreach (field; fields)
240             result ~= generateOneofField(field, indent + indentSize, field == fields[0]);
241         result ~= "%*s}\n".format(indent, "");
242 
243         return result.data;
244     }
245 
246     private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer)
247     {
248         import std.format : format;
249 
250         return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field),
251             typeName(field), field.name.underscoresToCamelCase(false),
252             printInitializer ? fieldInitializer(field) : "");
253     }
254 
255     private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0)
256     {
257         import std.array : appender, array;
258         import std.format : format;
259 
260         auto result = appender!string;
261         result ~= "\n%*senum %s\n".format(indent, "", enumType.name);
262         result ~= "%*s{\n".format(indent, "");
263 
264         foreach (value; enumType.values)
265             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number);
266 
267         result ~= "%*s}\n".format(indent, "");
268 
269         return result.data;
270     }
271 
272     private DescriptorProto messageType(FieldDescriptorProto field)
273     {
274         return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null;
275     }
276 
277     private EnumDescriptorProto enumType(FieldDescriptorProto field)
278     {
279         return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null;
280     }
281 
282     private Wire wireByField(FieldDescriptorProto field)
283     {
284         final switch (field.type) with (FieldDescriptorProto.Type)
285         {
286         case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64:
287         case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM:
288             return Wire.none;
289         case TYPE_MESSAGE:
290         {
291             auto fieldMessageType = messageType(field);
292 
293             if (fieldMessageType !is null && fieldMessageType.isMap)
294             {
295                 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key));
296                 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value));
297 
298                 return keyWire << 2 | valueWire << 4;
299             }
300             return Wire.none;
301         }
302         case TYPE_SINT32: case TYPE_SINT64:
303             return Wire.zigzag;
304         case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64:
305             return Wire.fixed;
306         case TYPE_GROUP: case TYPE_ERROR:
307             assert(0, "Invalid field type");
308         }
309     }
310 
311     private string baseTypeName(FieldDescriptorProto field)
312     {
313         import std.exception : enforce;
314 
315         final switch (field.type) with (FieldDescriptorProto.Type)
316         {
317         case TYPE_BOOL:
318             return "bool";
319         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
320             return "int";
321         case TYPE_UINT32: case TYPE_FIXED32:
322             return "uint";
323         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
324             return "long";
325         case TYPE_UINT64: case TYPE_FIXED64:
326             return "ulong";
327         case TYPE_FLOAT:
328             return "float";
329         case TYPE_DOUBLE:
330             return "double";
331         case TYPE_STRING:
332             return "string";
333         case TYPE_BYTES:
334             return "bytes";
335         case TYPE_MESSAGE:
336         {
337             auto fieldMessageType = messageType(field);
338             enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~
339                 "' has unknown message type " ~ field.typeName ~ "`");
340             return fieldMessageType.name;
341         }
342         case TYPE_ENUM:
343         {
344             auto fieldEnumType = enumType(field);
345             enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~
346                 "' has unknown enum type ' " ~field.typeName ~ "`");
347             return fieldEnumType.name;
348         }
349         case TYPE_GROUP: case TYPE_ERROR:
350             assert(0, "Invalid field type");
351         }
352     }
353 
354     private bool isBaseTypePackable(FieldDescriptorProto field)
355     {
356         import std.exception : enforce;
357 
358         final switch (field.type) with (FieldDescriptorProto.Type)
359         {
360         case TYPE_BOOL:
361         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
362         case TYPE_UINT32: case TYPE_FIXED32:
363         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
364         case TYPE_UINT64: case TYPE_FIXED64:
365         case TYPE_FLOAT:
366         case TYPE_DOUBLE:
367             return true;
368         case TYPE_STRING:
369         case TYPE_BYTES:
370         case TYPE_MESSAGE:
371         case TYPE_ENUM:
372             return false;
373         case TYPE_GROUP: case TYPE_ERROR:
374             assert(0, "Invalid field type");
375         }
376     }
377 
378     string typeName(FieldDescriptorProto field)
379     {
380         import std.format : format;
381 
382         string fieldBaseTypeName = baseTypeName(field);
383 
384         auto fieldMessageType = messageType(field);
385 
386         if (fieldMessageType !is null && fieldMessageType.isMap)
387         {
388             auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key);
389             auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value);
390 
391             return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField));
392         }
393 
394         if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED)
395             return fieldBaseTypeName ~ "[]";
396         else
397             return fieldBaseTypeName;
398     }
399 
400     private string fieldProtoFields(FieldDescriptorProto field)
401     {
402         import std.algorithm : stripRight;
403         import std.conv : to;
404         import std.range : join;
405 
406         string packedByField(FieldDescriptorProto field)
407         {
408             if (field.label != FieldDescriptorProto.Label.LABEL_REPEATED)
409                 return "No.packed";
410 
411             if (!isBaseTypePackable(field))
412                 return "No.packed";
413 
414             return (!field.options || field.options.packed) ? "Yes.packed" : "No.packed";
415         }
416 
417         return [field.number.to!string, wireByField(field).toString, packedByField(field)]
418             .stripRight("No.packed")
419             .stripRight("Wire.none")
420             .join(", ");
421     }
422 
423     private string fieldInitializer(FieldDescriptorProto field)
424     {
425         import std.algorithm : endsWith;
426         import std.format : format;
427 
428         auto fieldTypeName = typeName(field);
429         if (fieldTypeName.endsWith("]"))
430             return " = protoDefaultValue!(%s)".format(fieldTypeName);
431         else
432             return " = protoDefaultValue!%s".format(fieldTypeName);
433     }
434 
435     private string protocVersion;
436     private DescriptorProto[string] collectedMessageTypes;
437     private EnumDescriptorProto[string] collectedEnumTypes;
438 }
439 
440 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber)
441 {
442     import std.algorithm : find;
443     import std.array : empty;
444     import std.exception : enforce;
445     import std.format : format;
446 
447     auto result = messageType.fields.find!(a => a.number == fieldNumber);
448 
449     enforce!CodeGeneratorException(!result.empty,
450         "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber));
451 
452     return result[0];
453 }
454 
455 private bool isMap(DescriptorProto messageType)
456 {
457     return messageType.options && messageType.options.mapEntry;
458 }
459 
460 private enum MapFieldNumber
461 {
462     key = 1,
463     value = 2,
464 }
465 
466 private enum Wire
467 {
468     none,
469     fixed = 1 << 0,
470     zigzag = 1 << 1,
471     fixed_key = 1 << 2,
472     zigzag_key = 1 << 3,
473     fixed_value = 1 << 4,
474     zigzag_value = 1 << 5,
475     fixed_key_fixed_value = fixed_key | fixed_value,
476     fixed_key_zigzag_value = fixed_key | zigzag_value,
477     zigzag_key_fixed_value = zigzag_key | fixed_value,
478     zigzag_key_zigzag_value = zigzag_key | zigzag_value,
479 }
480 
481 private string toString(Wire wire)
482 {
483     final switch (wire) with (Wire)
484     {
485     case none:
486         return "Wire.none";
487     case fixed:
488         return "Wire.fixed";
489     case zigzag:
490         return "Wire.zigzag";
491     case fixed_key:
492         return "Wire.fixedKey";
493     case zigzag_key:
494         return "Wire.zigzagKey";
495     case fixed_value:
496         return "Wire.fixedValue";
497     case zigzag_value:
498         return "Wire.zigzagValue";
499     case fixed_key_fixed_value:
500         return "Wire.fixedKeyFixedValue";
501     case fixed_key_zigzag_value:
502         return "Wire.fixedKeyZigzagValue";
503     case zigzag_key_fixed_value:
504         return "Wire.zigzagKeyFixedValue";
505     case zigzag_key_zigzag_value:
506         return "Wire.zigzagKeyZigzagValue";
507     }
508 }
509 
510 private string moduleName(FileDescriptorProto fileDescriptor)
511 {
512     import std.array : empty;
513     import std.path : baseName;
514 
515     string moduleName = fileDescriptor.name.baseName(".proto");
516 
517     if (!fileDescriptor.package_.empty)
518         moduleName = fileDescriptor.package_ ~ "." ~ moduleName;
519 
520     return moduleName.escapeKeywords;
521 }
522 
523 private string moduleName(string fileName)
524 {
525     import std.array : replace;
526     import std..string : chomp;
527 
528     return fileName.chomp(".proto").replace("/", ".").escapeKeywords;
529 }
530 
531 private string underscoresToCamelCase(string input, bool capitalizeNextLetter)
532 {
533     import std.array : appender;
534 
535     auto result = appender!string;
536 
537     foreach (ubyte c; input)
538     {
539         if (c == '_')
540         {
541             capitalizeNextLetter = true;
542             continue;
543         }
544 
545         if ('a' <= c && c <= 'z' && capitalizeNextLetter)
546             c += 'A' - 'a';
547 
548         result ~= c;
549         capitalizeNextLetter = false;
550     }
551 
552     return result.data.escapeKeywords;
553 }
554 
555 private enum string[] keywords = [
556     "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch",
557     "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default",
558     "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final",
559     "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat",
560     "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro",
561     "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected",
562     "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch",
563     "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent",
564     "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with",
565     "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector",
566     "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__",
567     "__TIMESTAMP__", "__VENDOR__", "__VERSION__",
568 ];
569 
570 private string escapeKeywords(string input, string separator = ".")
571 {
572     import std.algorithm : canFind, joiner, map, splitter;
573     import std.conv : to;
574 
575     return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string;
576 }