1 module protoc_gen_d; 2 3 import google.protobuf; 4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse; 5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, 6 OneofDescriptorProto; 7 8 int main() 9 { 10 import std.algorithm : map; 11 import std.array : array; 12 import std.range : isInputRange, take, walkLength; 13 import std.stdio : stdin, stdout; 14 15 foreach (inputRange; stdin.byChunk(1024 * 1024)) 16 { 17 auto request = inputRange.fromProtobuf!CodeGeneratorRequest; 18 auto codeGenerator = new CodeGenerator; 19 20 stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array); 21 } 22 23 return 0; 24 } 25 26 class CodeGeneratorException : Exception 27 { 28 this(string message = null, string file = __FILE__, size_t line = __LINE__, 29 Throwable next = null) @safe pure nothrow 30 { 31 super(message, file, line, next); 32 } 33 } 34 35 class CodeGenerator 36 { 37 private enum indentSize = 4; 38 private bool messageAsStruct = false; 39 40 CodeGeneratorResponse handle(CodeGeneratorRequest request) 41 { 42 import std.algorithm : filter, map, canFind, splitter; 43 import std.array : array; 44 import std.conv : to; 45 import std.format : format; 46 47 if (request.parameter.splitter(",").canFind("message-as-struct")) 48 messageAsStruct = true; 49 50 if (request.compilerVersion) with (request.compilerVersion) 51 protocVersion = format!"%d%03d%03d"(major, minor, patch); 52 53 collectedMessageTypes.clear; 54 auto response = new CodeGeneratorResponse; 55 try 56 { 57 collectMessageAndEnumTypes(request); 58 response.files = request.protoFiles 59 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types 60 .map!(a => generate(a)).array; 61 } 62 catch (CodeGeneratorException generatorException) 63 { 64 response.error = generatorException.message.to!string; 65 } 66 67 return response; 68 } 69 70 private void collectMessageAndEnumTypes(CodeGeneratorRequest request) 71 { 72 void collect(DescriptorProto messageType, string prefix) 73 { 74 auto absoluteName = prefix ~ "." ~ messageType.name; 75 76 if (absoluteName in collectedMessageTypes) 77 return; 78 79 collectedMessageTypes[absoluteName] = messageType; 80 81 foreach (nestedType; messageType.nestedTypes) 82 collect(nestedType, absoluteName); 83 84 foreach (enumType; messageType.enumTypes) 85 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType; 86 } 87 88 foreach (file; request.protoFiles) 89 { 90 auto packagePrefix = file.package_ ? "." ~ file.package_ : ""; 91 92 foreach (messageType; file.messageTypes) 93 collect(messageType, packagePrefix); 94 95 foreach (enumType; file.enumTypes) 96 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType; 97 } 98 } 99 100 private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor) 101 { 102 import std.array : replace; 103 import std.exception : enforce; 104 105 enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3", 106 "Can only generate D code for proto3 .proto files.\n" ~ 107 "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n"); 108 109 auto file = new CodeGeneratorResponse.File; 110 111 file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d"; 112 file.content = generateFile(fileDescriptor); 113 114 return file; 115 } 116 117 private string generateFile(FileDescriptorProto fileDescriptor) 118 { 119 import std.array : appender, empty; 120 import std.format : format; 121 122 auto result = appender!string; 123 result ~= "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"; 124 result ~= "// source: %s\n\n".format(fileDescriptor.name); 125 result ~= "module %s;\n\n".format(fileDescriptor.moduleName); 126 result ~= "import google.protobuf;\n"; 127 128 foreach (dependency; fileDescriptor.dependencies) 129 result ~= "import %s;\n".format(dependency.moduleName); 130 131 if (!protocVersion.empty) 132 result ~= "\nenum protocVersion = %s;\n".format(protocVersion); 133 134 foreach (messageType; fileDescriptor.messageTypes) 135 result ~= generateMessage(messageType); 136 137 foreach (enumType; fileDescriptor.enumTypes) 138 result ~= generateEnum(enumType); 139 140 return result.data; 141 } 142 143 private string generateMessage(DescriptorProto messageType, size_t indent = 0) 144 { 145 import std.algorithm : canFind, filter, sort; 146 import std.array : appender, array; 147 import std.format : format; 148 149 // Don't generate MapEntry messages, they are generated as associative arrays 150 if (messageType.isMap) 151 return ""; 152 153 auto indentation = "%*s".format(indent, ""); 154 155 auto result = appender!string; 156 result ~= "\n"; 157 result ~= indentation; 158 result ~= indent > 0 ? "static " : ""; 159 result ~= messageAsStruct ? "struct" : "class"; 160 result ~= " "; 161 result ~= messageType.name.escapeKeywords; 162 result ~= "\n"; 163 result ~= indentation; 164 result ~= "{\n"; 165 166 int[] generatedOneofs; 167 foreach (field; messageType.fields.sort!((a, b) => a.number < b.number)) 168 { 169 if (field.oneofIndex < 0) 170 { 171 result ~= generateField(field, indent + indentSize); 172 continue; 173 } 174 175 if (generatedOneofs.canFind(field.oneofIndex)) 176 continue; 177 178 result ~= generateOneof(messageType.oneofDecls[field.oneofIndex], 179 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize); 180 generatedOneofs ~= field.oneofIndex; 181 } 182 183 foreach (nestedType; messageType.nestedTypes) 184 result ~= generateMessage(nestedType, indent + indentSize); 185 186 foreach (enumType; messageType.enumTypes) 187 result ~= generateEnum(enumType, indent + indentSize); 188 189 result ~= indentation; 190 result ~= "}\n"; 191 192 return result.data; 193 } 194 195 private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true) 196 { 197 import std.format : format; 198 199 return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field), 200 field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : ""); 201 } 202 203 private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 204 { 205 return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent); 206 } 207 208 private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 209 { 210 import std.format : format; 211 import std.array : appender; 212 213 auto result = appender!string; 214 result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true)); 215 result ~= "%*s{\n".format(indent, ""); 216 result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false)); 217 foreach (field; fields) 218 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false), 219 field.number); 220 result ~= "%*s}\n".format(indent, ""); 221 result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "", 222 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 223 result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "", 224 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 225 result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "", 226 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 227 228 return result.data; 229 } 230 231 private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 232 { 233 import std.format : format; 234 import std.array : appender; 235 236 auto result = appender!string; 237 result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false)); 238 result ~= "%*s{\n".format(indent, ""); 239 foreach (field; fields) 240 result ~= generateOneofField(field, indent + indentSize, field == fields[0]); 241 result ~= "%*s}\n".format(indent, ""); 242 243 return result.data; 244 } 245 246 private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer) 247 { 248 import std.format : format; 249 250 return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field), 251 typeName(field), field.name.underscoresToCamelCase(false), 252 printInitializer ? fieldInitializer(field) : ""); 253 } 254 255 private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0) 256 { 257 import std.array : appender, array; 258 import std.format : format; 259 260 auto result = appender!string; 261 result ~= "\n%*senum %s\n".format(indent, "", enumType.name); 262 result ~= "%*s{\n".format(indent, ""); 263 264 foreach (value; enumType.values) 265 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number); 266 267 result ~= "%*s}\n".format(indent, ""); 268 269 return result.data; 270 } 271 272 private DescriptorProto messageType(FieldDescriptorProto field) 273 { 274 return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null; 275 } 276 277 private EnumDescriptorProto enumType(FieldDescriptorProto field) 278 { 279 return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null; 280 } 281 282 private Wire wireByField(FieldDescriptorProto field) 283 { 284 final switch (field.type) with (FieldDescriptorProto.Type) 285 { 286 case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64: 287 case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM: 288 return Wire.none; 289 case TYPE_MESSAGE: 290 { 291 auto fieldMessageType = messageType(field); 292 293 if (fieldMessageType !is null && fieldMessageType.isMap) 294 { 295 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key)); 296 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value)); 297 298 return keyWire << 2 | valueWire << 4; 299 } 300 return Wire.none; 301 } 302 case TYPE_SINT32: case TYPE_SINT64: 303 return Wire.zigzag; 304 case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64: 305 return Wire.fixed; 306 case TYPE_GROUP: case TYPE_ERROR: 307 assert(0, "Invalid field type"); 308 } 309 } 310 311 private string baseTypeName(FieldDescriptorProto field) 312 { 313 import std.exception : enforce; 314 315 final switch (field.type) with (FieldDescriptorProto.Type) 316 { 317 case TYPE_BOOL: 318 return "bool"; 319 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 320 return "int"; 321 case TYPE_UINT32: case TYPE_FIXED32: 322 return "uint"; 323 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 324 return "long"; 325 case TYPE_UINT64: case TYPE_FIXED64: 326 return "ulong"; 327 case TYPE_FLOAT: 328 return "float"; 329 case TYPE_DOUBLE: 330 return "double"; 331 case TYPE_STRING: 332 return "string"; 333 case TYPE_BYTES: 334 return "bytes"; 335 case TYPE_MESSAGE: 336 { 337 auto fieldMessageType = messageType(field); 338 enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~ 339 "' has unknown message type " ~ field.typeName ~ "`"); 340 return fieldMessageType.name; 341 } 342 case TYPE_ENUM: 343 { 344 auto fieldEnumType = enumType(field); 345 enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~ 346 "' has unknown enum type ' " ~field.typeName ~ "`"); 347 return fieldEnumType.name; 348 } 349 case TYPE_GROUP: case TYPE_ERROR: 350 assert(0, "Invalid field type"); 351 } 352 } 353 354 private bool isBaseTypePackable(FieldDescriptorProto field) 355 { 356 import std.exception : enforce; 357 358 final switch (field.type) with (FieldDescriptorProto.Type) 359 { 360 case TYPE_BOOL: 361 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 362 case TYPE_UINT32: case TYPE_FIXED32: 363 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 364 case TYPE_UINT64: case TYPE_FIXED64: 365 case TYPE_FLOAT: 366 case TYPE_DOUBLE: 367 return true; 368 case TYPE_STRING: 369 case TYPE_BYTES: 370 case TYPE_MESSAGE: 371 case TYPE_ENUM: 372 return false; 373 case TYPE_GROUP: case TYPE_ERROR: 374 assert(0, "Invalid field type"); 375 } 376 } 377 378 string typeName(FieldDescriptorProto field) 379 { 380 import std.format : format; 381 382 string fieldBaseTypeName = baseTypeName(field); 383 384 auto fieldMessageType = messageType(field); 385 386 if (fieldMessageType !is null && fieldMessageType.isMap) 387 { 388 auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key); 389 auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value); 390 391 return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField)); 392 } 393 394 if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED) 395 return fieldBaseTypeName ~ "[]"; 396 else 397 return fieldBaseTypeName; 398 } 399 400 private string fieldProtoFields(FieldDescriptorProto field) 401 { 402 import std.algorithm : stripRight; 403 import std.conv : to; 404 import std.range : join; 405 406 string packedByField(FieldDescriptorProto field) 407 { 408 if (field.label != FieldDescriptorProto.Label.LABEL_REPEATED) 409 return "No.packed"; 410 411 if (!isBaseTypePackable(field)) 412 return "No.packed"; 413 414 return (!field.options || field.options.packed) ? "Yes.packed" : "No.packed"; 415 } 416 417 return [field.number.to!string, wireByField(field).toString, packedByField(field)] 418 .stripRight("No.packed") 419 .stripRight("Wire.none") 420 .join(", "); 421 } 422 423 private string fieldInitializer(FieldDescriptorProto field) 424 { 425 import std.algorithm : endsWith; 426 import std.format : format; 427 428 auto fieldTypeName = typeName(field); 429 if (fieldTypeName.endsWith("]")) 430 return " = protoDefaultValue!(%s)".format(fieldTypeName); 431 else 432 return " = protoDefaultValue!%s".format(fieldTypeName); 433 } 434 435 private string protocVersion; 436 private DescriptorProto[string] collectedMessageTypes; 437 private EnumDescriptorProto[string] collectedEnumTypes; 438 } 439 440 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber) 441 { 442 import std.algorithm : find; 443 import std.array : empty; 444 import std.exception : enforce; 445 import std.format : format; 446 447 auto result = messageType.fields.find!(a => a.number == fieldNumber); 448 449 enforce!CodeGeneratorException(!result.empty, 450 "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber)); 451 452 return result[0]; 453 } 454 455 private bool isMap(DescriptorProto messageType) 456 { 457 return messageType.options && messageType.options.mapEntry; 458 } 459 460 private enum MapFieldNumber 461 { 462 key = 1, 463 value = 2, 464 } 465 466 private enum Wire 467 { 468 none, 469 fixed = 1 << 0, 470 zigzag = 1 << 1, 471 fixed_key = 1 << 2, 472 zigzag_key = 1 << 3, 473 fixed_value = 1 << 4, 474 zigzag_value = 1 << 5, 475 fixed_key_fixed_value = fixed_key | fixed_value, 476 fixed_key_zigzag_value = fixed_key | zigzag_value, 477 zigzag_key_fixed_value = zigzag_key | fixed_value, 478 zigzag_key_zigzag_value = zigzag_key | zigzag_value, 479 } 480 481 private string toString(Wire wire) 482 { 483 final switch (wire) with (Wire) 484 { 485 case none: 486 return "Wire.none"; 487 case fixed: 488 return "Wire.fixed"; 489 case zigzag: 490 return "Wire.zigzag"; 491 case fixed_key: 492 return "Wire.fixedKey"; 493 case zigzag_key: 494 return "Wire.zigzagKey"; 495 case fixed_value: 496 return "Wire.fixedValue"; 497 case zigzag_value: 498 return "Wire.zigzagValue"; 499 case fixed_key_fixed_value: 500 return "Wire.fixedKeyFixedValue"; 501 case fixed_key_zigzag_value: 502 return "Wire.fixedKeyZigzagValue"; 503 case zigzag_key_fixed_value: 504 return "Wire.zigzagKeyFixedValue"; 505 case zigzag_key_zigzag_value: 506 return "Wire.zigzagKeyZigzagValue"; 507 } 508 } 509 510 private string moduleName(FileDescriptorProto fileDescriptor) 511 { 512 import std.array : empty; 513 import std.path : baseName; 514 515 string moduleName = fileDescriptor.name.baseName(".proto"); 516 517 if (!fileDescriptor.package_.empty) 518 moduleName = fileDescriptor.package_ ~ "." ~ moduleName; 519 520 return moduleName.escapeKeywords; 521 } 522 523 private string moduleName(string fileName) 524 { 525 import std.array : replace; 526 import std.string : chomp; 527 528 return fileName.chomp(".proto").replace("/", ".").escapeKeywords; 529 } 530 531 private string underscoresToCamelCase(string input, bool capitalizeNextLetter) 532 { 533 import std.array : appender; 534 535 auto result = appender!string; 536 537 foreach (ubyte c; input) 538 { 539 if (c == '_') 540 { 541 capitalizeNextLetter = true; 542 continue; 543 } 544 545 if ('a' <= c && c <= 'z' && capitalizeNextLetter) 546 c += 'A' - 'a'; 547 548 result ~= c; 549 capitalizeNextLetter = false; 550 } 551 552 return result.data.escapeKeywords; 553 } 554 555 private enum string[] keywords = [ 556 "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch", 557 "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default", 558 "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final", 559 "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat", 560 "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro", 561 "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected", 562 "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch", 563 "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent", 564 "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with", 565 "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector", 566 "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__", 567 "__TIMESTAMP__", "__VENDOR__", "__VERSION__", 568 ]; 569 570 private string escapeKeywords(string input, string separator = ".") 571 { 572 import std.algorithm : canFind, joiner, map, splitter; 573 import std.conv : to; 574 575 return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string; 576 }