1 module protoc_gen_d;
2 
3 import google.protobuf;
4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse;
5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
6     OneofDescriptorProto;
7 
8 int main()
9 {
10     import std.algorithm : map;
11     import std.array : array;
12     import std.range : isInputRange, take, walkLength;
13     import std.stdio : stdin, stdout;
14 
15     foreach (inputRange; stdin.byChunk(1024 * 1024))
16     {
17         auto request = inputRange.fromProtobuf!CodeGeneratorRequest;
18         auto codeGenerator = new CodeGenerator;
19 
20         stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array);
21     }
22 
23     return 0;
24 }
25 
26 class CodeGeneratorException : Exception
27 {
28     this(string message = null, string file = __FILE__, size_t line = __LINE__,
29         Throwable next = null) @safe pure nothrow
30     {
31         super(message, file, line, next);
32     }
33 }
34 
35 class CodeGenerator
36 {
37     private enum indentSize = 4;
38     private bool messageAsStruct = false;
39 
40     CodeGeneratorResponse handle(CodeGeneratorRequest request)
41     {
42         import std.algorithm : filter, map, canFind, splitter;
43         import std.array : array;
44         import std.conv : to;
45         import std.format : format;
46 
47         if (request.parameter.splitter(",").canFind("message-as-struct"))
48             messageAsStruct = true;
49 
50         if (request.compilerVersion) with (request.compilerVersion)
51             protocVersion = format!"%d%03d%03d"(major, minor, patch);
52 
53         collectedMessageTypes.clear;
54         auto response = new CodeGeneratorResponse;
55         try
56         {
57             collectMessageAndEnumTypes(request);
58             response.files = request.protoFiles
59                 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types
60                 .map!(a => generate(a)).array;
61         }
62         catch (CodeGeneratorException generatorException)
63         {
64             response.error = generatorException.message.to!string;
65         }
66 
67         return response;
68     }
69 
70     private void collectMessageAndEnumTypes(CodeGeneratorRequest request)
71     {
72         void collect(DescriptorProto messageType, string prefix)
73         {
74             auto absoluteName = prefix ~ "." ~ messageType.name;
75 
76             if (absoluteName in collectedMessageTypes)
77                 return;
78 
79             collectedMessageTypes[absoluteName] = messageType;
80 
81             foreach (nestedType; messageType.nestedTypes)
82                 collect(nestedType, absoluteName);
83 
84             foreach (enumType; messageType.enumTypes)
85                 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType;
86         }
87 
88         foreach (file; request.protoFiles)
89         {
90             auto packagePrefix = file.package_ ? "." ~ file.package_ : "";
91 
92             foreach (messageType; file.messageTypes)
93                 collect(messageType, packagePrefix);
94 
95             foreach (enumType; file.enumTypes)
96                 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType;
97         }
98     }
99 
100     private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor)
101     {
102         import std.array : replace;
103         import std.exception : enforce;
104 
105         enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3",
106             "Can only generate D code for proto3 .proto files.\n" ~
107             "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n");
108 
109         auto file = new CodeGeneratorResponse.File;
110 
111         file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d";
112         file.content = generateFile(fileDescriptor);
113 
114         return file;
115     }
116 
117     private string generateFile(FileDescriptorProto fileDescriptor)
118     {
119         import std.array : appender, empty;
120         import std.format : format;
121 
122         auto result = appender!string;
123         result ~= "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n";
124         result ~= "// source: %s\n\n".format(fileDescriptor.name);
125         result ~= "module %s;\n\n".format(fileDescriptor.moduleName);
126         result ~= "import google.protobuf;\n";
127 
128         foreach (dependency; fileDescriptor.dependencies)
129             result ~= "import %s;\n".format(dependency.moduleName);
130 
131         if (!protocVersion.empty)
132             result ~= "\nenum protocVersion = %s;\n".format(protocVersion);
133 
134         foreach (messageType; fileDescriptor.messageTypes)
135             result ~= generateMessage(messageType);
136 
137         foreach (enumType; fileDescriptor.enumTypes)
138             result ~= generateEnum(enumType);
139 
140         return result.data;
141     }
142 
143     private string generateMessage(DescriptorProto messageType, size_t indent = 0)
144     {
145         import std.algorithm : canFind, filter, sort;
146         import std.array : appender, array;
147         import std.format : format;
148 
149         // Don't generate MapEntry messages, they are generated as associative arrays
150         if (messageType.isMap)
151             return "";
152 
153         auto indentation = "%*s".format(indent, "");
154 
155         auto result = appender!string;
156         result ~= "\n";
157         result ~= indentation;
158         result ~= indent > 0 ? "static " : "";
159         result ~= messageAsStruct ? "struct" : "class";
160         result ~= " ";
161         result ~= messageType.name.escapeKeywords;
162         result ~= "\n";
163         result ~= indentation;
164         result ~= "{\n";
165 
166         int[] generatedOneofs;
167         foreach (field; messageType.fields.sort!((a, b) => a.number < b.number))
168         {
169             if (field.oneofIndex < 0)
170             {
171                 result ~= generateField(field, indent + indentSize);
172                 continue;
173             }
174 
175             if (generatedOneofs.canFind(field.oneofIndex))
176                 continue;
177 
178             result ~= generateOneof(messageType.oneofDecls[field.oneofIndex],
179                 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize);
180             generatedOneofs ~= field.oneofIndex;
181         }
182 
183         foreach (nestedType; messageType.nestedTypes)
184             result ~= generateMessage(nestedType, indent + indentSize);
185 
186         foreach (enumType; messageType.enumTypes)
187             result ~= generateEnum(enumType, indent + indentSize);
188 
189         result ~= indentation;
190         result ~= "}\n";
191 
192         return result.data;
193     }
194 
195     private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true)
196     {
197         import std.format : format;
198 
199         return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field),
200             field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : "");
201     }
202 
203     private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
204     {
205         return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent);
206     }
207 
208     private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
209     {
210         import std.format : format;
211         import std.array : appender;
212 
213         auto result = appender!string;
214         result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true));
215         result ~= "%*s{\n".format(indent, "");
216         result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false));
217         foreach (field; fields)
218             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false),
219                 field.number);
220         result ~= "%*s}\n".format(indent, "");
221         result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "",
222             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
223         result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "",
224             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
225         result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "",
226             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
227 
228         return result.data;
229     }
230 
231     private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
232     {
233         import std.format : format;
234         import std.array : appender;
235 
236         auto result = appender!string;
237         result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false));
238         result ~= "%*s{\n".format(indent, "");
239         foreach (field; fields)
240             result ~= generateOneofField(field, indent + indentSize, field == fields[0]);
241         result ~= "%*s}\n".format(indent, "");
242 
243         return result.data;
244     }
245 
246     private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer)
247     {
248         import std.format : format;
249 
250         return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field),
251             typeName(field), field.name.underscoresToCamelCase(false),
252             printInitializer ? fieldInitializer(field) : "");
253     }
254 
255     private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0)
256     {
257         import std.array : appender, array;
258         import std.format : format;
259 
260         auto result = appender!string;
261         result ~= "\n%*senum %s\n".format(indent, "", enumType.name);
262         result ~= "%*s{\n".format(indent, "");
263 
264         foreach (value; enumType.values)
265             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number);
266 
267         result ~= "%*s}\n".format(indent, "");
268 
269         return result.data;
270     }
271 
272     private DescriptorProto messageType(FieldDescriptorProto field)
273     {
274         return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null;
275     }
276 
277     private EnumDescriptorProto enumType(FieldDescriptorProto field)
278     {
279         return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null;
280     }
281 
282     private Wire wireByField(FieldDescriptorProto field)
283     {
284         final switch (field.type) with (FieldDescriptorProto.Type)
285         {
286         case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64:
287         case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM:
288             return Wire.none;
289         case TYPE_MESSAGE:
290         {
291             auto fieldMessageType = messageType(field);
292 
293             if (fieldMessageType !is null && fieldMessageType.isMap)
294             {
295                 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key));
296                 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value));
297 
298                 return keyWire << 2 | valueWire << 4;
299             }
300             return Wire.none;
301         }
302         case TYPE_SINT32: case TYPE_SINT64:
303             return Wire.zigzag;
304         case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64:
305             return Wire.fixed;
306         case TYPE_GROUP: case TYPE_ERROR:
307             assert(0, "Invalid field type");
308         }
309     }
310 
311     private string baseTypeName(FieldDescriptorProto field)
312     {
313         import std.exception : enforce;
314 
315         final switch (field.type) with (FieldDescriptorProto.Type)
316         {
317         case TYPE_BOOL:
318             return "bool";
319         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
320             return "int";
321         case TYPE_UINT32: case TYPE_FIXED32:
322             return "uint";
323         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
324             return "long";
325         case TYPE_UINT64: case TYPE_FIXED64:
326             return "ulong";
327         case TYPE_FLOAT:
328             return "float";
329         case TYPE_DOUBLE:
330             return "double";
331         case TYPE_STRING:
332             return "string";
333         case TYPE_BYTES:
334             return "bytes";
335         case TYPE_MESSAGE:
336         {
337             auto fieldMessageType = messageType(field);
338             enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~
339                 "' has unknown message type " ~ field.typeName ~ "`");
340             return fieldMessageType.name;
341         }
342         case TYPE_ENUM:
343         {
344             auto fieldEnumType = enumType(field);
345             enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~
346                 "' has unknown enum type ' " ~field.typeName ~ "`");
347             return fieldEnumType.name;
348         }
349         case TYPE_GROUP: case TYPE_ERROR:
350             assert(0, "Invalid field type");
351         }
352     }
353 
354     string typeName(FieldDescriptorProto field)
355     {
356         import std.format : format;
357 
358         string fieldBaseTypeName = baseTypeName(field);
359 
360         auto fieldMessageType = messageType(field);
361 
362         if (fieldMessageType !is null && fieldMessageType.isMap)
363         {
364             auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key);
365             auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value);
366 
367             return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField));
368         }
369 
370         if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED)
371             return fieldBaseTypeName ~ "[]";
372         else
373             return fieldBaseTypeName;
374     }
375 
376     private string fieldProtoFields(FieldDescriptorProto field)
377     {
378         import std.algorithm : stripRight;
379         import std.conv : to;
380         import std.range : join;
381 
382         static string packedByField(FieldDescriptorProto field)
383         {
384             return (field.options && field.options.packed) ? "Yes.packed" : "No.packed";
385         }
386 
387         return [field.number.to!string, wireByField(field).toString, packedByField(field)]
388             .stripRight("No.packed")
389             .stripRight("Wire.none")
390             .join(", ");
391     }
392 
393     private string fieldInitializer(FieldDescriptorProto field)
394     {
395         import std.algorithm : endsWith;
396         import std.format : format;
397 
398         auto fieldTypeName = typeName(field);
399         if (fieldTypeName.endsWith("]"))
400             return " = protoDefaultValue!(%s)".format(fieldTypeName);
401         else
402             return " = protoDefaultValue!%s".format(fieldTypeName);
403     }
404 
405     private string protocVersion;
406     private DescriptorProto[string] collectedMessageTypes;
407     private EnumDescriptorProto[string] collectedEnumTypes;
408 }
409 
410 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber)
411 {
412     import std.algorithm : find;
413     import std.array : empty;
414     import std.exception : enforce;
415     import std.format : format;
416 
417     auto result = messageType.fields.find!(a => a.number == fieldNumber);
418 
419     enforce!CodeGeneratorException(!result.empty,
420         "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber));
421 
422     return result[0];
423 }
424 
425 private bool isMap(DescriptorProto messageType)
426 {
427     return messageType.options && messageType.options.mapEntry;
428 }
429 
430 private enum MapFieldNumber
431 {
432     key = 1,
433     value = 2,
434 }
435 
436 private enum Wire
437 {
438     none,
439     fixed = 1 << 0,
440     zigzag = 1 << 1,
441     fixed_key = 1 << 2,
442     zigzag_key = 1 << 3,
443     fixed_value = 1 << 4,
444     zigzag_value = 1 << 5,
445     fixed_key_fixed_value = fixed_key | fixed_value,
446     fixed_key_zigzag_value = fixed_key | zigzag_value,
447     zigzag_key_fixed_value = zigzag_key | fixed_value,
448     zigzag_key_zigzag_value = zigzag_key | zigzag_value,
449 }
450 
451 private string toString(Wire wire)
452 {
453     final switch (wire) with (Wire)
454     {
455     case none:
456         return "Wire.none";
457     case fixed:
458         return "Wire.fixed";
459     case zigzag:
460         return "Wire.zigzag";
461     case fixed_key:
462         return "Wire.fixedKey";
463     case zigzag_key:
464         return "Wire.zigzagKey";
465     case fixed_value:
466         return "Wire.fixedValue";
467     case zigzag_value:
468         return "Wire.zigzagValue";
469     case fixed_key_fixed_value:
470         return "Wire.fixedKeyFixedValue";
471     case fixed_key_zigzag_value:
472         return "Wire.fixedKeyZigzagValue";
473     case zigzag_key_fixed_value:
474         return "Wire.zigzagKeyFixedValue";
475     case zigzag_key_zigzag_value:
476         return "Wire.zigzagKeyZigzagValue";
477     }
478 }
479 
480 private string moduleName(FileDescriptorProto fileDescriptor)
481 {
482     import std.array : empty;
483     import std.path : baseName;
484 
485     string moduleName = fileDescriptor.name.baseName(".proto");
486 
487     if (!fileDescriptor.package_.empty)
488         moduleName = fileDescriptor.package_ ~ "." ~ moduleName;
489 
490     return moduleName.escapeKeywords;
491 }
492 
493 private string moduleName(string fileName)
494 {
495     import std.array : replace;
496     import std.string : chomp;
497 
498     return fileName.chomp(".proto").replace("/", ".").escapeKeywords;
499 }
500 
501 private string underscoresToCamelCase(string input, bool capitalizeNextLetter)
502 {
503     import std.array : appender;
504 
505     auto result = appender!string;
506 
507     foreach (ubyte c; input)
508     {
509         if (c == '_')
510         {
511             capitalizeNextLetter = true;
512             continue;
513         }
514 
515         if ('a' <= c && c <= 'z' && capitalizeNextLetter)
516             c += 'A' - 'a';
517 
518         result ~= c;
519         capitalizeNextLetter = false;
520     }
521 
522     return result.data.escapeKeywords;
523 }
524 
525 private enum string[] keywords = [
526     "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch",
527     "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default",
528     "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final",
529     "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat",
530     "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro",
531     "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected",
532     "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch",
533     "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent",
534     "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with",
535     "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector",
536     "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__",
537     "__TIMESTAMP__", "__VENDOR__", "__VERSION__",
538 ];
539 
540 private string escapeKeywords(string input, string separator = ".")
541 {
542     import std.algorithm : canFind, joiner, map, splitter;
543     import std.conv : to;
544 
545     return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string;
546 }