1 module protoc_gen_d; 2 3 import google.protobuf; 4 import google.protobuf.compiler.plugin : CodeGeneratorRequest, CodeGeneratorResponse; 5 import google.protobuf.descriptor : DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, 6 OneofDescriptorProto; 7 8 int main() 9 { 10 import std.algorithm : map; 11 import std.array : array; 12 import std.range : isInputRange, take, walkLength; 13 import std.stdio : stdin, stdout; 14 15 foreach (inputRange; stdin.byChunk(1024 * 1024)) 16 { 17 auto request = inputRange.fromProtobuf!CodeGeneratorRequest; 18 auto codeGenerator = new CodeGenerator; 19 20 stdout.rawWrite(codeGenerator.handle(request).toProtobuf.array); 21 } 22 23 return 0; 24 } 25 26 class CodeGeneratorException : Exception 27 { 28 this(string message = null, string file = __FILE__, size_t line = __LINE__, 29 Throwable next = null) @safe pure nothrow 30 { 31 super(message, file, line, next); 32 } 33 } 34 35 class CodeGenerator 36 { 37 private enum indentSize = 4; 38 39 CodeGeneratorResponse handle(CodeGeneratorRequest request) 40 { 41 import std.algorithm : filter, map; 42 import std.array : array; 43 import std.conv : to; 44 import std.format : format; 45 46 if (request.compilerVersion) with (request.compilerVersion) 47 protocVersion = format!"%d%03d%03d"(major, minor, patch); 48 49 collectedMessageTypes.clear; 50 auto response = new CodeGeneratorResponse; 51 try 52 { 53 collectMessageAndEnumTypes(request); 54 response.files = request.protoFiles 55 .filter!(a => a.package_ != "google.protobuf") // don't generate the well known types 56 .map!(a => generate(a)).array; 57 } 58 catch (CodeGeneratorException generatorException) 59 { 60 response.error = generatorException.message.to!string; 61 } 62 63 return response; 64 } 65 66 private void collectMessageAndEnumTypes(CodeGeneratorRequest request) 67 { 68 void collect(DescriptorProto messageType, string prefix) 69 { 70 auto absoluteName = prefix ~ "." ~ messageType.name; 71 72 if (absoluteName in collectedMessageTypes) 73 return; 74 75 collectedMessageTypes[absoluteName] = messageType; 76 77 foreach (nestedType; messageType.nestedTypes) 78 collect(nestedType, absoluteName); 79 80 foreach (enumType; messageType.enumTypes) 81 collectedEnumTypes[absoluteName ~ "." ~ enumType.name] = enumType; 82 } 83 84 foreach (file; request.protoFiles) 85 { 86 auto packagePrefix = file.package_ ? "." ~ file.package_ : ""; 87 88 foreach (messageType; file.messageTypes) 89 collect(messageType, packagePrefix); 90 91 foreach (enumType; file.enumTypes) 92 collectedEnumTypes[packagePrefix ~ "." ~ enumType.name] = enumType; 93 } 94 } 95 96 private CodeGeneratorResponse.File generate(FileDescriptorProto fileDescriptor) 97 { 98 import std.array : replace; 99 import std.exception : enforce; 100 101 enforce!CodeGeneratorException(fileDescriptor.syntax == "proto3", 102 "Can only generate D code for proto3 .proto files.\n" ~ 103 "Please add 'syntax = \"proto3\";' to the top of your .proto file.\n"); 104 105 auto file = new CodeGeneratorResponse.File; 106 107 file.name = fileDescriptor.moduleName.replace(".", "/") ~ ".d"; 108 file.content = generateFile(fileDescriptor); 109 110 return file; 111 } 112 113 private string generateFile(FileDescriptorProto fileDescriptor) 114 { 115 import std.array : appender, empty; 116 import std.format : format; 117 118 auto result = appender!string; 119 result ~= "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"; 120 result ~= "// source: %s\n\n".format(fileDescriptor.name); 121 result ~= "module %s;\n\n".format(fileDescriptor.moduleName); 122 result ~= "import google.protobuf;\n"; 123 124 foreach (dependency; fileDescriptor.dependencies) 125 result ~= "import %s;\n".format(dependency.moduleName); 126 127 if (!protocVersion.empty) 128 result ~= "\nenum protocVersion = %s;\n".format(protocVersion); 129 130 foreach (messageType; fileDescriptor.messageTypes) 131 result ~= generateMessage(messageType); 132 133 foreach (enumType; fileDescriptor.enumTypes) 134 result ~= generateEnum(enumType); 135 136 return result.data; 137 } 138 139 private string generateMessage(DescriptorProto messageType, size_t indent = 0) 140 { 141 import std.algorithm : canFind, filter, sort; 142 import std.array : appender, array; 143 import std.format : format; 144 145 // Don't generate MapEntry messages, they are generated as associative arrays 146 if (messageType.isMap) 147 return ""; 148 149 auto result = appender!string; 150 result ~= "\n%*s%sclass %s\n".format(indent, "", indent > 0 ? "static " : "", messageType.name.escapeKeywords); 151 result ~= "%*s{\n".format(indent, ""); 152 153 int[] generatedOneofs; 154 foreach (field; messageType.fields.sort!((a, b) => a.number < b.number)) 155 { 156 if (field.oneofIndex < 0) 157 { 158 result ~= generateField(field, indent + indentSize); 159 continue; 160 } 161 162 if (generatedOneofs.canFind(field.oneofIndex)) 163 continue; 164 165 result ~= generateOneof(messageType.oneofDecls[field.oneofIndex], 166 messageType.fields.filter!(a => a.oneofIndex == field.oneofIndex).array, indent + indentSize); 167 generatedOneofs ~= field.oneofIndex; 168 } 169 170 foreach (nestedType; messageType.nestedTypes) 171 result ~= generateMessage(nestedType, indent + indentSize); 172 173 foreach (enumType; messageType.enumTypes) 174 result ~= generateEnum(enumType, indent + indentSize); 175 176 result ~= "%*s}\n".format(indent, ""); 177 178 return result.data; 179 } 180 181 private string generateField(FieldDescriptorProto field, size_t indent, bool printInitializer = true) 182 { 183 import std.format : format; 184 185 return "%*s@Proto(%s) %s %s%s;\n".format(indent, "", fieldProtoFields(field), typeName(field), 186 field.name.underscoresToCamelCase(false), printInitializer ? fieldInitializer(field) : ""); 187 } 188 189 private string generateOneof(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 190 { 191 return generateOneofCaseEnum(oneof, fields, indent) ~ generateOneofUnion(oneof, fields, indent); 192 } 193 194 private string generateOneofCaseEnum(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 195 { 196 import std.format : format; 197 import std.array : appender; 198 199 auto result = appender!string; 200 result ~= "%*senum %sCase\n".format(indent, "", oneof.name.underscoresToCamelCase(true)); 201 result ~= "%*s{\n".format(indent, ""); 202 result ~= "%*s%sNotSet = 0,\n".format(indent + indentSize, "", oneof.name.underscoresToCamelCase(false)); 203 foreach (field; fields) 204 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", field.name.underscoresToCamelCase(false), 205 field.number); 206 result ~= "%*s}\n".format(indent, ""); 207 result ~= "%*s%3$sCase _%4$sCase = %3$sCase.%4$sNotSet;\n".format(indent, "", 208 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 209 result ~= "%*s@property %3$sCase %4$sCase() { return _%4$sCase; }\n".format(indent, "", 210 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 211 result ~= "%*svoid clear%3$s() { _%4$sCase = %3$sCase.%4$sNotSet; }\n".format(indent, "", 212 oneof.name.underscoresToCamelCase(true), oneof.name.underscoresToCamelCase(false)); 213 214 return result.data; 215 } 216 217 private string generateOneofUnion(OneofDescriptorProto oneof, FieldDescriptorProto[] fields, size_t indent) 218 { 219 import std.format : format; 220 import std.array : appender; 221 222 auto result = appender!string; 223 result ~= "%*s@Oneof(\"_%sCase\") union\n".format(indent, "", oneof.name.underscoresToCamelCase(false)); 224 result ~= "%*s{\n".format(indent, ""); 225 foreach (field; fields) 226 result ~= generateOneofField(field, indent + indentSize, field == fields[0]); 227 result ~= "%*s}\n".format(indent, ""); 228 229 return result.data; 230 } 231 232 private string generateOneofField(FieldDescriptorProto field, size_t indent, bool printInitializer) 233 { 234 import std.format : format; 235 236 return "%*s@Proto(%s) %s _%5$s%6$s; mixin(oneofAccessors!_%5$s);\n".format(indent, "", fieldProtoFields(field), 237 typeName(field), field.name.underscoresToCamelCase(false), 238 printInitializer ? fieldInitializer(field) : ""); 239 } 240 241 private string generateEnum(EnumDescriptorProto enumType, size_t indent = 0) 242 { 243 import std.array : appender, array; 244 import std.format : format; 245 246 auto result = appender!string; 247 result ~= "\n%*senum %s\n".format(indent, "", enumType.name); 248 result ~= "%*s{\n".format(indent, ""); 249 250 foreach (value; enumType.values) 251 result ~= "%*s%s = %s,\n".format(indent + indentSize, "", value.name.escapeKeywords, value.number); 252 253 result ~= "%*s}\n".format(indent, ""); 254 255 return result.data; 256 } 257 258 private DescriptorProto messageType(FieldDescriptorProto field) 259 { 260 return field.typeName in collectedMessageTypes ? collectedMessageTypes[field.typeName] : null; 261 } 262 263 private EnumDescriptorProto enumType(FieldDescriptorProto field) 264 { 265 return field.typeName in collectedEnumTypes ? collectedEnumTypes[field.typeName] : null; 266 } 267 268 private Wire wireByField(FieldDescriptorProto field) 269 { 270 final switch (field.type) with (FieldDescriptorProto.Type) 271 { 272 case TYPE_BOOL: case TYPE_INT32: case TYPE_UINT32: case TYPE_INT64: case TYPE_UINT64: 273 case TYPE_FLOAT: case TYPE_DOUBLE: case TYPE_STRING: case TYPE_BYTES: case TYPE_ENUM: 274 return Wire.none; 275 case TYPE_MESSAGE: 276 { 277 auto fieldMessageType = messageType(field); 278 279 if (fieldMessageType !is null && fieldMessageType.isMap) 280 { 281 Wire keyWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.key)); 282 Wire valueWire = wireByField(fieldMessageType.fieldByNumber(MapFieldNumber.value)); 283 284 return keyWire << 2 | valueWire << 4; 285 } 286 return Wire.none; 287 } 288 case TYPE_SINT32: case TYPE_SINT64: 289 return Wire.zigzag; 290 case TYPE_SFIXED32: case TYPE_FIXED32: case TYPE_SFIXED64: case TYPE_FIXED64: 291 return Wire.fixed; 292 case TYPE_GROUP: case TYPE_ERROR: 293 assert(0, "Invalid field type"); 294 } 295 } 296 297 private string baseTypeName(FieldDescriptorProto field) 298 { 299 import std.exception : enforce; 300 301 final switch (field.type) with (FieldDescriptorProto.Type) 302 { 303 case TYPE_BOOL: 304 return "bool"; 305 case TYPE_INT32: case TYPE_SINT32: case TYPE_SFIXED32: 306 return "int"; 307 case TYPE_UINT32: case TYPE_FIXED32: 308 return "uint"; 309 case TYPE_INT64: case TYPE_SINT64: case TYPE_SFIXED64: 310 return "long"; 311 case TYPE_UINT64: case TYPE_FIXED64: 312 return "ulong"; 313 case TYPE_FLOAT: 314 return "float"; 315 case TYPE_DOUBLE: 316 return "double"; 317 case TYPE_STRING: 318 return "string"; 319 case TYPE_BYTES: 320 return "bytes"; 321 case TYPE_MESSAGE: 322 { 323 auto fieldMessageType = messageType(field); 324 enforce!CodeGeneratorException(fieldMessageType !is null, "Field '" ~ field.name ~ 325 "' has unknown message type " ~ field.typeName ~ "`"); 326 return fieldMessageType.name; 327 } 328 case TYPE_ENUM: 329 { 330 auto fieldEnumType = enumType(field); 331 enforce!CodeGeneratorException(fieldEnumType !is null, "Field '" ~ field.name ~ 332 "' has unknown enum type ' " ~field.typeName ~ "`"); 333 return fieldEnumType.name; 334 } 335 case TYPE_GROUP: case TYPE_ERROR: 336 assert(0, "Invalid field type"); 337 } 338 } 339 340 string typeName(FieldDescriptorProto field) 341 { 342 import std.format : format; 343 344 string fieldBaseTypeName = baseTypeName(field); 345 346 auto fieldMessageType = messageType(field); 347 348 if (fieldMessageType !is null && fieldMessageType.isMap) 349 { 350 auto keyField = fieldMessageType.fieldByNumber(MapFieldNumber.key); 351 auto valueField = fieldMessageType.fieldByNumber(MapFieldNumber.value); 352 353 return "%s[%s]".format(baseTypeName(valueField), baseTypeName(keyField)); 354 } 355 356 if (field.label == FieldDescriptorProto.Label.LABEL_REPEATED) 357 return fieldBaseTypeName ~ "[]"; 358 else 359 return fieldBaseTypeName; 360 } 361 362 private string fieldProtoFields(FieldDescriptorProto field) 363 { 364 import std.algorithm : stripRight; 365 import std.conv : to; 366 import std.range : join; 367 368 return [field.number.to!string, wireByField(field).toString].stripRight("").stripRight("Wire.none").join(", "); 369 } 370 371 private string fieldInitializer(FieldDescriptorProto field) 372 { 373 import std.algorithm : endsWith; 374 import std.format : format; 375 376 auto fieldTypeName = typeName(field); 377 if (fieldTypeName.endsWith("]")) 378 return " = protoDefaultValue!(%s)".format(fieldTypeName); 379 else 380 return " = protoDefaultValue!%s".format(fieldTypeName); 381 } 382 383 private string protocVersion; 384 private DescriptorProto[string] collectedMessageTypes; 385 private EnumDescriptorProto[string] collectedEnumTypes; 386 } 387 388 private FieldDescriptorProto fieldByNumber(DescriptorProto messageType, int fieldNumber) 389 { 390 import std.algorithm : find; 391 import std.array : empty; 392 import std.exception : enforce; 393 import std.format : format; 394 395 auto result = messageType.fields.find!(a => a.number == fieldNumber); 396 397 enforce!CodeGeneratorException(!result.empty, 398 "Message '%s' has no field with tag %s".format(messageType.name, fieldNumber)); 399 400 return result[0]; 401 } 402 403 private bool isMap(DescriptorProto messageType) 404 { 405 return messageType.options && messageType.options.mapEntry; 406 } 407 408 private enum MapFieldNumber 409 { 410 key = 1, 411 value = 2, 412 } 413 414 private enum Wire 415 { 416 none, 417 fixed = 1 << 0, 418 zigzag = 1 << 1, 419 fixed_key = 1 << 2, 420 zigzag_key = 1 << 3, 421 fixed_value = 1 << 4, 422 zigzag_value = 1 << 5, 423 fixed_key_fixed_value = fixed_key | fixed_value, 424 fixed_key_zigzag_value = fixed_key | zigzag_value, 425 zigzag_key_fixed_value = zigzag_key | fixed_value, 426 zigzag_key_zigzag_value = zigzag_key | zigzag_value, 427 } 428 429 private string toString(Wire wire) 430 { 431 final switch (wire) with (Wire) 432 { 433 case none: 434 return "Wire.none"; 435 case fixed: 436 return "Wire.fixed"; 437 case zigzag: 438 return "Wire.zigzag"; 439 case fixed_key: 440 return "Wire.fixedKey"; 441 case zigzag_key: 442 return "Wire.zigzagKey"; 443 case fixed_value: 444 return "Wire.fixedValue"; 445 case zigzag_value: 446 return "Wire.zigzagValue"; 447 case fixed_key_fixed_value: 448 return "Wire.fixedKeyFixedValue"; 449 case fixed_key_zigzag_value: 450 return "Wire.fixedKeyZigzagValue"; 451 case zigzag_key_fixed_value: 452 return "Wire.zigzagKeyFixedValue"; 453 case zigzag_key_zigzag_value: 454 return "Wire.zigzagKeyZigzagValue"; 455 } 456 } 457 458 private string moduleName(FileDescriptorProto fileDescriptor) 459 { 460 import std.array : empty; 461 import std.path : baseName; 462 463 string moduleName = fileDescriptor.name.baseName(".proto"); 464 465 if (!fileDescriptor.package_.empty) 466 moduleName = fileDescriptor.package_ ~ "." ~ moduleName; 467 468 return moduleName.escapeKeywords; 469 } 470 471 private string moduleName(string fileName) 472 { 473 import std.array : replace; 474 import std.string : chomp; 475 476 return fileName.chomp(".proto").replace("/", ".").escapeKeywords; 477 } 478 479 private string underscoresToCamelCase(string input, bool capitalizeNextLetter) 480 { 481 import std.array : appender; 482 483 auto result = appender!string; 484 485 foreach (ubyte c; input) 486 { 487 if (c == '_') 488 { 489 capitalizeNextLetter = true; 490 continue; 491 } 492 493 if ('a' <= c && c <= 'z' && capitalizeNextLetter) 494 c += 'A' - 'a'; 495 496 result ~= c; 497 capitalizeNextLetter = false; 498 } 499 500 return result.data.escapeKeywords; 501 } 502 503 private enum string[] keywords = [ 504 "abstract", "alias", "align", "asm", "assert", "auto", "body", "bool", "break", "byte", "case", "cast", "catch", 505 "cdouble", "cent", "cfloat", "char", "class", "const", "continue", "creal", "dchar", "debug", "default", 506 "delegate", "delete", "deprecated", "do", "double", "else", "enum", "export", "extern", "false", "final", 507 "finally", "float", "for", "foreach", "foreach_reverse", "function", "goto", "idouble", "if", "ifloat", 508 "immutable", "import", "in", "inout", "int", "interface", "invariant", "ireal", "is", "lazy", "long", "macro", 509 "mixin", "module", "new", "nothrow", "null", "out", "override", "package", "pragma", "private", "protected", 510 "public", "pure", "real", "ref", "return", "scope", "shared", "short", "static", "struct", "super", "switch", 511 "synchronized", "template", "this", "throw", "true", "try", "typedef", "typeid", "typeof", "ubyte", "ucent", 512 "uint", "ulong", "union", "unittest", "ushort", "version", "void", "volatile", "wchar", "while", "with", 513 "__FILE__", "__MODULE__", "__LINE__", "__FUNCTION__", "__PRETTY_FUNCTION__", "__gshared", "__traits", "__vector", 514 "__parameters", "string", "wstring", "dstring", "size_t", "ptrdiff_t", "__DATE__", "__EOF__", "__TIME__", 515 "__TIMESTAMP__", "__VENDOR__", "__VERSION__", 516 ]; 517 518 private string escapeKeywords(string input, string separator = ".") 519 { 520 import std.algorithm : canFind, joiner, map, splitter; 521 import std.conv : to; 522 523 return input.splitter(separator).map!(a => keywords.canFind(a) ? a ~ '_' : a).joiner(separator).to!string; 524 }