1 module protoc_gen_d; 2 3 import google.protobuf; 4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse; 5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, 6 OneofDescriptorProto; 7 8 int main() 9 { 10 import std.algorithm : map; 11 import std.array : array; 12 import std.range : isInputRange, take, walkLength; 13 import std.stdio : stdin, stdout; 14 15 foreach (inputRange; stdin.byChunk(1024 * 1024)) 16 { 17 auto request = inputRange.fromProtobuf!CodeGeneratorRequest; 18 auto codeGenerator = new CodeGenerator; 19 20 stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array); 21 } 22 23 return 0; 24 } 25 26 class CodeGeneratorException : Exception 27 { 28 this(string message = null, string file = __FILE__, size_t line = __LINE__, 29 Throwable next = null) @safe pure nothrow 30 { 31 super(message, file, line, next); 32 } 33 } 34 35 class CodeGenerator 36 { 37 private enum indentSize = 4; 38 private bool messageAsStruct = false; 39 40 CodeGeneratorResponse handle(CodeGeneratorRequest request) 41 { 42 import std.algorithm : filter, map, canFind, splitter; 43 import std.array : array; 44 import std.conv : to; 45 import std.format : format; 46 47 if (request.parameter.splitter(",").canFind("message-as-struct")) 48 messageAsStruct = true; 49 50 if (request.compilerVersion) with (request.compilerVersion) 51 protocVersion = format!"%d%03d%03d"(major, minor, patch); 52 53 collectedMessageTypes.clear; 54 auto response = new CodeGeneratorResponse; 55 try 56 { 57 collectMessageAndEnumTypes(request); 58 response.files = request.protoFiles 59 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types 60 .map!(a => generate(a)).array; 61 } 62 catch (CodeGeneratorException generatorException) 63 { 64 response.error = generatorException.message.to!string; 65 } 66 67 return response; 68 } 69 70 private void collectMessageAndEnumTypes(CodeGeneratorRequest request) 71 { 72 void collect(DescriptorProto messageType, string prefix) 73 { 74 auto absoluteName = prefix ~ "." ~ messageType.name; 75 76 if (absoluteName in collectedMessageTypes) 77 return; 78 79 collectedMessageTypes[absoluteName] = messageType; 80 81 foreach (nestedType; messageType.nestedTypes) 82 collect(nestedType, absoluteName); 83 84 foreach (enumType; messageType.enumTypes) 85 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType; 86 } 87 88 foreach (file; request.protoFiles) 89 { 90 auto packagePrefix = file.package_ ? "." ~ file.package_ : ""; 91 92 foreach (messageType; file.messageTypes) 93 collect(messageType, packagePrefix); 94 95 foreach (enumType; file.enumTypes) 96 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType; 97 } 98 } 99 100 private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor) 101 { 102 import std.array : replace; 103 import std.exception : enforce; 104 105 enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3", 106 "Can only generate D code for proto3 .proto files.\n" ~ 107 "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n"); 108 109 auto file = new CodeGeneratorResponse.File; 110 111 file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d"; 112 file.content = generateFile(fileDescriptor); 113 114 return file; 115 } 116 117 private string generateFile(FileDescriptorProto fileDescriptor) 118 { 119 import std.array : appender, empty; 120 import std.format : format; 121 122 auto result = appender!string; 123 result ~= "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"; 124 result ~= "// source: %s\n\n".format(fileDescriptor.name); 125 result ~= "module %s;\n\n".format(fileDescriptor.moduleName); 126 result ~= "import google.protobuf;\n"; 127 128 foreach (dependency; fileDescriptor.dependencies) 129 result ~= "import %s;\n".format(dependency.moduleName); 130 131 if (!protocVersion.empty) 132 result ~= "\nenum protocVersion = %s;\n".format(protocVersion); 133 134 foreach (messageType; fileDescriptor.messageTypes) 135 result ~= generateMessage(messageType); 136 137 foreach (enumType; fileDescriptor.enumTypes) 138 result ~= generateEnum(enumType); 139 140 return result.data; 141 } 142 143 private string generateMessage(DescriptorProto messageType, size_t indent = 0) 144 { 145 import std.algorithm : canFind, filter, sort; 146 import std.array : appender, array; 147 import std.format : format; 148 149 // Don't generate MapEntry messages, they are generated as associative arrays 150 if (messageType.isMap) 151 return ""; 152 153 auto indentation = "%*s".format(indent, ""); 154 155 auto result = appender!string; 156 result ~= "\n"; 157 result ~= indentation; 158 result ~= indent > 0 ? "static " : ""; 159 result ~= messageAsStruct ? "struct" : "class"; 160 result ~= " "; 161 result ~= messageType.name.escapeKeywords; 162 result ~= "\n"; 163 result ~= indentation; 164 result ~= "{\n"; 165 166 int[] generatedOneofs; 167 foreach (field; messageType.fields.sort!((a, b) => a.number < b.number)) 168 { 169 if (field.oneofIndex < 0) 170 { 171 result ~= generateField(field, indent + indentSize); 172 continue; 173 } 174 175 if (generatedOneofs.canFind(field.oneofIndex)) 176 continue; 177 178 result ~= generateOneof(messageType.oneofDecls[field.oneofIndex], 179 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize); 180 generatedOneofs ~= field.oneofIndex; 181 } 182 183 foreach (nestedType; messageType.nestedTypes) 184 result ~= generateMessage(nestedType, indent + indentSize); 185 186 foreach (enumType; messageType.enumTypes) 187 result ~= generateEnum(enumType, indent + indentSize); 188 189 result ~= indentation; 190 result ~= "}\n"; 191 192 return result.data; 193 } 194 195 private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true) 196 { 197 import std.format : format; 198 199 return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field), 200 field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : ""); 201 } 202 203 private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 204 { 205 return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent); 206 } 207 208 private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 209 { 210 import std.format : format; 211 import std.array : appender; 212 213 auto result = appender!string; 214 result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true)); 215 result ~= "%*s{\n".format(indent, ""); 216 result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false)); 217 foreach (field; fields) 218 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false), 219 field.number); 220 result ~= "%*s}\n".format(indent, ""); 221 result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "", 222 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 223 result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "", 224 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 225 result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "", 226 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 227 228 return result.data; 229 } 230 231 private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 232 { 233 import std.format : format; 234 import std.array : appender; 235 236 auto result = appender!string; 237 result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false)); 238 result ~= "%*s{\n".format(indent, ""); 239 foreach (field; fields) 240 result ~= generateOneofField(field, indent + indentSize, field == fields[0]); 241 result ~= "%*s}\n".format(indent, ""); 242 243 return result.data; 244 } 245 246 private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer) 247 { 248 import std.format : format; 249 250 return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field), 251 typeName(field), field.name.underscoresToCamelCase(false), 252 printInitializer ? fieldInitializer(field) : ""); 253 } 254 255 private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0) 256 { 257 import std.array : appender, array; 258 import std.format : format; 259 260 auto result = appender!string; 261 result ~= "\n%*senum %s\n".format(indent, "", enumType.name); 262 result ~= "%*s{\n".format(indent, ""); 263 264 foreach (value; enumType.values) 265 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number); 266 267 result ~= "%*s}\n".format(indent, ""); 268 269 return result.data; 270 } 271 272 private DescriptorProto messageType(FieldDescriptorProto field) 273 { 274 return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null; 275 } 276 277 private EnumDescriptorProto enumType(FieldDescriptorProto field) 278 { 279 return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null; 280 } 281 282 private Wire wireByField(FieldDescriptorProto field) 283 { 284 final switch (field.type) with (FieldDescriptorProto.Type) 285 { 286 case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64: 287 case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM: 288 return Wire.none; 289 case TYPE_MESSAGE: 290 { 291 auto fieldMessageType = messageType(field); 292 293 if (fieldMessageType !is null && fieldMessageType.isMap) 294 { 295 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key)); 296 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value)); 297 298 return keyWire << 2 | valueWire << 4; 299 } 300 return Wire.none; 301 } 302 case TYPE_SINT32: case TYPE_SINT64: 303 return Wire.zigzag; 304 case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64: 305 return Wire.fixed; 306 case TYPE_GROUP: case TYPE_ERROR: 307 assert(0, "Invalid field type"); 308 } 309 } 310 311 private string baseTypeName(FieldDescriptorProto field) 312 { 313 import std.exception : enforce; 314 315 final switch (field.type) with (FieldDescriptorProto.Type) 316 { 317 case TYPE_BOOL: 318 return "bool"; 319 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 320 return "int"; 321 case TYPE_UINT32: case TYPE_FIXED32: 322 return "uint"; 323 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 324 return "long"; 325 case TYPE_UINT64: case TYPE_FIXED64: 326 return "ulong"; 327 case TYPE_FLOAT: 328 return "float"; 329 case TYPE_DOUBLE: 330 return "double"; 331 case TYPE_STRING: 332 return "string"; 333 case TYPE_BYTES: 334 return "bytes"; 335 case TYPE_MESSAGE: 336 { 337 auto fieldMessageType = messageType(field); 338 enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~ 339 "' has unknown message type " ~ field.typeName ~ "`"); 340 return fieldMessageType.name; 341 } 342 case TYPE_ENUM: 343 { 344 auto fieldEnumType = enumType(field); 345 enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~ 346 "' has unknown enum type ' " ~field.typeName ~ "`"); 347 return fieldEnumType.name; 348 } 349 case TYPE_GROUP: case TYPE_ERROR: 350 assert(0, "Invalid field type"); 351 } 352 } 353 354 string typeName(FieldDescriptorProto field) 355 { 356 import std.format : format; 357 358 string fieldBaseTypeName = baseTypeName(field); 359 360 auto fieldMessageType = messageType(field); 361 362 if (fieldMessageType !is null && fieldMessageType.isMap) 363 { 364 auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key); 365 auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value); 366 367 return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField)); 368 } 369 370 if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED) 371 return fieldBaseTypeName ~ "[]"; 372 else 373 return fieldBaseTypeName; 374 } 375 376 private string fieldProtoFields(FieldDescriptorProto field) 377 { 378 import std.algorithm : stripRight; 379 import std.conv : to; 380 import std.range : join; 381 382 return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", "); 383 } 384 385 private string fieldInitializer(FieldDescriptorProto field) 386 { 387 import std.algorithm : endsWith; 388 import std.format : format; 389 390 auto fieldTypeName = typeName(field); 391 if (fieldTypeName.endsWith("]")) 392 return " = protoDefaultValue!(%s)".format(fieldTypeName); 393 else 394 return " = protoDefaultValue!%s".format(fieldTypeName); 395 } 396 397 private string protocVersion; 398 private DescriptorProto[string] collectedMessageTypes; 399 private EnumDescriptorProto[string] collectedEnumTypes; 400 } 401 402 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber) 403 { 404 import std.algorithm : find; 405 import std.array : empty; 406 import std.exception : enforce; 407 import std.format : format; 408 409 auto result = messageType.fields.find!(a => a.number == fieldNumber); 410 411 enforce!CodeGeneratorException(!result.empty, 412 "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber)); 413 414 return result[0]; 415 } 416 417 private bool isMap(DescriptorProto messageType) 418 { 419 return messageType.options && messageType.options.mapEntry; 420 } 421 422 private enum MapFieldNumber 423 { 424 key = 1, 425 value = 2, 426 } 427 428 private enum Wire 429 { 430 none, 431 fixed = 1 << 0, 432 zigzag = 1 << 1, 433 fixed_key = 1 << 2, 434 zigzag_key = 1 << 3, 435 fixed_value = 1 << 4, 436 zigzag_value = 1 << 5, 437 fixed_key_fixed_value = fixed_key | fixed_value, 438 fixed_key_zigzag_value = fixed_key | zigzag_value, 439 zigzag_key_fixed_value = zigzag_key | fixed_value, 440 zigzag_key_zigzag_value = zigzag_key | zigzag_value, 441 } 442 443 private string toString(Wire wire) 444 { 445 final switch (wire) with (Wire) 446 { 447 case none: 448 return "Wire.none"; 449 case fixed: 450 return "Wire.fixed"; 451 case zigzag: 452 return "Wire.zigzag"; 453 case fixed_key: 454 return "Wire.fixedKey"; 455 case zigzag_key: 456 return "Wire.zigzagKey"; 457 case fixed_value: 458 return "Wire.fixedValue"; 459 case zigzag_value: 460 return "Wire.zigzagValue"; 461 case fixed_key_fixed_value: 462 return "Wire.fixedKeyFixedValue"; 463 case fixed_key_zigzag_value: 464 return "Wire.fixedKeyZigzagValue"; 465 case zigzag_key_fixed_value: 466 return "Wire.zigzagKeyFixedValue"; 467 case zigzag_key_zigzag_value: 468 return "Wire.zigzagKeyZigzagValue"; 469 } 470 } 471 472 private string moduleName(FileDescriptorProto fileDescriptor) 473 { 474 import std.array : empty; 475 import std.path : baseName; 476 477 string moduleName = fileDescriptor.name.baseName(".proto"); 478 479 if (!fileDescriptor.package_.empty) 480 moduleName = fileDescriptor.package_ ~ "." ~ moduleName; 481 482 return moduleName.escapeKeywords; 483 } 484 485 private string moduleName(string fileName) 486 { 487 import std.array : replace; 488 import std..string : chomp; 489 490 return fileName.chomp(".proto").replace("/", ".").escapeKeywords; 491 } 492 493 private string underscoresToCamelCase(string input, bool capitalizeNextLetter) 494 { 495 import std.array : appender; 496 497 auto result = appender!string; 498 499 foreach (ubyte c; input) 500 { 501 if (c == '_') 502 { 503 capitalizeNextLetter = true; 504 continue; 505 } 506 507 if ('a' <= c && c <= 'z' && capitalizeNextLetter) 508 c += 'A' - 'a'; 509 510 result ~= c; 511 capitalizeNextLetter = false; 512 } 513 514 return result.data.escapeKeywords; 515 } 516 517 private enum string[] keywords = [ 518 "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch", 519 "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default", 520 "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final", 521 "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat", 522 "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro", 523 "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected", 524 "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch", 525 "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent", 526 "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with", 527 "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector", 528 "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__", 529 "__TIMESTAMP__", "__VENDOR__", "__VERSION__", 530 ]; 531 532 private string escapeKeywords(string input, string separator = ".") 533 { 534 import std.algorithm : canFind, joiner, map, splitter; 535 import std.conv : to; 536 537 return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string; 538 }