1 module protoc_gen_d;
2 
3 import google.protobuf;
4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse;
5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
6     OneofDescriptorProto;
7 
8 int main()
9 {
10     import std.algorithm : map;
11     import std.array : array;
12     import std.range : isInputRange, take, walkLength;
13     import std.stdio : stdin, stdout;
14 
15     foreach (inputRange; stdin.byChunk(1024 * 1024))
16     {
17         auto request = inputRange.fromProtobuf!CodeGeneratorRequest;
18         auto codeGenerator = new CodeGenerator;
19 
20         stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array);
21     }
22 
23     return 0;
24 }
25 
26 class CodeGeneratorException : Exception
27 {
28     this(string message = null, string file = __FILE__, size_t line = __LINE__,
29         Throwable next = null) @safe pure nothrow
30     {
31         super(message, file, line, next);
32     }
33 }
34 
35 class CodeGenerator
36 {
37     private enum indentSize = 4;
38     private bool messageAsStruct = false;
39 
40     CodeGeneratorResponse handle(CodeGeneratorRequest request)
41     {
42         import std.algorithm : filter, map, canFind, splitter;
43         import std.array : array;
44         import std.conv : to;
45         import std.format : format;
46 
47         if (request.parameter.splitter(",").canFind("message-as-struct"))
48             messageAsStruct = true;
49 
50         if (request.compilerVersion) with (request.compilerVersion)
51             protocVersion = format!"%d%03d%03d"(major, minor, patch);
52 
53         collectedMessageTypes.clear;
54         auto response = new CodeGeneratorResponse;
55         try
56         {
57             collectMessageAndEnumTypes(request);
58             response.files = request.protoFiles
59                 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types
60                 .map!(a => generate(a)).array;
61         }
62         catch (CodeGeneratorException generatorException)
63         {
64             response.error = generatorException.message.to!string;
65         }
66 
67         return response;
68     }
69 
70     private void collectMessageAndEnumTypes(CodeGeneratorRequest request)
71     {
72         void collect(DescriptorProto messageType, string prefix)
73         {
74             auto absoluteName = prefix ~ "." ~ messageType.name;
75 
76             if (absoluteName in collectedMessageTypes)
77                 return;
78 
79             collectedMessageTypes[absoluteName] = messageType;
80 
81             foreach (nestedType; messageType.nestedTypes)
82                 collect(nestedType, absoluteName);
83 
84             foreach (enumType; messageType.enumTypes)
85                 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType;
86         }
87 
88         foreach (file; request.protoFiles)
89         {
90             auto packagePrefix = file.package_ ? "." ~ file.package_ : "";
91 
92             foreach (messageType; file.messageTypes)
93                 collect(messageType, packagePrefix);
94 
95             foreach (enumType; file.enumTypes)
96                 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType;
97         }
98     }
99 
100     private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor)
101     {
102         import std.array : replace;
103         import std.exception : enforce;
104 
105         enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3",
106             "Can only generate D code for proto3 .proto files.\n" ~
107             "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n");
108 
109         auto file = new CodeGeneratorResponse.File;
110 
111         file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d";
112         file.content = generateFile(fileDescriptor);
113 
114         return file;
115     }
116 
117     private string generateFile(FileDescriptorProto fileDescriptor)
118     {
119         import std.array : appender, empty;
120         import std.format : format;
121 
122         auto result = appender!string;
123         result ~= "// Generated by the protocol buffer compiler.  DO NOT EDIT!\n";
124         result ~= "// source: %s\n\n".format(fileDescriptor.name);
125         result ~= "module %s;\n\n".format(fileDescriptor.moduleName);
126         result ~= "import google.protobuf;\n";
127 
128         foreach (dependency; fileDescriptor.dependencies)
129             result ~= "import %s;\n".format(dependency.moduleName);
130 
131         if (!protocVersion.empty)
132             result ~= "\nenum protocVersion = %s;\n".format(protocVersion);
133 
134         foreach (messageType; fileDescriptor.messageTypes)
135             result ~= generateMessage(messageType);
136 
137         foreach (enumType; fileDescriptor.enumTypes)
138             result ~= generateEnum(enumType);
139 
140         return result.data;
141     }
142 
143     private string generateMessage(DescriptorProto messageType, size_t indent = 0)
144     {
145         import std.algorithm : canFind, filter, sort;
146         import std.array : appender, array;
147         import std.format : format;
148 
149         // Don't generate MapEntry messages, they are generated as associative arrays
150         if (messageType.isMap)
151             return "";
152 
153         auto indentation = "%*s".format(indent, "");
154 
155         auto result = appender!string;
156         result ~= "\n";
157         result ~= indentation;
158         result ~= indent > 0 ? "static " : "";
159         result ~= messageAsStruct ? "struct" : "class";
160         result ~= " ";
161         result ~= messageType.name.escapeKeywords;
162         result ~= "\n";
163         result ~= indentation;
164         result ~= "{\n";
165 
166         int[] generatedOneofs;
167         foreach (field; messageType.fields.sort!((a, b) => a.number < b.number))
168         {
169             if (field.oneofIndex < 0)
170             {
171                 result ~= generateField(field, indent + indentSize);
172                 continue;
173             }
174 
175             if (generatedOneofs.canFind(field.oneofIndex))
176                 continue;
177 
178             result ~= generateOneof(messageType.oneofDecls[field.oneofIndex],
179                 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize);
180             generatedOneofs ~= field.oneofIndex;
181         }
182 
183         foreach (nestedType; messageType.nestedTypes)
184             result ~= generateMessage(nestedType, indent + indentSize);
185 
186         foreach (enumType; messageType.enumTypes)
187             result ~= generateEnum(enumType, indent + indentSize);
188 
189         result ~= indentation;
190         result ~= "}\n";
191 
192         return result.data;
193     }
194 
195     private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true)
196     {
197         import std.format : format;
198 
199         return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field),
200             field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : "");
201     }
202 
203     private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
204     {
205         return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent);
206     }
207 
208     private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
209     {
210         import std.format : format;
211         import std.array : appender;
212 
213         auto result = appender!string;
214         result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true));
215         result ~= "%*s{\n".format(indent, "");
216         result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false));
217         foreach (field; fields)
218             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false),
219                 field.number);
220         result ~= "%*s}\n".format(indent, "");
221         result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "",
222             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
223         result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "",
224             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
225         result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "",
226             oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false));
227 
228         return result.data;
229     }
230 
231     private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent)
232     {
233         import std.format : format;
234         import std.array : appender;
235 
236         auto result = appender!string;
237         result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false));
238         result ~= "%*s{\n".format(indent, "");
239         foreach (field; fields)
240             result ~= generateOneofField(field, indent + indentSize, field == fields[0]);
241         result ~= "%*s}\n".format(indent, "");
242 
243         return result.data;
244     }
245 
246     private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer)
247     {
248         import std.format : format;
249 
250         return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field),
251             typeName(field), field.name.underscoresToCamelCase(false),
252             printInitializer ? fieldInitializer(field) : "");
253     }
254 
255     private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0)
256     {
257         import std.array : appender, array;
258         import std.format : format;
259 
260         auto result = appender!string;
261         result ~= "\n%*senum %s\n".format(indent, "", enumType.name);
262         result ~= "%*s{\n".format(indent, "");
263 
264         foreach (value; enumType.values)
265             result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number);
266 
267         result ~= "%*s}\n".format(indent, "");
268 
269         return result.data;
270     }
271 
272     private DescriptorProto messageType(FieldDescriptorProto field)
273     {
274         return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null;
275     }
276 
277     private EnumDescriptorProto enumType(FieldDescriptorProto field)
278     {
279         return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null;
280     }
281 
282     private Wire wireByField(FieldDescriptorProto field)
283     {
284         final switch (field.type) with (FieldDescriptorProto.Type)
285         {
286         case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64:
287         case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM:
288             return Wire.none;
289         case TYPE_MESSAGE:
290         {
291             auto fieldMessageType = messageType(field);
292 
293             if (fieldMessageType !is null && fieldMessageType.isMap)
294             {
295                 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key));
296                 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value));
297 
298                 return keyWire << 2 | valueWire << 4;
299             }
300             return Wire.none;
301         }
302         case TYPE_SINT32: case TYPE_SINT64:
303             return Wire.zigzag;
304         case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64:
305             return Wire.fixed;
306         case TYPE_GROUP: case TYPE_ERROR:
307             assert(0, "Invalid field type");
308         }
309     }
310 
311     private string baseTypeName(FieldDescriptorProto field)
312     {
313         import std.exception : enforce;
314 
315         final switch (field.type) with (FieldDescriptorProto.Type)
316         {
317         case TYPE_BOOL:
318             return "bool";
319         case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32:
320             return "int";
321         case TYPE_UINT32: case TYPE_FIXED32:
322             return "uint";
323         case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64:
324             return "long";
325         case TYPE_UINT64: case TYPE_FIXED64:
326             return "ulong";
327         case TYPE_FLOAT:
328             return "float";
329         case TYPE_DOUBLE:
330             return "double";
331         case TYPE_STRING:
332             return "string";
333         case TYPE_BYTES:
334             return "bytes";
335         case TYPE_MESSAGE:
336         {
337             auto fieldMessageType = messageType(field);
338             enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~
339                 "' has unknown message type " ~ field.typeName ~ "`");
340             return fieldMessageType.name;
341         }
342         case TYPE_ENUM:
343         {
344             auto fieldEnumType = enumType(field);
345             enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~
346                 "' has unknown enum type ' " ~field.typeName ~ "`");
347             return fieldEnumType.name;
348         }
349         case TYPE_GROUP: case TYPE_ERROR:
350             assert(0, "Invalid field type");
351         }
352     }
353 
354     string typeName(FieldDescriptorProto field)
355     {
356         import std.format : format;
357 
358         string fieldBaseTypeName = baseTypeName(field);
359 
360         auto fieldMessageType = messageType(field);
361 
362         if (fieldMessageType !is null && fieldMessageType.isMap)
363         {
364             auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key);
365             auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value);
366 
367             return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField));
368         }
369 
370         if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED)
371             return fieldBaseTypeName ~ "[]";
372         else
373             return fieldBaseTypeName;
374     }
375 
376     private string fieldProtoFields(FieldDescriptorProto field)
377     {
378         import std.algorithm : stripRight;
379         import std.conv : to;
380         import std.range : join;
381 
382         return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", ");
383     }
384 
385     private string fieldInitializer(FieldDescriptorProto field)
386     {
387         import std.algorithm : endsWith;
388         import std.format : format;
389 
390         auto fieldTypeName = typeName(field);
391         if (fieldTypeName.endsWith("]"))
392             return " = protoDefaultValue!(%s)".format(fieldTypeName);
393         else
394             return " = protoDefaultValue!%s".format(fieldTypeName);
395     }
396 
397     private string protocVersion;
398     private DescriptorProto[string] collectedMessageTypes;
399     private EnumDescriptorProto[string] collectedEnumTypes;
400 }
401 
402 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber)
403 {
404     import std.algorithm : find;
405     import std.array : empty;
406     import std.exception : enforce;
407     import std.format : format;
408 
409     auto result = messageType.fields.find!(a => a.number == fieldNumber);
410 
411     enforce!CodeGeneratorException(!result.empty,
412         "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber));
413 
414     return result[0];
415 }
416 
417 private bool isMap(DescriptorProto messageType)
418 {
419     return messageType.options && messageType.options.mapEntry;
420 }
421 
422 private enum MapFieldNumber
423 {
424     key = 1,
425     value = 2,
426 }
427 
428 private enum Wire
429 {
430     none,
431     fixed = 1 << 0,
432     zigzag = 1 << 1,
433     fixed_key = 1 << 2,
434     zigzag_key = 1 << 3,
435     fixed_value = 1 << 4,
436     zigzag_value = 1 << 5,
437     fixed_key_fixed_value = fixed_key | fixed_value,
438     fixed_key_zigzag_value = fixed_key | zigzag_value,
439     zigzag_key_fixed_value = zigzag_key | fixed_value,
440     zigzag_key_zigzag_value = zigzag_key | zigzag_value,
441 }
442 
443 private string toString(Wire wire)
444 {
445     final switch (wire) with (Wire)
446     {
447     case none:
448         return "Wire.none";
449     case fixed:
450         return "Wire.fixed";
451     case zigzag:
452         return "Wire.zigzag";
453     case fixed_key:
454         return "Wire.fixedKey";
455     case zigzag_key:
456         return "Wire.zigzagKey";
457     case fixed_value:
458         return "Wire.fixedValue";
459     case zigzag_value:
460         return "Wire.zigzagValue";
461     case fixed_key_fixed_value:
462         return "Wire.fixedKeyFixedValue";
463     case fixed_key_zigzag_value:
464         return "Wire.fixedKeyZigzagValue";
465     case zigzag_key_fixed_value:
466         return "Wire.zigzagKeyFixedValue";
467     case zigzag_key_zigzag_value:
468         return "Wire.zigzagKeyZigzagValue";
469     }
470 }
471 
472 private string moduleName(FileDescriptorProto fileDescriptor)
473 {
474     import std.array : empty;
475     import std.path : baseName;
476 
477     string moduleName = fileDescriptor.name.baseName(".proto");
478 
479     if (!fileDescriptor.package_.empty)
480         moduleName = fileDescriptor.package_ ~ "." ~ moduleName;
481 
482     return moduleName.escapeKeywords;
483 }
484 
485 private string moduleName(string fileName)
486 {
487     import std.array : replace;
488     import std..string : chomp;
489 
490     return fileName.chomp(".proto").replace("/", ".").escapeKeywords;
491 }
492 
493 private string underscoresToCamelCase(string input, bool capitalizeNextLetter)
494 {
495     import std.array : appender;
496 
497     auto result = appender!string;
498 
499     foreach (ubyte c; input)
500     {
501         if (c == '_')
502         {
503             capitalizeNextLetter = true;
504             continue;
505         }
506 
507         if ('a' <= c && c <= 'z' && capitalizeNextLetter)
508             c += 'A' - 'a';
509 
510         result ~= c;
511         capitalizeNextLetter = false;
512     }
513 
514     return result.data.escapeKeywords;
515 }
516 
517 private enum string[] keywords = [
518     "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch",
519     "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default",
520     "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final",
521     "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat",
522     "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro",
523     "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected",
524     "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch",
525     "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent",
526     "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with",
527     "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector",
528     "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__",
529     "__TIMESTAMP__", "__VENDOR__", "__VERSION__",
530 ];
531 
532 private string escapeKeywords(string input, string separator = ".")
533 {
534     import std.algorithm : canFind, joiner, map, splitter;
535     import std.conv : to;
536 
537     return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string;
538 }