1 module protoc_gen_d; 2 3 import google.protobuf; 4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse; 5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, 6 OneofDescriptorProto; 7 8 int main() 9 { 10 import std.algorithm : map; 11 import std.array : array; 12 import std.range : isInputRange, take, walkLength; 13 import std.stdio : stdin, stdout; 14 15 foreach (inputRange; stdin.byChunk(1024 * 1024)) 16 { 17 auto request = inputRange.fromProtobuf!CodeGeneratorRequest; 18 auto codeGenerator = new CodeGenerator; 19 20 stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array); 21 } 22 23 return 0; 24 } 25 26 class CodeGeneratorException : Exception 27 { 28 this(string message = null, string file = __FILE__, size_t line = __LINE__, 29 Throwable next = null) @safe pure nothrow 30 { 31 super(message, file, line, next); 32 } 33 } 34 35 class CodeGenerator 36 { 37 private enum indentSize = 4; 38 private bool messageAsStruct = false; 39 40 CodeGeneratorResponse handle(CodeGeneratorRequest request) 41 { 42 import std.algorithm : filter, map, canFind, splitter; 43 import std.array : array; 44 import std.conv : to; 45 import std.format : format; 46 47 if (request.parameter.splitter(",").canFind("message-as-struct")) 48 messageAsStruct = true; 49 50 if (request.compilerVersion) with (request.compilerVersion) 51 protocVersion = format!"%d%03d%03d"(major, minor, patch); 52 53 collectedMessageTypes.clear; 54 auto response = new CodeGeneratorResponse; 55 try 56 { 57 collectMessageAndEnumTypes(request); 58 response.files = request.protoFiles 59 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types 60 .map!(a => generate(a)).array; 61 } 62 catch (CodeGeneratorException generatorException) 63 { 64 response.error = generatorException.message.to!string; 65 } 66 67 return response; 68 } 69 70 private void collectMessageAndEnumTypes(CodeGeneratorRequest request) 71 { 72 void collect(DescriptorProto messageType, string prefix) 73 { 74 auto absoluteName = prefix ~ "." ~ messageType.name; 75 76 if (absoluteName in collectedMessageTypes) 77 return; 78 79 collectedMessageTypes[absoluteName] = messageType; 80 81 foreach (nestedType; messageType.nestedTypes) 82 collect(nestedType, absoluteName); 83 84 foreach (enumType; messageType.enumTypes) 85 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType; 86 } 87 88 foreach (file; request.protoFiles) 89 { 90 auto packagePrefix = file.package_ ? "." ~ file.package_ : ""; 91 92 foreach (messageType; file.messageTypes) 93 collect(messageType, packagePrefix); 94 95 foreach (enumType; file.enumTypes) 96 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType; 97 } 98 } 99 100 private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor) 101 { 102 import std.array : replace; 103 import std.exception : enforce; 104 105 enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3", 106 "Can only generate D code for proto3 .proto files.\n" ~ 107 "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n"); 108 109 auto file = new CodeGeneratorResponse.File; 110 111 file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d"; 112 file.content = generateFile(fileDescriptor); 113 114 return file; 115 } 116 117 private string generateFile(FileDescriptorProto fileDescriptor) 118 { 119 import std.array : appender, empty; 120 import std.format : format; 121 122 auto result = appender!string; 123 result ~= "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"; 124 result ~= "// source: %s\n\n".format(fileDescriptor.name); 125 result ~= "module %s;\n\n".format(fileDescriptor.moduleName); 126 result ~= "import google.protobuf;\n"; 127 128 foreach (dependency; fileDescriptor.dependencies) 129 result ~= "import %s;\n".format(dependency.moduleName); 130 131 if (!protocVersion.empty) 132 result ~= "\nenum protocVersion = %s;\n".format(protocVersion); 133 134 foreach (messageType; fileDescriptor.messageTypes) 135 result ~= generateMessage(messageType); 136 137 foreach (enumType; fileDescriptor.enumTypes) 138 result ~= generateEnum(enumType); 139 140 return result.data; 141 } 142 143 private string generateMessage(DescriptorProto messageType, size_t indent = 0) 144 { 145 import std.algorithm : canFind, filter, sort; 146 import std.array : appender, array; 147 import std.format : format; 148 149 // Don't generate MapEntry messages, they are generated as associative arrays 150 if (messageType.isMap) 151 return ""; 152 153 auto indentation = "%*s".format(indent, ""); 154 155 auto result = appender!string; 156 result ~= "\n"; 157 result ~= indentation; 158 result ~= indent > 0 ? "static " : ""; 159 result ~= messageAsStruct ? "struct" : "class"; 160 result ~= " "; 161 result ~= messageType.name.escapeKeywords; 162 result ~= "\n"; 163 result ~= indentation; 164 result ~= "{\n"; 165 166 int[] generatedOneofs; 167 foreach (field; messageType.fields.sort!((a, b) => a.number < b.number)) 168 { 169 if (field.oneofIndex < 0) 170 { 171 result ~= generateField(field, indent + indentSize); 172 continue; 173 } 174 175 if (generatedOneofs.canFind(field.oneofIndex)) 176 continue; 177 178 result ~= generateOneof(messageType.oneofDecls[field.oneofIndex], 179 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize); 180 generatedOneofs ~= field.oneofIndex; 181 } 182 183 foreach (nestedType; messageType.nestedTypes) 184 result ~= generateMessage(nestedType, indent + indentSize); 185 186 foreach (enumType; messageType.enumTypes) 187 result ~= generateEnum(enumType, indent + indentSize); 188 189 result ~= indentation; 190 result ~= "}\n"; 191 192 return result.data; 193 } 194 195 private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true) 196 { 197 import std.format : format; 198 199 return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field), 200 field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : ""); 201 } 202 203 private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 204 { 205 return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent); 206 } 207 208 private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 209 { 210 import std.format : format; 211 import std.array : appender; 212 213 auto result = appender!string; 214 result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true)); 215 result ~= "%*s{\n".format(indent, ""); 216 result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false)); 217 foreach (field; fields) 218 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false), 219 field.number); 220 result ~= "%*s}\n".format(indent, ""); 221 result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "", 222 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 223 result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "", 224 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 225 result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "", 226 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 227 228 return result.data; 229 } 230 231 private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 232 { 233 import std.format : format; 234 import std.array : appender; 235 236 auto result = appender!string; 237 result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false)); 238 result ~= "%*s{\n".format(indent, ""); 239 foreach (field; fields) 240 result ~= generateOneofField(field, indent + indentSize, field == fields[0]); 241 result ~= "%*s}\n".format(indent, ""); 242 243 return result.data; 244 } 245 246 private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer) 247 { 248 import std.format : format; 249 250 return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field), 251 typeName(field), field.name.underscoresToCamelCase(false), 252 printInitializer ? fieldInitializer(field) : ""); 253 } 254 255 private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0) 256 { 257 import std.array : appender, array; 258 import std.format : format; 259 260 auto result = appender!string; 261 result ~= "\n%*senum %s\n".format(indent, "", enumType.name); 262 result ~= "%*s{\n".format(indent, ""); 263 264 foreach (value; enumType.values) 265 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number); 266 267 result ~= "%*s}\n".format(indent, ""); 268 269 return result.data; 270 } 271 272 private DescriptorProto messageType(FieldDescriptorProto field) 273 { 274 return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null; 275 } 276 277 private EnumDescriptorProto enumType(FieldDescriptorProto field) 278 { 279 return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null; 280 } 281 282 private Wire wireByField(FieldDescriptorProto field) 283 { 284 final switch (field.type) with (FieldDescriptorProto.Type) 285 { 286 case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64: 287 case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM: 288 return Wire.none; 289 case TYPE_MESSAGE: 290 { 291 auto fieldMessageType = messageType(field); 292 293 if (fieldMessageType !is null && fieldMessageType.isMap) 294 { 295 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key)); 296 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value)); 297 298 return keyWire << 2 | valueWire << 4; 299 } 300 return Wire.none; 301 } 302 case TYPE_SINT32: case TYPE_SINT64: 303 return Wire.zigzag; 304 case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64: 305 return Wire.fixed; 306 case TYPE_GROUP: case TYPE_ERROR: 307 assert(0, "Invalid field type"); 308 } 309 } 310 311 private string baseTypeName(FieldDescriptorProto field) 312 { 313 import std.exception : enforce; 314 315 final switch (field.type) with (FieldDescriptorProto.Type) 316 { 317 case TYPE_BOOL: 318 return "bool"; 319 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 320 return "int"; 321 case TYPE_UINT32: case TYPE_FIXED32: 322 return "uint"; 323 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 324 return "long"; 325 case TYPE_UINT64: case TYPE_FIXED64: 326 return "ulong"; 327 case TYPE_FLOAT: 328 return "float"; 329 case TYPE_DOUBLE: 330 return "double"; 331 case TYPE_STRING: 332 return "string"; 333 case TYPE_BYTES: 334 return "bytes"; 335 case TYPE_MESSAGE: 336 { 337 auto fieldMessageType = messageType(field); 338 enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~ 339 "' has unknown message type " ~ field.typeName ~ "`"); 340 return fieldMessageType.name; 341 } 342 case TYPE_ENUM: 343 { 344 auto fieldEnumType = enumType(field); 345 enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~ 346 "' has unknown enum type ' " ~field.typeName ~ "`"); 347 return fieldEnumType.name; 348 } 349 case TYPE_GROUP: case TYPE_ERROR: 350 assert(0, "Invalid field type"); 351 } 352 } 353 354 string typeName(FieldDescriptorProto field) 355 { 356 import std.format : format; 357 358 string fieldBaseTypeName = baseTypeName(field); 359 360 auto fieldMessageType = messageType(field); 361 362 if (fieldMessageType !is null && fieldMessageType.isMap) 363 { 364 auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key); 365 auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value); 366 367 return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField)); 368 } 369 370 if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED) 371 return fieldBaseTypeName ~ "[]"; 372 else 373 return fieldBaseTypeName; 374 } 375 376 private string fieldProtoFields(FieldDescriptorProto field) 377 { 378 import std.algorithm : stripRight; 379 import std.conv : to; 380 import std.range : join; 381 382 static string packedByField(FieldDescriptorProto field) 383 { 384 return (field.options && field.options.packed) ? "Yes.packed" : "No.packed"; 385 } 386 387 return [field.number.to!string, wireByField(field).toString, packedByField(field)] 388 .stripRight("No.packed") 389 .stripRight("Wire.none") 390 .join(", "); 391 } 392 393 private string fieldInitializer(FieldDescriptorProto field) 394 { 395 import std.algorithm : endsWith; 396 import std.format : format; 397 398 auto fieldTypeName = typeName(field); 399 if (fieldTypeName.endsWith("]")) 400 return " = protoDefaultValue!(%s)".format(fieldTypeName); 401 else 402 return " = protoDefaultValue!%s".format(fieldTypeName); 403 } 404 405 private string protocVersion; 406 private DescriptorProto[string] collectedMessageTypes; 407 private EnumDescriptorProto[string] collectedEnumTypes; 408 } 409 410 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber) 411 { 412 import std.algorithm : find; 413 import std.array : empty; 414 import std.exception : enforce; 415 import std.format : format; 416 417 auto result = messageType.fields.find!(a => a.number == fieldNumber); 418 419 enforce!CodeGeneratorException(!result.empty, 420 "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber)); 421 422 return result[0]; 423 } 424 425 private bool isMap(DescriptorProto messageType) 426 { 427 return messageType.options && messageType.options.mapEntry; 428 } 429 430 private enum MapFieldNumber 431 { 432 key = 1, 433 value = 2, 434 } 435 436 private enum Wire 437 { 438 none, 439 fixed = 1 << 0, 440 zigzag = 1 << 1, 441 fixed_key = 1 << 2, 442 zigzag_key = 1 << 3, 443 fixed_value = 1 << 4, 444 zigzag_value = 1 << 5, 445 fixed_key_fixed_value = fixed_key | fixed_value, 446 fixed_key_zigzag_value = fixed_key | zigzag_value, 447 zigzag_key_fixed_value = zigzag_key | fixed_value, 448 zigzag_key_zigzag_value = zigzag_key | zigzag_value, 449 } 450 451 private string toString(Wire wire) 452 { 453 final switch (wire) with (Wire) 454 { 455 case none: 456 return "Wire.none"; 457 case fixed: 458 return "Wire.fixed"; 459 case zigzag: 460 return "Wire.zigzag"; 461 case fixed_key: 462 return "Wire.fixedKey"; 463 case zigzag_key: 464 return "Wire.zigzagKey"; 465 case fixed_value: 466 return "Wire.fixedValue"; 467 case zigzag_value: 468 return "Wire.zigzagValue"; 469 case fixed_key_fixed_value: 470 return "Wire.fixedKeyFixedValue"; 471 case fixed_key_zigzag_value: 472 return "Wire.fixedKeyZigzagValue"; 473 case zigzag_key_fixed_value: 474 return "Wire.zigzagKeyFixedValue"; 475 case zigzag_key_zigzag_value: 476 return "Wire.zigzagKeyZigzagValue"; 477 } 478 } 479 480 private string moduleName(FileDescriptorProto fileDescriptor) 481 { 482 import std.array : empty; 483 import std.path : baseName; 484 485 string moduleName = fileDescriptor.name.baseName(".proto"); 486 487 if (!fileDescriptor.package_.empty) 488 moduleName = fileDescriptor.package_ ~ "." ~ moduleName; 489 490 return moduleName.escapeKeywords; 491 } 492 493 private string moduleName(string fileName) 494 { 495 import std.array : replace; 496 import std.string : chomp; 497 498 return fileName.chomp(".proto").replace("/", ".").escapeKeywords; 499 } 500 501 private string underscoresToCamelCase(string input, bool capitalizeNextLetter) 502 { 503 import std.array : appender; 504 505 auto result = appender!string; 506 507 foreach (ubyte c; input) 508 { 509 if (c == '_') 510 { 511 capitalizeNextLetter = true; 512 continue; 513 } 514 515 if ('a' <= c && c <= 'z' && capitalizeNextLetter) 516 c += 'A' - 'a'; 517 518 result ~= c; 519 capitalizeNextLetter = false; 520 } 521 522 return result.data.escapeKeywords; 523 } 524 525 private enum string[] keywords = [ 526 "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch", 527 "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default", 528 "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final", 529 "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat", 530 "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro", 531 "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected", 532 "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch", 533 "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent", 534 "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with", 535 "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector", 536 "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__", 537 "__TIMESTAMP__", "__VENDOR__", "__VERSION__", 538 ]; 539 540 private string escapeKeywords(string input, string separator = ".") 541 { 542 import std.algorithm : canFind, joiner, map, splitter; 543 import std.conv : to; 544 545 return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string; 546 }