diff options
Diffstat (limited to 'src/message.cpp')
-rw-r--r-- | src/message.cpp | 253 |
1 files changed, 196 insertions, 57 deletions
diff --git a/src/message.cpp b/src/message.cpp index 46cadd8..fcedd07 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -81,6 +81,82 @@ static std::uint8_t const RATCHET_KEY_TAG = 012; static std::uint8_t const COUNTER_TAG = 020; static std::uint8_t const CIPHERTEXT_TAG = 042; +std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint32_t value +) { + *(pos++) = tag; + return varint_encode(pos, value); +} + +std::uint8_t * encode( + std::uint8_t * pos, + std::uint8_t tag, + std::uint8_t * & value, std::size_t value_length +) { + *(pos++) = tag; + pos = varint_encode(pos, value_length); + value = pos; + return pos + value_length; +} + +std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint32_t & value, bool & has_value +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * value_start = pos; + pos = varint_skip(pos, end); + value = varint_decode<std::uint32_t>(value_start, pos); + has_value = true; + } + return pos; +} + + +std::uint8_t const * decode( + std::uint8_t const * pos, std::uint8_t const * end, + std::uint8_t tag, + std::uint8_t const * & value, std::size_t & value_length +) { + if (pos != end && *pos == tag) { + ++pos; + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode<std::size_t>(len_start, pos); + if (len > end - pos) return end; + value = pos; + value_length = len; + pos += len; + } + return pos; +} + +std::uint8_t const * skip_unknown( + std::uint8_t const * pos, std::uint8_t const * end +) { + if (pos != end) { + uint8_t tag = *pos; + if (tag & 0x7 == 0) { + pos = varint_skip(pos, end); + pos = varint_skip(pos, end); + } else if (tag & 0x7 == 2) { + pos = varint_skip(pos, end); + std::uint8_t const * len_start = pos; + pos = varint_skip(pos, end); + std::size_t len = varint_decode<std::size_t>(len_start, pos); + if (len > end - pos) return end; + pos += len; + } else { + return end; + } + } + return pos; +} + } // namespace @@ -109,75 +185,138 @@ void axolotl::encode_message( ) { std::uint8_t * pos = output; *(pos++) = version; - *(pos++) = COUNTER_TAG; - pos = varint_encode(pos, counter); - *(pos++) = RATCHET_KEY_TAG; - pos = varint_encode(pos, ratchet_key_length); - writer.ratchet_key = pos; - pos += ratchet_key_length; - *(pos++) = CIPHERTEXT_TAG; - pos = varint_encode(pos, ciphertext_length); - writer.ciphertext = pos; - pos += ciphertext_length; + pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length); + pos = encode(pos, COUNTER_TAG, counter); + pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length); } -std::size_t axolotl::decode_message( +void axolotl::decode_message( axolotl::MessageReader & reader, std::uint8_t const * input, std::size_t input_length, std::size_t mac_length ) { std::uint8_t const * pos = input; std::uint8_t const * end = input + input_length - mac_length; - std::uint8_t flags = 0; - std::size_t result = std::size_t(-1); - if (pos == end) return result; + std::uint8_t const * unknown = NULL; + + if (pos == end) return; reader.version = *(pos++); + reader.input = input; + reader.input_length = input_length; + reader.has_counter = false; + reader.ratchet_key = NULL; + reader.ciphertext = NULL; + while (pos != end) { - uint8_t tag = *(pos); - if (tag == COUNTER_TAG) { - ++pos; - std::uint8_t const * counter_start = pos; - pos = varint_skip(pos, end); - reader.counter = varint_decode<std::uint32_t>(counter_start, pos); - flags |= 1; - } else if (tag == RATCHET_KEY_TAG) { - ++pos; - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode<std::size_t>(len_start, pos); - if (len > end - pos) return result; - reader.ratchet_key_length = len; - reader.ratchet_key = pos; - pos += len; - flags |= 2; - } else if (tag == CIPHERTEXT_TAG) { - ++pos; - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode<std::size_t>(len_start, pos); - if (len > end - pos) return result; - reader.ciphertext_length = len; - reader.ciphertext = pos; - pos += len; - flags |= 4; - } else if (tag & 0x7 == 0) { - pos = varint_skip(pos, end); - pos = varint_skip(pos, end); - } else if (tag & 0x7 == 2) { - std::uint8_t const * len_start = pos; - pos = varint_skip(pos, end); - std::size_t len = varint_decode<std::size_t>(len_start, pos); - if (len > end - pos) return result; - pos += len; - } else { - return std::size_t(-1); + pos = decode( + pos, end, RATCHET_KEY_TAG, + reader.ratchet_key, reader.ratchet_key_length + ); + pos = decode( + pos, end, COUNTER_TAG, + reader.counter, reader.has_counter + ); + pos = decode( + pos, end, CIPHERTEXT_TAG, + reader.ciphertext, reader.ciphertext_length + ); + if (unknown == pos) { + pos == skip_unknown(pos, end); } + unknown = pos; } - if (flags == 0x7) { - reader.input = input; - reader.input_length = input_length; - return std::size_t(pos - input); +} + + +namespace { + +static std::uint8_t const REGISTRATION_ID_TAG = 050; +static std::uint8_t const ONE_TIME_KEY_ID_TAG = 010; +static std::uint8_t const BASE_KEY_TAG = 022; +static std::uint8_t const IDENTITY_KEY_TAG = 032; +static std::uint8_t const MESSAGE_TAG = 042; + +} // namespace + + +std::size_t axolotl::encode_one_time_key_message_length( + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length +) { + std::size_t length = VERSION_LENGTH; + length += 1 + varint_length(registration_id); + length += 1 + varint_length(one_time_key_id); + length += 1 + varstring_length(identity_key_length); + length += 1 + varstring_length(base_key_length); + length += 1 + varstring_length(message_length); + return length; +} + + +void axolotl::encode_one_time_key_message( + axolotl::PreKeyMessageWriter & writer, + std::uint8_t version, + std::uint32_t registration_id, + std::uint32_t one_time_key_id, + std::size_t identity_key_length, + std::size_t base_key_length, + std::size_t message_length, + std::uint8_t * output +) { + std::uint8_t * pos = output; + *(pos++) = version; + pos = encode(pos, REGISTRATION_ID_TAG, registration_id); + pos = encode(pos, ONE_TIME_KEY_ID_TAG, one_time_key_id); + pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length); + pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length); + pos = encode(pos, MESSAGE_TAG, writer.message, message_length); +} + + +void axolotl::decode_one_time_key_message( + PreKeyMessageReader & reader, + std::uint8_t const * input, std::size_t input_length +) { + std::uint8_t const * pos = input; + std::uint8_t const * end = input + input_length; + std::uint8_t const * unknown = NULL; + + if (pos == end) return; + reader.version = *(pos++); + reader.has_registration_id = false; + reader.has_one_time_key_id = false; + reader.identity_key = NULL; + reader.base_key = NULL; + reader.message = NULL; + + while (pos != end) { + pos = decode( + pos, end, REGISTRATION_ID_TAG, + reader.registration_id, reader.has_registration_id + ); + pos = decode( + pos, end, ONE_TIME_KEY_ID_TAG, + reader.one_time_key_id, reader.has_one_time_key_id + ); + pos = decode( + pos, end, BASE_KEY_TAG, + reader.base_key, reader.base_key_length + ); + pos = decode( + pos, end, IDENTITY_KEY_TAG, + reader.identity_key, reader.identity_key_length + ); + pos = decode( + pos, end, MESSAGE_TAG, + reader.message, reader.message_length + ); + if (unknown == pos) { + pos == skip_unknown(pos, end); + } + unknown = pos; } - return result; } |