1 module protoc_gen_d;
2 
3 import google.protobuf;
4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse;
5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
6     OneofDescriptorProto;
7 
8 int main()
9 {
10     import std.algorithm : map;
11     import std.array : array;
12     import std.range : isInputRange, take, walkLength;
13     import std.stdio : stdin, stdout;
14 
15     foreach (inputRange; stdin.byChunk(1024 * 1024))
16     {
17         auto request = inputRange.fromProtobuf!CodeGeneratorRequest;
18         auto codeGenerator = new CodeGenerator;
19 
20         stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array);
21     }
22 
23     return 0;
24 }
25 
26 class CodeGeneratorException : Exception
27 {
28     this(string message = null, string file = __FILE__, size_t line = __LINE__,
29         Throwable next = null) @safe pure nothrow
30     {
31         super(message, file, line, next);
32     }
33 }
34 
35 class CodeGenerator
36 {
37     private enum indentSize = 4;
38 
39     CodeGeneratorResponse handle(CodeGeneratorRequest request)
40     {
41         import std.algorithm : filter, map;
42         import std.array : array;
43         import std.conv : to;
44         import std.format : format;
45 
46         if (request.compilerVersion) with (request.compilerVersion)
47             protocVersion = format!"%d%03d%03d"(major, minor, patch);
48 
49         collectedMessageTypes.clear;
50         auto response = new CodeGeneratorResponse;
51         try
52         {
53             collectMessageAndEnumTypes(request);
54             response.files = request.protoFiles
55                 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types
56                 .map!(a => generate(a)).array;
57         }
58         catch (CodeGeneratorException generatorException)
59         {
60             response.error = generatorException.message.to!string;
61         }
62 
63         return response;
64     }
65 
66     private void collectMessageAndEnumTypes(CodeGeneratorRequest request)
67     {
68         void collect(DescriptorProto messageType, string prefix)
69         {
70             auto absoluteName = prefix ~ "." ~ messageType.name;
71 
72             if (absoluteName in collectedMessageTypes)
73                 return;
74 
75             collectedMessageTypes[absoluteName] = messageType;
76 
77             foreach (nestedType; messageType.nestedTypes)
78                 collect(nestedType, absoluteName);
79 
80             foreach (enumType; messageType.enumTypes)
81                 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType;
82         }
83 
84         foreach (file; request.protoFiles)
85         {
86             auto packagePrefix = file.package_ ? "." ~ file.package_ : "";
87 
88             foreach (messageType; file.messageTypes)
89                 collect(messageType, packagePrefix);
90 
91             foreach (enumType; file.enumTypes)
92                 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType;
93         }
94     }
95 
96     private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor)
97     {
98         import std.array : replace;
99         import std.exception : enforce;
100 
101         enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3",
102             "Can only generate D code for proto3 .proto files.\n" ~
103             "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n");
104 
105         auto file = new CodeGeneratorResponse.File;
106 
107         file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d";
108         file.content = generateFile(fileDescriptor);
109 
110         return file;
111     }
112 
113     private string generateFile(FileDescriptorProto fileDescriptor)
114     {
115         import std.array : appender, empty;
116         import std.format : format;
117 
118         auto result = appender!string;
119         result ~= "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n";
120         result ~= "// source: %s\n\n".format(fileDescriptor.name);
121         result ~= "module %s;\n\n".format(fileDescriptor.moduleName);
122         result ~= "import google.protobuf;\n";
123 
124         foreach (dependency; fileDescriptor.dependencies)
125             result ~= "import %s;\n".format(dependency.moduleName);
126 
127         if (!protocVersion.empty)
128             result ~= "\nenum protocVersion = %s;\n".format(protocVersion);
129 
130         foreach (messageType; fileDescriptor.messageTypes)
131             result ~= generateMessage(messageType);
132 
133         foreach (enumType; fileDescriptor.enumTypes)
134             result ~= generateEnum(enumType);
135 
136         return result.data;
137     }
138 
139     private string generateMessage(DescriptorProto messageType, size_t indent = 0)
140     {
141         import std.algorithm : canFind, filter, sort;
142         import std.array : appender, array;
143         import std.format : format;
144 
145         // Don't generate MapEntry messages, they are generated as associative arrays
146         if (messageType.isMap)
147             return "";
148 
149         auto result = appender!string;
150         result ~= "\n%*s%sclass %s\n".format(indent, "", indent > 0 ? "static " : "", messageType.name.escapeKeywords);
151         result ~= "%*s{\n".format(indent, "");
152 
153         int[] generatedOneofs;
154         foreach (field; messageType.fields.sort!((a, b) => a.number < b.number))
155         {
156             if (field.oneofIndex < 0)
157             {
158                 result ~= generateField(field, indent + indentSize);
159                 continue;
160             }
161 
162             if (generatedOneofs.canFind(field.oneofIndex))
163                 continue;
164 
165             result ~= generateOneof(messageType.oneofDecls[field.oneofIndex],
166                 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize);
167             generatedOneofs ~= field.oneofIndex;
168         }
169 
170         foreach (nestedType; messageType.nestedTypes)
171             result ~= generateMessage(nestedType, indent + indentSize);
172 
173         foreach (enumType; messageType.enumTypes)
174             result ~= generateEnum(enumType, indent + indentSize);
175 
176         result ~= "%*s}\n".format(indent, "");
177 
178         return result.data;
179     }
180 
181     private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true)
182     {
183         import std.format : format;
184 
185         return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field),
186             field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : "");
187     }
188 
189     private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
190     {
191         return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent);
192     }
193 
194     private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
195     {
196         import std.format : format;
197         import std.array : appender;
198 
199         auto result = appender!string;
200         result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true));
201         result ~= "%*s{\n".format(indent, "");
202         result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false));
203         foreach (field; fields)
204             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false),
205                 field.number);
206         result ~= "%*s}\n".format(indent, "");
207         result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "",
208             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
209         result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "",
210             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
211         result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "",
212             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
213 
214         return result.data;
215     }
216 
217     private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
218     {
219         import std.format : format;
220         import std.array : appender;
221 
222         auto result = appender!string;
223         result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false));
224         result ~= "%*s{\n".format(indent, "");
225         foreach (field; fields)
226             result ~= generateOneofField(field, indent + indentSize, field == fields[0]);
227         result ~= "%*s}\n".format(indent, "");
228 
229         return result.data;
230     }
231 
232     private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer)
233     {
234         import std.format : format;
235 
236         return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field),
237             typeName(field), field.name.underscoresToCamelCase(false),
238             printInitializer ? fieldInitializer(field) : "");
239     }
240 
241     private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0)
242     {
243         import std.array : appender, array;
244         import std.format : format;
245 
246         auto result = appender!string;
247         result ~= "\n%*senum %s\n".format(indent, "", enumType.name);
248         result ~= "%*s{\n".format(indent, "");
249 
250         foreach (value; enumType.values)
251             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number);
252 
253         result ~= "%*s}\n".format(indent, "");
254 
255         return result.data;
256     }
257 
258     private DescriptorProto messageType(FieldDescriptorProto field)
259     {
260         return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null;
261     }
262 
263     private EnumDescriptorProto enumType(FieldDescriptorProto field)
264     {
265         return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null;
266     }
267 
268     private Wire wireByField(FieldDescriptorProto field)
269     {
270         final switch (field.type) with (FieldDescriptorProto.Type)
271         {
272         case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64:
273         case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM:
274             return Wire.none;
275         case TYPE_MESSAGE:
276         {
277             auto fieldMessageType = messageType(field);
278 
279             if (fieldMessageType !is null && fieldMessageType.isMap)
280             {
281                 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key));
282                 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value));
283 
284                 return keyWire << 2 | valueWire << 4;
285             }
286             return Wire.none;
287         }
288         case TYPE_SINT32: case TYPE_SINT64:
289             return Wire.zigzag;
290         case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64:
291             return Wire.fixed;
292         case TYPE_GROUP: case TYPE_ERROR:
293             assert(0, "Invalid field type");
294         }
295     }
296 
297     private string baseTypeName(FieldDescriptorProto field)
298     {
299         import std.exception : enforce;
300 
301         final switch (field.type) with (FieldDescriptorProto.Type)
302         {
303         case TYPE_BOOL:
304             return "bool";
305         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
306             return "int";
307         case TYPE_UINT32: case TYPE_FIXED32:
308             return "uint";
309         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
310             return "long";
311         case TYPE_UINT64: case TYPE_FIXED64:
312             return "ulong";
313         case TYPE_FLOAT:
314             return "float";
315         case TYPE_DOUBLE:
316             return "double";
317         case TYPE_STRING:
318             return "string";
319         case TYPE_BYTES:
320             return "bytes";
321         case TYPE_MESSAGE:
322         {
323             auto fieldMessageType = messageType(field);
324             enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~
325                 "' has unknown message type " ~ field.typeName ~ "`");
326             return fieldMessageType.name;
327         }
328         case TYPE_ENUM:
329         {
330             auto fieldEnumType = enumType(field);
331             enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~
332                 "' has unknown enum type ' " ~field.typeName ~ "`");
333             return fieldEnumType.name;
334         }
335         case TYPE_GROUP: case TYPE_ERROR:
336             assert(0, "Invalid field type");
337         }
338     }
339 
340     string typeName(FieldDescriptorProto field)
341     {
342         import std.format : format;
343 
344         string fieldBaseTypeName = baseTypeName(field);
345 
346         auto fieldMessageType = messageType(field);
347 
348         if (fieldMessageType !is null && fieldMessageType.isMap)
349         {
350             auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key);
351             auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value);
352 
353             return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField));
354         }
355 
356         if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED)
357             return fieldBaseTypeName ~ "[]";
358         else
359             return fieldBaseTypeName;
360     }
361 
362     private string fieldProtoFields(FieldDescriptorProto field)
363     {
364         import std.algorithm : stripRight;
365         import std.conv : to;
366         import std.range : join;
367 
368         return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", ");
369     }
370 
371     private string fieldInitializer(FieldDescriptorProto field)
372     {
373         import std.algorithm : endsWith;
374         import std.format : format;
375 
376         auto fieldTypeName = typeName(field);
377         if (fieldTypeName.endsWith("]"))
378             return " = protoDefaultValue!(%s)".format(fieldTypeName);
379         else
380             return " = protoDefaultValue!%s".format(fieldTypeName);
381     }
382 
383     private string protocVersion;
384     private DescriptorProto[string] collectedMessageTypes;
385     private EnumDescriptorProto[string] collectedEnumTypes;
386 }
387 
388 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber)
389 {
390     import std.algorithm : find;
391     import std.array : empty;
392     import std.exception : enforce;
393     import std.format : format;
394 
395     auto result = messageType.fields.find!(a => a.number == fieldNumber);
396 
397     enforce!CodeGeneratorException(!result.empty,
398         "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber));
399 
400     return result[0];
401 }
402 
403 private bool isMap(DescriptorProto messageType)
404 {
405     return messageType.options && messageType.options.mapEntry;
406 }
407 
408 private enum MapFieldNumber
409 {
410     key = 1,
411     value = 2,
412 }
413 
414 private enum Wire
415 {
416     none,
417     fixed = 1 << 0,
418     zigzag = 1 << 1,
419     fixed_key = 1 << 2,
420     zigzag_key = 1 << 3,
421     fixed_value = 1 << 4,
422     zigzag_value = 1 << 5,
423     fixed_key_fixed_value = fixed_key | fixed_value,
424     fixed_key_zigzag_value = fixed_key | zigzag_value,
425     zigzag_key_fixed_value = zigzag_key | fixed_value,
426     zigzag_key_zigzag_value = zigzag_key | zigzag_value,
427 }
428 
429 private string toString(Wire wire)
430 {
431     final switch (wire) with (Wire)
432     {
433     case none:
434         return "Wire.none";
435     case fixed:
436         return "Wire.fixed";
437     case zigzag:
438         return "Wire.zigzag";
439     case fixed_key:
440         return "Wire.fixedKey";
441     case zigzag_key:
442         return "Wire.zigzagKey";
443     case fixed_value:
444         return "Wire.fixedValue";
445     case zigzag_value:
446         return "Wire.zigzagValue";
447     case fixed_key_fixed_value:
448         return "Wire.fixedKeyFixedValue";
449     case fixed_key_zigzag_value:
450         return "Wire.fixedKeyZigzagValue";
451     case zigzag_key_fixed_value:
452         return "Wire.zigzagKeyFixedValue";
453     case zigzag_key_zigzag_value:
454         return "Wire.zigzagKeyZigzagValue";
455     }
456 }
457 
458 private string moduleName(FileDescriptorProto fileDescriptor)
459 {
460     import std.array : empty;
461     import std.path : baseName;
462 
463     string moduleName = fileDescriptor.name.baseName(".proto");
464 
465     if (!fileDescriptor.package_.empty)
466         moduleName = fileDescriptor.package_ ~ "." ~ moduleName;
467 
468     return moduleName.escapeKeywords;
469 }
470 
471 private string moduleName(string fileName)
472 {
473     import std.array : replace;
474     import std.string : chomp;
475 
476     return fileName.chomp(".proto").replace("/", ".").escapeKeywords;
477 }
478 
479 private string underscoresToCamelCase(string input, bool capitalizeNextLetter)
480 {
481     import std.array : appender;
482 
483     auto result = appender!string;
484 
485     foreach (ubyte c; input)
486     {
487         if (c == '_')
488         {
489             capitalizeNextLetter = true;
490             continue;
491         }
492 
493         if ('a' <= c && c <= 'z' && capitalizeNextLetter)
494             c += 'A' - 'a';
495 
496         result ~= c;
497         capitalizeNextLetter = false;
498     }
499 
500     return result.data.escapeKeywords;
501 }
502 
503 private enum string[] keywords = [
504     "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch",
505     "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default",
506     "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final",
507     "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat",
508     "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro",
509     "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected",
510     "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch",
511     "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent",
512     "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with",
513     "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector",
514     "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__",
515     "__TIMESTAMP__", "__VENDOR__", "__VERSION__",
516 ];
517 
518 private string escapeKeywords(string input, string separator = ".")
519 {
520     import std.algorithm : canFind, joiner, map, splitter;
521     import std.conv : to;
522 
523     return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string;
524 }