1 module protoc_gen_d; 2 3 import google.protobuf; 4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse; 5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, 6 OneofDescriptorProto; 7 8 int main() 9 { 10 import std.algorithm : map; 11 import std.array : array; 12 import std.range : isInputRange, take, walkLength; 13 import std.stdio : stdin, stdout; 14 15 foreach (inputRange; stdin.byChunk(1024 * 1024)) 16 { 17 auto request = inputRange.fromProtobuf!CodeGeneratorRequest; 18 auto codeGenerator = new CodeGenerator; 19 20 stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array); 21 } 22 23 return 0; 24 } 25 26 class CodeGeneratorException : Exception 27 { 28 this(string message = null, string file = __FILE__, size_t line = __LINE__, 29 Throwable next = null) @safe pure nothrow 30 { 31 super(message, file, line, next); 32 } 33 } 34 35 class CodeGenerator 36 { 37 private enum indentSize = 4; 38 39 CodeGeneratorResponse handle(CodeGeneratorRequest request) 40 { 41 import std.algorithm : filter, map; 42 import std.array : array; 43 import std.conv : to; 44 import std.format : format; 45 46 with (request.compilerVersion) 47 protocVersion = format!"%d%03d%03d"(major, minor, patch); 48 49 collectedMessageTypes.clear; 50 auto response = new CodeGeneratorResponse; 51 try 52 { 53 collectMessageAndEnumTypes(request); 54 response.files = request.protoFiles 55 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types 56 .map!(a => generate(a)).array; 57 } 58 catch (CodeGeneratorException generatorException) 59 { 60 response.error = generatorException.message.to!string; 61 } 62 63 return response; 64 } 65 66 private void collectMessageAndEnumTypes(CodeGeneratorRequest request) 67 { 68 void collect(DescriptorProto messageType, string prefix) 69 { 70 auto absoluteName = prefix ~ "." ~ messageType.name; 71 72 if (absoluteName in collectedMessageTypes) 73 return; 74 75 collectedMessageTypes[absoluteName] = messageType; 76 77 foreach (nestedType; messageType.nestedTypes) 78 collect(nestedType, absoluteName); 79 80 foreach (enumType; messageType.enumTypes) 81 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType; 82 } 83 84 foreach (file; request.protoFiles) 85 { 86 foreach (messageType; file.messageTypes) 87 collect(messageType, file.package_ ? "." ~ file.package_ : ""); 88 89 foreach (enumType; file.enumTypes) 90 collectedEnumTypes["." ~ enumType.name] = enumType; 91 } 92 } 93 94 private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor) 95 { 96 import std.array : replace; 97 import std.exception : enforce; 98 99 enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3", 100 "Can only generate D code for proto3 .proto files.\n" ~ 101 "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n"); 102 103 auto file = new CodeGeneratorResponse.File; 104 105 file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d"; 106 file.content = generateFile(fileDescriptor); 107 108 return file; 109 } 110 111 private string generateFile(FileDescriptorProto fileDescriptor) 112 { 113 import std.array : appender, empty; 114 import std.format : format; 115 116 auto result = appender!string; 117 result ~= "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"; 118 result ~= "// source: %s\n\n".format(fileDescriptor.name); 119 result ~= "module %s;\n\n".format(fileDescriptor.moduleName); 120 result ~= "import google.protobuf;\n"; 121 122 foreach (dependency; fileDescriptor.dependencies) 123 result ~= "import %s;\n".format(dependency.moduleName); 124 125 if (!protocVersion.empty) 126 result ~= "\nenum protocVersion = %s;\n".format(protocVersion); 127 128 foreach (messageType; fileDescriptor.messageTypes) 129 result ~= generateMessage(messageType); 130 131 foreach (enumType; fileDescriptor.enumTypes) 132 result ~= generateEnum(enumType); 133 134 return result.data; 135 } 136 137 private string generateMessage(DescriptorProto messageType, size_t indent = 0) 138 { 139 import std.algorithm : canFind, filter, sort; 140 import std.array : appender, array; 141 import std.format : format; 142 143 // Don't generate MapEntry messages, they are generated as associative arrays 144 if (messageType.isMap) 145 return ""; 146 147 auto result = appender!string; 148 result ~= "\n%*s%sclass %s\n".format(indent, "", indent > 0 ? "static " : "", messageType.name.escapeKeywords); 149 result ~= "%*s{\n".format(indent, ""); 150 151 int[] generatedOneofs; 152 foreach (field; messageType.fields.sort!((a, b) => a.number < b.number)) 153 { 154 if (field.oneofIndex < 0) 155 { 156 result ~= generateField(field, indent + indentSize); 157 continue; 158 } 159 160 if (generatedOneofs.canFind(field.oneofIndex)) 161 continue; 162 163 result ~= generateOneof(messageType.oneofDecls[field.oneofIndex], 164 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize); 165 generatedOneofs ~= field.oneofIndex; 166 } 167 168 foreach (nestedType; messageType.nestedTypes) 169 result ~= generateMessage(nestedType, indent + indentSize); 170 171 foreach (enumType; messageType.enumTypes) 172 result ~= generateEnum(enumType, indent + indentSize); 173 174 result ~= "%*s}\n".format(indent, ""); 175 176 return result.data; 177 } 178 179 private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true) 180 { 181 import std.format : format; 182 183 return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field), 184 field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : ""); 185 } 186 187 private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 188 { 189 return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent); 190 } 191 192 private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 193 { 194 import std.format : format; 195 import std.array : appender; 196 197 auto result = appender!string; 198 result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true)); 199 result ~= "%*s{\n".format(indent, ""); 200 result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false)); 201 foreach (field; fields) 202 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false), 203 field.number); 204 result ~= "%*s}\n".format(indent, ""); 205 result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "", 206 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 207 result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "", 208 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 209 result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "", 210 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 211 212 return result.data; 213 } 214 215 private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 216 { 217 import std.format : format; 218 import std.array : appender; 219 220 auto result = appender!string; 221 result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false)); 222 result ~= "%*s{\n".format(indent, ""); 223 foreach (field; fields) 224 result ~= generateOneofField(field, indent + indentSize, field == fields[0]); 225 result ~= "%*s}\n".format(indent, ""); 226 227 return result.data; 228 } 229 230 private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer) 231 { 232 import std.format : format; 233 234 return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field), 235 typeName(field), field.name.underscoresToCamelCase(false), 236 printInitializer ? fieldInitializer(field) : ""); 237 } 238 239 private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0) 240 { 241 import std.array : appender, array; 242 import std.format : format; 243 244 auto result = appender!string; 245 result ~= "\n%*senum %s\n".format(indent, "", enumType.name); 246 result ~= "%*s{\n".format(indent, ""); 247 248 foreach (value; enumType.values) 249 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number); 250 251 result ~= "%*s}\n".format(indent, ""); 252 253 return result.data; 254 } 255 256 private DescriptorProto messageType(FieldDescriptorProto field) 257 { 258 return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null; 259 } 260 261 private EnumDescriptorProto enumType(FieldDescriptorProto field) 262 { 263 return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null; 264 } 265 266 private Wire wireByField(FieldDescriptorProto field) 267 { 268 final switch (field.type) with (FieldDescriptorProto.Type) 269 { 270 case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64: 271 case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM: 272 return Wire.none; 273 case TYPE_MESSAGE: 274 { 275 auto fieldMessageType = messageType(field); 276 277 if (fieldMessageType !is null && fieldMessageType.isMap) 278 { 279 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key)); 280 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value)); 281 282 return keyWire << 2 | valueWire << 4; 283 } 284 return Wire.none; 285 } 286 case TYPE_SINT32: case TYPE_SINT64: 287 return Wire.zigzag; 288 case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64: 289 return Wire.fixed; 290 case TYPE_GROUP: case TYPE_ERROR: 291 assert(0, "Invalid field type"); 292 } 293 } 294 295 private string baseTypeName(FieldDescriptorProto field) 296 { 297 import std.exception : enforce; 298 299 final switch (field.type) with (FieldDescriptorProto.Type) 300 { 301 case TYPE_BOOL: 302 return "bool"; 303 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 304 return "int"; 305 case TYPE_UINT32: case TYPE_FIXED32: 306 return "uint"; 307 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 308 return "long"; 309 case TYPE_UINT64: case TYPE_FIXED64: 310 return "ulong"; 311 case TYPE_FLOAT: 312 return "float"; 313 case TYPE_DOUBLE: 314 return "double"; 315 case TYPE_STRING: 316 return "string"; 317 case TYPE_BYTES: 318 return "bytes"; 319 case TYPE_MESSAGE: 320 { 321 auto fieldMessageType = messageType(field); 322 enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~ 323 "' has unknown message type " ~ field.typeName ~ "`"); 324 return fieldMessageType.name; 325 } 326 case TYPE_ENUM: 327 { 328 auto fieldEnumType = enumType(field); 329 enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~ 330 "' has unknown enum type ' " ~field.typeName ~ "`"); 331 return fieldEnumType.name; 332 } 333 case TYPE_GROUP: case TYPE_ERROR: 334 assert(0, "Invalid field type"); 335 } 336 } 337 338 string typeName(FieldDescriptorProto field) 339 { 340 import std.format : format; 341 342 string fieldBaseTypeName = baseTypeName(field); 343 344 auto fieldMessageType = messageType(field); 345 346 if (fieldMessageType !is null && fieldMessageType.isMap) 347 { 348 auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key); 349 auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value); 350 351 return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField)); 352 } 353 354 if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED) 355 return fieldBaseTypeName ~ "[]"; 356 else 357 return fieldBaseTypeName; 358 } 359 360 private string fieldProtoFields(FieldDescriptorProto field) 361 { 362 import std.algorithm : stripRight; 363 import std.conv : to; 364 import std.range : join; 365 366 return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", "); 367 } 368 369 private string fieldInitializer(FieldDescriptorProto field) 370 { 371 import std.algorithm : endsWith; 372 import std.format : format; 373 374 auto fieldTypeName = typeName(field); 375 if (fieldTypeName.endsWith("]")) 376 return " = protoDefaultValue!(%s)".format(fieldTypeName); 377 else 378 return " = protoDefaultValue!%s".format(fieldTypeName); 379 } 380 381 private string protocVersion; 382 private DescriptorProto[string] collectedMessageTypes; 383 private EnumDescriptorProto[string] collectedEnumTypes; 384 } 385 386 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber) 387 { 388 import std.algorithm : find; 389 import std.array : empty; 390 import std.exception : enforce; 391 import std.format : format; 392 393 auto result = messageType.fields.find!(a => a.number == fieldNumber); 394 395 enforce!CodeGeneratorException(!result.empty, 396 "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber)); 397 398 return result[0]; 399 } 400 401 private bool isMap(DescriptorProto messageType) 402 { 403 return messageType.options && messageType.options.mapEntry; 404 } 405 406 private enum MapFieldNumber 407 { 408 key = 1, 409 value = 2, 410 } 411 412 private enum Wire 413 { 414 none, 415 fixed = 1 << 0, 416 zigzag = 1 << 1, 417 fixed_key = 1 << 2, 418 zigzag_key = 1 << 3, 419 fixed_value = 1 << 4, 420 zigzag_value = 1 << 5, 421 fixed_key_fixed_value = fixed_key | fixed_value, 422 fixed_key_zigzag_value = fixed_key | zigzag_value, 423 zigzag_key_fixed_value = zigzag_key | fixed_value, 424 zigzag_key_zigzag_value = zigzag_key | zigzag_value, 425 } 426 427 private string toString(Wire wire) 428 { 429 final switch (wire) with (Wire) 430 { 431 case none: 432 return "Wire.none"; 433 case fixed: 434 return "Wire.fixed"; 435 case zigzag: 436 return "Wire.zigzag"; 437 case fixed_key: 438 return "Wire.fixedKey"; 439 case zigzag_key: 440 return "Wire.zigzagKey"; 441 case fixed_value: 442 return "Wire.fixedValue"; 443 case zigzag_value: 444 return "Wire.zigzagValue"; 445 case fixed_key_fixed_value: 446 return "Wire.fixedKeyFixedValue"; 447 case fixed_key_zigzag_value: 448 return "Wire.fixedKeyZigzagValue"; 449 case zigzag_key_fixed_value: 450 return "Wire.zigzagKeyFixedValue"; 451 case zigzag_key_zigzag_value: 452 return "Wire.zigzagKeyZigzagValue"; 453 } 454 } 455 456 private string moduleName(FileDescriptorProto fileDescriptor) 457 { 458 import std.array : empty; 459 import std.path : baseName; 460 461 string moduleName = fileDescriptor.name.baseName(".proto"); 462 463 if (!fileDescriptor.package_.empty) 464 moduleName = fileDescriptor.package_ ~ "." ~ moduleName; 465 466 return moduleName.escapeKeywords; 467 } 468 469 private string moduleName(string fileName) 470 { 471 import std.array : replace; 472 import std..string : chomp; 473 474 return fileName.chomp(".proto").replace("/", "."); 475 } 476 477 private string underscoresToCamelCase(string input, bool capitalizeNextLetter) 478 { 479 import std.array : appender; 480 481 auto result = appender!string; 482 483 foreach (ubyte c; input) 484 { 485 if (c == '_') 486 { 487 capitalizeNextLetter = true; 488 continue; 489 } 490 491 if ('a' <= c && c <= 'z' && capitalizeNextLetter) 492 c += 'A' - 'a'; 493 494 result ~= c; 495 capitalizeNextLetter = false; 496 } 497 498 return result.data.escapeKeywords; 499 } 500 501 private enum string[] keywords = [ 502 "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch", 503 "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default", 504 "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final", 505 "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat", 506 "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro", 507 "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected", 508 "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch", 509 "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent", 510 "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with", 511 "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector", 512 "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__", 513 "__TIMESTAMP__", "__VENDOR__", "__VERSION__", 514 ]; 515 516 private string escapeKeywords(string input, string separator = ".") 517 { 518 import std.algorithm : canFind, joiner, map, splitter; 519 import std.conv : to; 520 521 return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string; 522 }