1 module google.protobuf.decoding;
2 
3 import std.range : ElementType, empty, isInputRange;
4 import std.traits : isArray, isAssociativeArray, isBoolean, isFloatingPoint, isIntegral, KeyType, ValueType;
5 import google.protobuf.common;
6 import google.protobuf.internal;
7 
8 T fromProtobuf(T, R)(ref R inputRange)
9 if (isInputRange!R && isBoolean!T)
10 {
11     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
12 
13     return cast(T) fromVarint(inputRange);
14 }
15 
16 unittest
17 {
18     import std.array : array;
19     import google.protobuf.encoding : toProtobuf;
20 
21     auto buffer = true.toProtobuf.array;
22     assert(buffer.fromProtobuf!bool);
23 }
24 
25 T fromProtobuf(T, Wire wire = Wire.none, R)(ref R inputRange)
26 if (isInputRange!R && isIntegral!T)
27 {
28     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
29 
30     static if (wire == Wire.none)
31     {
32         return cast(T) fromVarint(inputRange);
33     }
34     else static if (wire == Wire.fixed)
35     {
36         return inputRange.decodeFixed!T;
37     }
38     else static if (wire == Wire.zigzag)
39     {
40         import std.traits : Unsigned;
41 
42         return zagZig(cast(Unsigned!T) fromVarint(inputRange));
43     }
44     else
45     {
46         assert(0, "Invalid wire encoding");
47     }
48 }
49 
50 unittest
51 {
52     import std.array : array;
53     import google.protobuf.encoding : toProtobuf;
54 
55     auto buffer = 10.toProtobuf.array;
56     assert(buffer.fromProtobuf!int == 10);
57     buffer = (-1).toProtobuf.array;
58     assert(buffer.fromProtobuf!int == -1);
59     buffer = (-1L).toProtobuf.array;
60     assert(buffer.fromProtobuf!long == -1L);
61     buffer = 0xffffffffffffffffUL.toProtobuf.array;
62     assert(buffer.fromProtobuf!long == 0xffffffffffffffffUL);
63 
64     buffer = 1.toProtobuf!(Wire.fixed).array;
65     assert(buffer.fromProtobuf!(int, Wire.fixed) == 1);
66     buffer = (-1).toProtobuf!(Wire.fixed).array;
67     assert(buffer.fromProtobuf!(int, Wire.fixed) == -1);
68     buffer = 0xffffffffU.toProtobuf!(Wire.fixed).array;
69     assert(buffer.fromProtobuf!(uint, Wire.fixed) == 0xffffffffU);
70     buffer = 1L.toProtobuf!(Wire.fixed).array;
71     assert(buffer.fromProtobuf!(long, Wire.fixed) == 1L);
72 
73     buffer = 1.toProtobuf!(Wire.zigzag).array;
74     assert(buffer.fromProtobuf!(int, Wire.zigzag) == 1);
75     buffer = (-1).toProtobuf!(Wire.zigzag).array;
76     assert(buffer.fromProtobuf!(int, Wire.zigzag) == -1);
77     buffer = 1L.toProtobuf!(Wire.zigzag).array;
78     assert(buffer.fromProtobuf!(long, Wire.zigzag) == 1L);
79     buffer = (-1L).toProtobuf!(Wire.zigzag).array;
80     assert(buffer.fromProtobuf!(long, Wire.zigzag) == -1L);
81 }
82 
83 T fromProtobuf(T, R)(ref R inputRange)
84 if (isInputRange!R && isFloatingPoint!T)
85 {
86     return inputRange.decodeFixed!T;
87 }
88 
89 unittest
90 {
91     import std.array : array;
92     import google.protobuf.encoding : toProtobuf;
93 
94     auto buffer = (0.0).toProtobuf.array;
95     assert(buffer.fromProtobuf!double == 0.0);
96     buffer = (0.0f).toProtobuf.array;
97     assert(buffer.fromProtobuf!float == 0.0f);
98 }
99 
100 T fromProtobuf(T, R)(ref R inputRange)
101 if (isInputRange!R && (is(T == string) || is(T == bytes)))
102 {
103     import std.array : array;
104 
105     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
106 
107     R fieldRange = inputRange.takeLengthPrefixed;
108 
109     return cast(T) fieldRange.array;
110 }
111 
112 unittest
113 {
114     import std.array : array;
115     import google.protobuf.encoding : toProtobuf;
116 
117     auto buffer = "abc".toProtobuf.array;
118     assert(buffer.fromProtobuf!string == "abc");
119     buffer = "".toProtobuf.array;
120     assert(buffer.fromProtobuf!string.empty);
121     buffer = (cast(bytes) [1, 2, 3]).toProtobuf.array;
122     assert(buffer.fromProtobuf!bytes == (cast(bytes) [1, 2, 3]));
123     buffer = (cast(bytes) []).toProtobuf.array;
124     assert(buffer.fromProtobuf!bytes.empty);
125 }
126 
127 T fromProtobuf(T, Wire wire = Wire.none, R)(ref R inputRange)
128 if (isInputRange!R && isArray!T && !is(T == string) && !is(T == bytes))
129 {
130     import std.array : Appender;
131 
132     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
133 
134     R fieldRange = inputRange.takeLengthPrefixed;
135 
136     Appender!T result;
137     static if (wire == Wire.none)
138     {
139         while (!fieldRange.empty)
140             result ~= fieldRange.fromProtobuf!(ElementType!T);
141     }
142     else
143     {
144         static assert(isIntegral!(ElementType!T), "Cannot specify wire format for non-integral arrays");
145 
146         while (!fieldRange.empty)
147             result ~= fieldRange.fromProtobuf!(ElementType!T, wire);
148     }
149 
150     return result.data;
151 }
152 
153 unittest
154 {
155     import std.array : array;
156     import google.protobuf.encoding : toProtobuf;
157 
158     auto buffer = [false, false, true].toProtobuf.array;
159     assert(buffer.fromProtobuf!(bool[]) == [false, false, true]);
160     buffer = [1, 2].toProtobuf!(Wire.fixed).array;
161     assert(buffer.fromProtobuf!(int[], Wire.fixed) == [1, 2]);
162     buffer = [1, 2].toProtobuf.array;
163     assert(buffer.fromProtobuf!(int[]) == [1, 2]);
164     buffer = [-54L, 54L].toProtobuf!(Wire.zigzag).array;
165     assert(buffer.fromProtobuf!(long[], Wire.zigzag) == [-54L, 54L]);
166 }
167 
168 T fromProtobuf(T, R)(ref R inputRange, T result = protoDefaultValue!T)
169 if (isInputRange!R && (is(T == class) || is(T == struct)))
170 {
171     import std.exception : enforce;
172     import std.format : format;
173     import std.meta : Alias;
174     import std.traits : hasMember;
175 
176     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
177 
178     static if (is(T == class))
179     {
180         if (result is null)
181             result = new T;
182     }
183 
184     static if (hasMember!(T, "fromProtobuf"))
185     {
186         return result.fromProtobuf(inputRange);
187     }
188     else
189     {
190         while (!inputRange.empty)
191         {
192             auto tagWire = inputRange.decodeTag;
193 
194             chooseFieldDecoder:
195             switch (tagWire.tag)
196             {
197             foreach (fieldName; Message!T.fieldNames)
198             {
199                 alias field = Alias!(mixin("T." ~ fieldName));
200                 case protoByField!field.tag:
201                 {
202                     enum proto = protoByField!field;
203 
204                     static if (isFieldPackable!field)
205                     {
206                         if (tagWire.wireType == WireType.withLength)
207                         {
208                             enum proto2 = Proto(proto.tag, proto.wire, Yes.packed);
209                             enum wireTypeExpected = wireType!(proto2, typeof(field));
210                             enforce!ProtobufException(tagWire.wireType == wireTypeExpected,
211                                 "Wrong wire format '%s' of field %s, expected '%s' "
212                                     .format(tagWire.wireType, T.stringof ~ "." ~ fieldName, wireTypeExpected));
213 
214                             inputRange.fromProtobufByProto!proto2(mixin("result." ~ __traits(identifier, field)));
215                         }
216                         else
217                         {
218                             enum proto2 = Proto(proto.tag, proto.wire, No.packed);
219                             enum wireTypeExpected = wireType!(proto2, typeof(field));
220                             enforce!ProtobufException(tagWire.wireType == wireTypeExpected,
221                                 "Wrong wire format '%s' of field %s, expected '%s' "
222                                     .format(tagWire.wireType, T.stringof ~ "." ~ fieldName, wireTypeExpected));
223 
224                             inputRange.fromProtobufByProto!proto2(mixin("result." ~ __traits(identifier, field)));
225                         }
226                     }
227                     else {
228                         enum wireTypeExpected = wireType!(proto, typeof(field));
229                         enforce!ProtobufException(tagWire.wireType == wireTypeExpected,
230                             "Wrong wire format '%s' of field %s, expected '%s' "
231                                 .format(tagWire.wireType, T.stringof ~ "." ~ fieldName, wireTypeExpected));
232 
233                         inputRange.fromProtobufByProto!proto(mixin("result." ~ __traits(identifier, field)));
234                     }
235                     static if (isOneof!field)
236                     {
237                         enum oneofCase = "result." ~ oneofCaseFieldName!field;
238                         enum fieldCase = "T." ~ typeof(mixin(oneofCase)).stringof ~ "." ~ oneofAccessorName!field;
239 
240                         mixin(oneofCase) = mixin(fieldCase);
241                     }
242 
243                     break chooseFieldDecoder;
244                 }
245             }
246             default:
247                 skipUnknown(inputRange, tagWire.wireType);
248                 break;
249             }
250         }
251         return result;
252     }
253 }
254 
255 unittest
256 {
257     static class Foo
258     {
259         @Proto(1) int bar;
260         @Proto(3) bool qux;
261         @Proto(2, Wire.fixed) long baz;
262     }
263 
264     ubyte[] buff = [0x08, 0x05, 0x11, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x18, 0x01];
265     auto foo = buff.fromProtobuf!Foo;
266     assert(buff.empty);
267     assert(foo.bar == 5);
268     assert(foo.baz == 1);
269     assert(foo.qux);
270 }
271 
272 unittest
273 {
274     static class EmptyMessage
275     {
276     }
277 
278     ubyte[] buff = [];
279     auto emptyMessage = buff.fromProtobuf!EmptyMessage;
280     assert(buff.empty);
281 }
282 
283 unittest
284 {
285     import std.array : array;
286     import std.typecons : Yes;
287 
288     import google.protobuf.encoding : toProtobuf;
289 
290     struct Foo
291     {
292         @Proto(1) int[] bar = protoDefaultValue!(int[]);
293         @Proto(2, Wire.zigzag, Yes.packed) int[] baz = protoDefaultValue!(int[]);
294     }
295 
296     Foo foo;
297     foo.bar = [1, 2];
298     foo.baz = [3, 4];
299     auto buff = foo.toProtobuf.array;
300 
301     foo = Foo.init;
302     assert(foo.bar.empty);
303     assert(foo.baz.empty);
304     foo = buff.fromProtobuf!Foo;
305     assert(foo.bar == [1, 2]);
306     assert(foo.baz == [3, 4]);
307 }
308 
309 unittest
310 {
311     import std.typecons : Yes;
312 
313     struct Foo
314     {
315         @Proto(1) int[] bar = protoDefaultValue!(int[]);
316         @Proto(2, Wire.zigzag, Yes.packed) int[] baz = protoDefaultValue!(int[]);
317     }
318 
319     // support packed and unpacked decoding
320     // bar: 1
321     // bar: [2, 3]
322     // baz: [4, 5]
323     // baz: 6
324     ubyte[] buff = [0x08, 0x01, 0x0a, 0x02, 0x02, 0x03, 0x12, 0x02, 0x08, 0x0a, 0x10, 0x0c];
325 
326     auto foo = buff.fromProtobuf!Foo;
327     assert(foo.bar == [1, 2, 3]);
328     assert(foo.baz == [4, 5, 6]);
329 }
330 
331 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
332 if (isInputRange!R && (isBoolean!T || isFloatingPoint!T || is(T == string) || is(T == bytes)))
333 {
334     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
335     static assert(validateProto!(proto, T));
336 
337     field = inputRange.fromProtobuf!T;
338 }
339 
340 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
341 if (isInputRange!R && isIntegral!T)
342 {
343     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
344     static assert(validateProto!(proto, T));
345 
346     field = inputRange.fromProtobuf!(T, proto.wire);
347 }
348 
349 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
350 if (isInputRange!R && isArray!T && !is(T == string) && !is(T == bytes) && proto.packed)
351 {
352     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
353     static assert(validateProto!(proto, T));
354 
355     field ~= inputRange.fromProtobuf!(T, proto.wire);
356 }
357 
358 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
359 if (isInputRange!R && isArray!T && !is(T == string) && !is(T == bytes) && !proto.packed)
360 {
361     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
362     static assert(validateProto!(proto, T));
363 
364     ElementType!T newElement = protoDefaultValue!(ElementType!T);
365     inputRange.fromProtobufByProto!proto(newElement);
366     field ~= newElement;
367 }
368 
369 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
370 if (isInputRange!R && isAssociativeArray!T)
371 {
372     import std.conv : to;
373     import std.exception : enforce;
374 
375     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
376     static assert(validateProto!(proto, T));
377 
378     enum keyProto = Proto(MapFieldTag.key, keyWireToWire(proto.wire));
379     enum valueProto = Proto(MapFieldTag.value, valueWireToWire(proto.wire));
380     KeyType!T key = protoDefaultValue!(KeyType!T);
381     ValueType!T value = protoDefaultValue!(ValueType!T);
382     R fieldRange = inputRange.takeLengthPrefixed;
383 
384     while (!fieldRange.empty)
385     {
386         auto tagWire = fieldRange.decodeTag;
387 
388         switch (tagWire.tag)
389         {
390         case MapFieldTag.key:
391             enum wireTypeExpected = wireType!(keyProto, KeyType!T);
392             enforce!ProtobufException(tagWire.wireType == wireTypeExpected, "Wrong wire format");
393             fieldRange.fromProtobufByProto!keyProto(key);
394             break;
395         case MapFieldTag.value:
396             enum wireTypeExpected = wireType!(valueProto, ValueType!T);
397             enforce!ProtobufException(tagWire.wireType == wireTypeExpected, "Wrong wire format");
398             fieldRange.fromProtobufByProto!valueProto(value);
399             break;
400         default:
401             throw new ProtobufException("Unexpected field tag " ~ tagWire.tag.to!string ~ " while decoding a map");
402         }
403     }
404     field[key] = value;
405 }
406 
407 private void fromProtobufByProto(Proto proto, T, R)(ref R inputRange, ref T field)
408 if (isInputRange!R && (is(T == class) || is(T == struct)))
409 {
410     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
411     static assert(validateProto!(proto, T));
412 
413     R fieldRange = inputRange.takeLengthPrefixed;
414 
415     field = fieldRange.fromProtobuf!T(field);
416 }
417 
418 void skipUnknown(R)(ref R inputRange, WireType wireType)
419 if (isInputRange!R)
420 {
421     import std.exception : enforce;
422 
423     static assert(is(ElementType!R == ubyte), "Input range should be an ubyte range");
424 
425     switch (wireType) with (WireType)
426     {
427     case varint:
428         inputRange.fromVarint;
429         break;
430     case bits64:
431         inputRange.takeN(8);
432         break;
433     case withLength:
434         inputRange.takeLengthPrefixed;
435         break;
436     case bits32:
437         inputRange.takeN(4);
438         break;
439     default:
440         enforce!ProtobufException(false, "Unknown wire format");
441         break;
442     }
443 }