diff options
-rw-r--r-- | include/olm/base64.hh | 24 | ||||
-rw-r--r-- | include/olm/cipher.hh | 4 | ||||
-rw-r--r-- | include/olm/crypto.hh | 26 | ||||
-rw-r--r-- | include/olm/memory.hh | 49 | ||||
-rw-r--r-- | include/olm/pickle.hh | 2 | ||||
-rw-r--r-- | include/olm/ratchet.hh | 2 | ||||
-rw-r--r-- | include/olm/session.hh | 46 | ||||
-rw-r--r-- | include/olm/utility.hh | 9 | ||||
-rw-r--r-- | src/account.cpp | 63 | ||||
-rw-r--r-- | src/cipher.cpp | 21 | ||||
-rw-r--r-- | src/crypto.cpp | 56 | ||||
-rw-r--r-- | src/message.cpp | 40 | ||||
-rw-r--r-- | src/olm.cpp | 47 | ||||
-rw-r--r-- | src/ratchet.cpp | 73 | ||||
-rw-r--r-- | src/session.cpp | 125 | ||||
-rw-r--r-- | src/utility.cpp | 6 | ||||
-rw-r--r-- | tests/test_base64.cpp | 1 | ||||
-rw-r--r-- | tests/test_crypto.cpp | 2 | ||||
-rw-r--r-- | tests/test_olm_decrypt.cpp | 75 |
19 files changed, 412 insertions, 259 deletions
diff --git a/include/olm/base64.hh b/include/olm/base64.hh index 018924a..da4641d 100644 --- a/include/olm/base64.hh +++ b/include/olm/base64.hh @@ -20,31 +20,45 @@ namespace olm { - +/** + * The number of bytes of unpadded base64 needed to encode a length of input. + */ static std::size_t encode_base64_length( std::size_t input_length ) { return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2; } - +/** + * Encode the raw input as unpadded base64. + * Writes encode_base64_length(input_length) bytes to the output buffer. + * The input can overlap with the last three quarters of the output buffer. + * That is, the input pointer may be output + output_length - input_length. + */ std::uint8_t * encode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ); - +/** + * The number of bytes of raw data a length of unpadded base64 will encode to. + * Returns std::size_t(-1) if the length is not a valid length for base64. + */ std::size_t decode_base64_length( std::size_t input_length ); - +/** + * Decodes the unpadded base64 input to raw bytes. + * Writes decode_base64_length(input_length) bytes to the output buffer. + * The output can overlap with the first three quarters of the input buffer. + * That is, the input pointers and output pointer may be the same. + */ std::uint8_t const * decode_base64( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ); - } // namespace olm diff --git a/include/olm/cipher.hh b/include/olm/cipher.hh index f71b3af..c561972 100644 --- a/include/olm/cipher.hh +++ b/include/olm/cipher.hh @@ -47,6 +47,8 @@ public: * output |--ciphertext_length-->| |---mac_length-->| * ciphertext * + * The plain-text pointers and cipher-text pointers may be the same. + * * Returns std::size_t(-1) if the length of the cipher-text or the output * buffer is too small. Otherwise returns the length of the output buffer. */ @@ -73,6 +75,8 @@ public: * input |--ciphertext_length-->| |---mac_length-->| * ciphertext * + * The plain-text pointers and cipher-text pointers may be the same. + * * Returns std::size_t(-1) if the length of the plain-text buffer is too * small or if the authentication check fails. Otherwise returns the length * of the plain text. diff --git a/include/olm/crypto.hh b/include/olm/crypto.hh index b845bfe..7a05f8d 100644 --- a/include/olm/crypto.hh +++ b/include/olm/crypto.hh @@ -20,28 +20,27 @@ namespace olm { +static const std::size_t KEY_LENGTH = 32; +static const std::size_t SIGNATURE_LENGTH = 64; +static const std::size_t IV_LENGTH = 16; struct Curve25519PublicKey { - static const int LENGTH = 32; - std::uint8_t public_key[32]; + std::uint8_t public_key[KEY_LENGTH]; }; struct Curve25519KeyPair : public Curve25519PublicKey { - static const int LENGTH = 64; - std::uint8_t private_key[32]; + std::uint8_t private_key[KEY_LENGTH]; }; struct Ed25519PublicKey { - static const int LENGTH = 32; - std::uint8_t public_key[32]; + std::uint8_t public_key[KEY_LENGTH]; }; struct Ed25519KeyPair : public Ed25519PublicKey { - static const int LENGTH = 64; - std::uint8_t private_key[32]; + std::uint8_t private_key[KEY_LENGTH]; }; @@ -52,9 +51,6 @@ void curve25519_generate_key( ); -const std::size_t CURVE25519_SHARED_SECRET_LENGTH = 32; - - /** Create a shared secret using our private key and their public key. * The output buffer must be at least 32 bytes long. */ void curve25519_shared_secret( @@ -109,14 +105,12 @@ bool ed25519_verify( struct Aes256Key { - static const int LENGTH = 32; - std::uint8_t key[32]; + std::uint8_t key[KEY_LENGTH]; }; struct Aes256Iv { - static const int LENGTH = 16; - std::uint8_t iv[16]; + std::uint8_t iv[IV_LENGTH]; }; @@ -156,7 +150,7 @@ void sha256( ); -const std::size_t HMAC_SHA256_OUTPUT_LENGTH = 32; +const std::size_t SHA256_OUTPUT_LENGTH = 32; /** HMAC: Keyed-Hashing for Message Authentication diff --git a/include/olm/memory.hh b/include/olm/memory.hh index b19c74b..128990a 100644 --- a/include/olm/memory.hh +++ b/include/olm/memory.hh @@ -14,6 +14,8 @@ */ #include <cstddef> #include <cstdint> +#include <cstring> +#include <type_traits> namespace olm { @@ -35,4 +37,51 @@ bool is_equal( std::size_t length ); +/** Check if two fixed size arrays are equals */ +template<typename T> +bool array_equal( + T const & array_a, + T const & array_b +) { + static_assert( + std::is_array<T>::value + && std::is_convertible<T, std::uint8_t *>::value + && sizeof(T) > 0, + "Arguments to array_equal must be std::uint8_t arrays[]." + ); + return is_equal(array_a, array_b, sizeof(T)); +} + +/** Copy into a fixed size array */ +template<typename T> +std::uint8_t const * load_array( + T & destination, + std::uint8_t const * source +) { + static_assert( + std::is_array<T>::value + && std::is_convertible<T, std::uint8_t *>::value + && sizeof(T) > 0, + "The first argument to load_array must be a std::uint8_t array[]." + ); + std::memcpy(destination, source, sizeof(T)); + return source + sizeof(T); +} + +/** Copy from a fixed size array */ +template<typename T> +std::uint8_t * store_array( + std::uint8_t * destination, + T const & source +) { + static_assert( + std::is_array<T>::value + && std::is_convertible<T, std::uint8_t *>::value + && sizeof(T) > 0, + "The second argument to store_array must be a std::uint8_t array[]." + ); + std::memcpy(destination, source, sizeof(T)); + return destination + sizeof(T); +} + } // namespace olm diff --git a/include/olm/pickle.hh b/include/olm/pickle.hh index 7a2bd1b..27f1f26 100644 --- a/include/olm/pickle.hh +++ b/include/olm/pickle.hh @@ -109,7 +109,7 @@ std::uint8_t const * unpickle( ) { std::uint32_t size; pos = unpickle(pos, end, size); - while (size--) { + while (size-- && pos != end) { T * value = list.insert(list.end()); pos = unpickle(pos, end, *value); } diff --git a/include/olm/ratchet.hh b/include/olm/ratchet.hh index 7274255..2393e5b 100644 --- a/include/olm/ratchet.hh +++ b/include/olm/ratchet.hh @@ -21,7 +21,7 @@ namespace olm { class Cipher; -typedef std::uint8_t SharedKey[32]; +typedef std::uint8_t SharedKey[olm::KEY_LENGTH]; struct ChainKey { diff --git a/include/olm/session.hh b/include/olm/session.hh index 993a8da..b21b0aa 100644 --- a/include/olm/session.hh +++ b/include/olm/session.hh @@ -39,8 +39,13 @@ struct Session { Curve25519PublicKey alice_base_key; Curve25519PublicKey bob_one_time_key; + /** The number of random bytes that are needed to create a new outbound + * session. This will be 64 bytes since two ephemeral keys are needed. */ std::size_t new_outbound_session_random_length(); + /** Start a new outbound session. Returns std::size_t(-1) on failure. On + * failure last_error will be set with an error code. The last_error will be + * NOT_ENOUGH_RANDOM if the number of random bytes was too small. */ std::size_t new_outbound_session( Account const & local_account, Curve25519PublicKey const & identity_key, @@ -48,42 +53,79 @@ struct Session { std::uint8_t const * random, std::size_t random_length ); + /** Start a new inbound session from a pre-key message. + * Returns std::size_t(-1) on failure. On failure last_error will be set + * with an error code. The last_error will be BAD_MESSAGE_FORMAT if + * the message headers could not be decoded. */ std::size_t new_inbound_session( Account & local_account, Curve25519PublicKey const * their_identity_key, - std::uint8_t const * one_time_key_message, std::size_t message_length + std::uint8_t const * pre_key_message, std::size_t message_length ); + /** The number of bytes written by session_id() */ std::size_t session_id_length(); + /** An identifier for this session. Generated by hashing the public keys + * used to create the session. Returns the length of the session id on + * success or std::size_t(-1) on failure. On failure last_error will be set + * with an error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if + * the id buffer was too small. */ std::size_t session_id( std::uint8_t * id, std::size_t id_length ); + /** True if this session can be used to decode an inbound pre-key message. + * This can be used to test whether a pre-key message should be decoded + * with an existing session or if a new session will need to be created. + * Returns true if the session is the same. Returns false if either the + * session does not match or the pre-key message could not be decoded. + */ bool matches_inbound_session( Curve25519PublicKey const * their_identity_key, - std::uint8_t const * one_time_key_message, std::size_t message_length + std::uint8_t const * pre_key_message, std::size_t message_length ); + /** Whether the next message will be a pre-key message or a normal message. + * An outbound session will send pre-key messages until it receives a + * message with a ratchet key. */ MessageType encrypt_message_type(); std::size_t encrypt_message_length( std::size_t plaintext_length ); + /** The number of bytes of random data the encrypt method will need to + * encrypt a message. This will be 32 bytes if the session needs to + * generate a new ephemeral key, or will be 0 bytes otherwise. */ std::size_t encrypt_random_length(); + /** Encrypt some plain-text. Returns the length of the encrypted message + * or std::size_t(-1) on failure. On failure last_error will be set with + * an error code. The last_error will be NOT_ENOUGH_RANDOM if the number + * of random bytes is too small. The last_error will be + * OUTPUT_BUFFER_TOO_SMALL if the output buffer is too small. */ std::size_t encrypt( std::uint8_t const * plaintext, std::size_t plaintext_length, std::uint8_t const * random, std::size_t random_length, std::uint8_t * message, std::size_t message_length ); + /** An upper bound on the number of bytes of plain-text the decrypt method + * will write for a given input message length. */ std::size_t decrypt_max_plaintext_length( MessageType message_type, std::uint8_t const * message, std::size_t message_length ); + /** Decrypt a message. Returns the length of the decrypted plain-text or + * std::size_t(-1) on failure. On failure last_error will be set with an + * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the + * plain-text buffer is too small. The last_error will be + * BAD_MESSAGE_VERSION if the message was encrypted with an unsupported + * version of the protocol. The last_error will be BAD_MESSAGE_FORMAT if + * the message headers could not be decoded. The last_error will be + * BAD_MESSAGE_MAC if the message could not be verified */ std::size_t decrypt( MessageType message_type, std::uint8_t const * message, std::size_t message_length, diff --git a/include/olm/utility.hh b/include/olm/utility.hh index 241d7e0..5329a59 100644 --- a/include/olm/utility.hh +++ b/include/olm/utility.hh @@ -31,13 +31,22 @@ struct Utility { ErrorCode last_error; + /** The length of a SHA-256 hash in bytes. */ std::size_t sha256_length(); + /** Compute a SHA-256 hash. Returns the length of the SHA-256 hash in bytes + * on success. Returns std::size_t(-1) on failure. On failure last_error + * will be set with an error code. If the output buffer was too small then + * last error will be OUTPUT_BUFFER_TOO_SMALL. */ std::size_t sha256( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output, std::size_t output_length ); + /** Verify a ed25519 signature. Returns std::size_t(0) on success. Returns + * std::size_t(-1) on failure or if the signature was invalid. On failure + * last_error will be set with an error code. If the signature was too short + * or was not a valid signature then last_error will be BAD_MESSAGE_MAC. */ std::size_t ed25519_verify( Ed25519PublicKey const & key, std::uint8_t const * message, std::size_t message_length, diff --git a/src/account.cpp b/src/account.cpp index cf6f0cb..43033c8 100644 --- a/src/account.cpp +++ b/src/account.cpp @@ -15,6 +15,7 @@ #include "olm/account.hh" #include "olm/base64.hh" #include "olm/pickle.hh" +#include "olm/memory.hh" olm::Account::Account( ) : next_one_time_key_id(0), @@ -26,7 +27,7 @@ olm::OneTimeKey const * olm::Account::lookup_key( olm::Curve25519PublicKey const & public_key ) { for (olm::OneTimeKey const & key : one_time_keys) { - if (0 == memcmp(key.key.public_key, public_key.public_key, 32)) { + if (olm::array_equal(key.key.public_key, public_key.public_key)) { return &key; } } @@ -38,7 +39,7 @@ std::size_t olm::Account::remove_key( ) { OneTimeKey * i; for (i = one_time_keys.begin(); i != one_time_keys.end(); ++i) { - if (0 == memcmp(i->key.public_key, public_key.public_key, 32)) { + if (olm::array_equal(i->key.public_key, public_key.public_key)) { std::uint32_t id = i->id; one_time_keys.erase(i); return id; @@ -48,7 +49,7 @@ std::size_t olm::Account::remove_key( } std::size_t olm::Account::new_account_random_length() { - return 2 * 32; + return 2 * olm::KEY_LENGTH; } std::size_t olm::Account::new_account( @@ -60,35 +61,19 @@ std::size_t olm::Account::new_account( } olm::ed25519_generate_key(random, identity_keys.ed25519_key); - random += 32; + random += KEY_LENGTH; olm::curve25519_generate_key(random, identity_keys.curve25519_key); - random += 32; return 0; } namespace { - -namespace { uint8_t KEY_JSON_ED25519[] = "\"ed25519\":"; uint8_t KEY_JSON_CURVE25519[] = "\"curve25519\":"; -} - - -std::size_t count_digits( - std::uint64_t value -) { - std::size_t digits = 0; - do { - digits++; - value /= 10; - } while (value); - return digits; -} template<typename T> -std::uint8_t * write_string( +static std::uint8_t * write_string( std::uint8_t * pos, T const & value ) { @@ -96,27 +81,6 @@ std::uint8_t * write_string( return pos + (sizeof(T) - 1); } -std::uint8_t * write_string( - std::uint8_t * pos, - std::uint8_t const * value, std::size_t value_length -) { - std::memcpy(pos, value, value_length); - return pos + value_length; -} - -std::uint8_t * write_digits( - std::uint8_t * pos, - std::uint64_t value -) { - size_t digits = count_digits(value); - pos += digits; - do { - *(--pos) = '0' + (value % 10); - value /= 10; - } while (value); - return pos + digits; -} - } @@ -143,7 +107,6 @@ std::size_t olm::Account::get_identity_json( std::uint8_t * identity_json, std::size_t identity_json_length ) { std::uint8_t * pos = identity_json; - std::uint8_t signature[64]; size_t expected_length = get_identity_json_length(); if (identity_json_length < expected_length) { @@ -174,7 +137,7 @@ std::size_t olm::Account::get_identity_json( std::size_t olm::Account::signature_length( ) { - return 64; + return olm::SIGNATURE_LENGTH; } @@ -196,19 +159,20 @@ std::size_t olm::Account::sign( std::size_t olm::Account::get_one_time_keys_json_length( ) { std::size_t length = 0; + bool is_empty = true; for (auto const & key : one_time_keys) { if (key.published) { continue; } + is_empty = false; length += 2; /* {" */ length += olm::encode_base64_length(olm::pickle_length(key.id)); length += 3; /* ":" */ length += olm::encode_base64_length(sizeof(key.key.public_key)); length += 1; /* " */ } - if (length == 0) { - /* The list was empty. Add a byte for the opening '{' */ - length = 1; + if (is_empty) { + length += 1; /* { */ } length += 3; /* }{} */ length += sizeof(KEY_JSON_CURVE25519) - 1; @@ -244,6 +208,7 @@ std::size_t olm::Account::get_one_time_keys_json( sep = ','; } if (sep != ',') { + /* The list was empty */ *(pos++) = sep; } *(pos++) = '}'; @@ -273,7 +238,7 @@ std::size_t olm::Account::max_number_of_one_time_keys( std::size_t olm::Account::generate_one_time_keys_random_length( std::size_t number_of_keys ) { - return 32 * number_of_keys; + return olm::KEY_LENGTH * number_of_keys; } std::size_t olm::Account::generate_one_time_keys( @@ -289,7 +254,7 @@ std::size_t olm::Account::generate_one_time_keys( key.id = ++next_one_time_key_id; key.published = false; olm::curve25519_generate_key(random, key.key); - random += 32; + random += olm::KEY_LENGTH; } return number_of_keys; } diff --git a/src/cipher.cpp b/src/cipher.cpp index 2202746..7bb11b8 100644 --- a/src/cipher.cpp +++ b/src/cipher.cpp @@ -23,11 +23,9 @@ olm::Cipher::~Cipher() { namespace { -static const std::size_t SHA256_LENGTH = 32; - struct DerivedKeys { olm::Aes256Key aes_key; - std::uint8_t mac_key[SHA256_LENGTH]; + std::uint8_t mac_key[olm::KEY_LENGTH]; olm::Aes256Iv aes_iv; }; @@ -37,16 +35,17 @@ static void derive_keys( std::uint8_t const * key, std::size_t key_length, DerivedKeys & keys ) { - std::uint8_t derived_secrets[80]; + std::uint8_t derived_secrets[2 * olm::KEY_LENGTH + olm::IV_LENGTH]; olm::hkdf_sha256( key, key_length, nullptr, 0, kdf_info, kdf_info_length, derived_secrets, sizeof(derived_secrets) ); - std::memcpy(keys.aes_key.key, derived_secrets, 32); - std::memcpy(keys.mac_key, derived_secrets + 32, 32); - std::memcpy(keys.aes_iv.iv, derived_secrets + 64, 16); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(keys.aes_key.key, pos); + pos = olm::load_array(keys.mac_key, pos); + pos = olm::load_array(keys.aes_iv.iv, pos); olm::unset(derived_secrets); } @@ -84,7 +83,7 @@ std::size_t olm::CipherAesSha256::encrypt( return std::size_t(-1); } struct DerivedKeys keys; - std::uint8_t mac[SHA256_LENGTH]; + std::uint8_t mac[olm::SHA256_OUTPUT_LENGTH]; derive_keys(kdf_info, kdf_info_length, key, key_length, keys); @@ -93,7 +92,7 @@ std::size_t olm::CipherAesSha256::encrypt( ); olm::hmac_sha256( - keys.mac_key, SHA256_LENGTH, output, output_length - MAC_LENGTH, mac + keys.mac_key, olm::KEY_LENGTH, output, output_length - MAC_LENGTH, mac ); std::memcpy(output + output_length - MAC_LENGTH, mac, MAC_LENGTH); @@ -116,12 +115,12 @@ std::size_t olm::CipherAesSha256::decrypt( std::uint8_t * plaintext, std::size_t max_plaintext_length ) const { DerivedKeys keys; - std::uint8_t mac[SHA256_LENGTH]; + std::uint8_t mac[olm::SHA256_OUTPUT_LENGTH]; derive_keys(kdf_info, kdf_info_length, key, key_length, keys); olm::hmac_sha256( - keys.mac_key, SHA256_LENGTH, input, input_length - MAC_LENGTH, mac + keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac ); std::uint8_t const * input_mac = input + input_length - MAC_LENGTH; diff --git a/src/crypto.cpp b/src/crypto.cpp index ed89d64..8024355 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -66,8 +66,9 @@ void ed25519_keypair( namespace { static const std::uint8_t CURVE25519_BASEPOINT[32] = {9}; +static const std::size_t AES_KEY_SCHEDULE_LENGTH = 60; +static const std::size_t AES_KEY_BITS = 8 * olm::KEY_LENGTH; static const std::size_t AES_BLOCK_LENGTH = 16; -static const std::size_t SHA256_HASH_LENGTH = 32; static const std::size_t SHA256_BLOCK_LENGTH = 64; static const std::uint8_t HKDF_DEFAULT_SALT[32] = {}; @@ -99,7 +100,7 @@ inline static void hmac_sha256_key( } -inline void hmac_sha256_init( +inline static void hmac_sha256_init( ::SHA256_CTX * context, std::uint8_t const * hmac_key ) { @@ -114,12 +115,12 @@ inline void hmac_sha256_init( } -inline void hmac_sha256_final( +inline static void hmac_sha256_final( ::SHA256_CTX * context, std::uint8_t const * hmac_key, std::uint8_t * output ) { - std::uint8_t o_pad[SHA256_BLOCK_LENGTH + SHA256_HASH_LENGTH]; + std::uint8_t o_pad[SHA256_BLOCK_LENGTH + olm::SHA256_OUTPUT_LENGTH]; std::memcpy(o_pad, hmac_key, SHA256_BLOCK_LENGTH); for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) { o_pad[i] ^= 0x5C; @@ -140,7 +141,7 @@ void olm::curve25519_generate_key( std::uint8_t const * random_32_bytes, olm::Curve25519KeyPair & key_pair ) { - std::memcpy(key_pair.private_key, random_32_bytes, 32); + std::memcpy(key_pair.private_key, random_32_bytes, KEY_LENGTH); ::curve25519_donna( key_pair.public_key, key_pair.private_key, CURVE25519_BASEPOINT ); @@ -161,9 +162,9 @@ void olm::curve25519_sign( std::uint8_t const * message, std::size_t message_length, std::uint8_t * output ) { - std::uint8_t private_key[32]; - std::uint8_t public_key[32]; - std::memcpy(private_key, our_key.private_key, 32); + std::uint8_t private_key[KEY_LENGTH]; + std::uint8_t public_key[KEY_LENGTH]; + std::memcpy(private_key, our_key.private_key, KEY_LENGTH); ::ed25519_keypair(private_key, public_key); ::ed25519_sign( output, @@ -179,10 +180,10 @@ bool olm::curve25519_verify( std::uint8_t const * message, std::size_t message_length, std::uint8_t const * signature ) { - std::uint8_t public_key[32]; - std::uint8_t signature_buffer[64]; - std::memcpy(public_key, their_key.public_key, 32); - std::memcpy(signature_buffer, signature, 64); + std::uint8_t public_key[KEY_LENGTH]; + std::uint8_t signature_buffer[SIGNATURE_LENGTH]; + std::memcpy(public_key, their_key.public_key, KEY_LENGTH); + std::memcpy(signature_buffer, signature, SIGNATURE_LENGTH); ::convert_curve25519_to_ed25519(public_key, signature_buffer); return 0 != ::ed25519_verify( signature, @@ -196,7 +197,7 @@ void olm::ed25519_generate_key( std::uint8_t const * random_32_bytes, olm::Ed25519KeyPair & key_pair ) { - std::memcpy(key_pair.private_key, random_32_bytes, 32); + std::memcpy(key_pair.private_key, random_32_bytes, KEY_LENGTH); ::ed25519_keypair(key_pair.private_key, key_pair.public_key); } @@ -240,13 +241,13 @@ void olm::aes_encrypt_cbc( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ) { - std::uint32_t key_schedule[60]; - ::aes_key_setup(key.key, key_schedule, 256); + std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH]; + ::aes_key_setup(key.key, key_schedule, AES_KEY_BITS); std::uint8_t input_block[AES_BLOCK_LENGTH]; std::memcpy(input_block, iv.iv, AES_BLOCK_LENGTH); while (input_length >= AES_BLOCK_LENGTH) { xor_block<AES_BLOCK_LENGTH>(input_block, input); - ::aes_encrypt(input_block, output, key_schedule, 256); + ::aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS); std::memcpy(input_block, output, AES_BLOCK_LENGTH); input += AES_BLOCK_LENGTH; output += AES_BLOCK_LENGTH; @@ -259,7 +260,7 @@ void olm::aes_encrypt_cbc( for (; i < AES_BLOCK_LENGTH; ++i) { input_block[i] ^= AES_BLOCK_LENGTH - input_length; } - ::aes_encrypt(input_block, output, key_schedule, 256); + ::aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS); olm::unset(key_schedule); olm::unset(input_block); } @@ -271,14 +272,14 @@ std::size_t olm::aes_decrypt_cbc( std::uint8_t const * input, std::size_t input_length, std::uint8_t * output ) { - std::uint32_t key_schedule[60]; - ::aes_key_setup(key.key, key_schedule, 256); + std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH]; + ::aes_key_setup(key.key, key_schedule, AES_KEY_BITS); std::uint8_t block1[AES_BLOCK_LENGTH]; std::uint8_t block2[AES_BLOCK_LENGTH]; std::memcpy(block1, iv.iv, AES_BLOCK_LENGTH); for (std::size_t i = 0; i < input_length; i += AES_BLOCK_LENGTH) { std::memcpy(block2, &input[i], AES_BLOCK_LENGTH); - ::aes_decrypt(&input[i], &output[i], key_schedule, 256); + ::aes_decrypt(&input[i], &output[i], key_schedule, AES_KEY_BITS); xor_block<AES_BLOCK_LENGTH>(&output[i], block1); std::memcpy(block1, block2, AES_BLOCK_LENGTH); } @@ -301,6 +302,7 @@ void olm::sha256( olm::unset(context); } + void olm::hmac_sha256( std::uint8_t const * key, std::size_t key_length, std::uint8_t const * input, std::size_t input_length, @@ -325,7 +327,7 @@ void olm::hkdf_sha256( ) { ::SHA256_CTX context; std::uint8_t hmac_key[SHA256_BLOCK_LENGTH]; - std::uint8_t step_result[SHA256_HASH_LENGTH]; + std::uint8_t step_result[olm::SHA256_OUTPUT_LENGTH]; std::size_t bytes_remaining = output_length; std::uint8_t iteration = 1; if (!salt) { @@ -337,20 +339,20 @@ void olm::hkdf_sha256( hmac_sha256_init(&context, hmac_key); ::sha256_update(&context, input, input_length); hmac_sha256_final(&context, hmac_key, step_result); - hmac_sha256_key(step_result, SHA256_HASH_LENGTH, hmac_key); + hmac_sha256_key(step_result, olm::SHA256_OUTPUT_LENGTH, hmac_key); /* Extract */ hmac_sha256_init(&context, hmac_key); ::sha256_update(&context, info, info_length); ::sha256_update(&context, &iteration, 1); hmac_sha256_final(&context, hmac_key, step_result); - while (bytes_remaining > SHA256_HASH_LENGTH) { - std::memcpy(output, step_result, SHA256_HASH_LENGTH); - output += SHA256_HASH_LENGTH; - bytes_remaining -= SHA256_HASH_LENGTH; + while (bytes_remaining > olm::SHA256_OUTPUT_LENGTH) { + std::memcpy(output, step_result, olm::SHA256_OUTPUT_LENGTH); + output += olm::SHA256_OUTPUT_LENGTH; + bytes_remaining -= olm::SHA256_OUTPUT_LENGTH; iteration ++; hmac_sha256_init(&context, hmac_key); - ::sha256_update(&context, step_result, SHA256_HASH_LENGTH); + ::sha256_update(&context, step_result, olm::SHA256_OUTPUT_LENGTH); ::sha256_update(&context, info, info_length); ::sha256_update(&context, &iteration, 1); hmac_sha256_final(&context, hmac_key, step_result); diff --git a/src/message.cpp b/src/message.cpp index 93473b9..05f707f 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -17,7 +17,7 @@ namespace { template<typename T> -std::size_t varint_length( +static std::size_t varint_length( T value ) { std::size_t result = 1; @@ -30,7 +30,7 @@ std::size_t varint_length( template<typename T> -std::uint8_t * varint_encode( +static std::uint8_t * varint_encode( std::uint8_t * output, T value ) { @@ -44,11 +44,14 @@ std::uint8_t * varint_encode( template<typename T> -T varint_decode( +static T varint_decode( std::uint8_t const * varint_start, std::uint8_t const * varint_end ) { T value = 0; + if (varint_end == varint_start) { + return 0; + } do { value <<= 7; value |= 0x7F & *(--varint_end); @@ -57,7 +60,7 @@ T varint_decode( } -std::uint8_t const * varint_skip( +static std::uint8_t const * varint_skip( std::uint8_t const * input, std::uint8_t const * input_end ) { @@ -71,7 +74,7 @@ std::uint8_t const * varint_skip( } -std::size_t varstring_length( +static std::size_t varstring_length( std::size_t string_length ) { return varint_length(string_length) + string_length; @@ -82,7 +85,7 @@ static std::uint8_t const RATCHET_KEY_TAG = 012; static std::uint8_t const COUNTER_TAG = 020; static std::uint8_t const CIPHERTEXT_TAG = 042; -std::uint8_t * encode( +static std::uint8_t * encode( std::uint8_t * pos, std::uint8_t tag, std::uint32_t value @@ -91,7 +94,7 @@ std::uint8_t * encode( return varint_encode(pos, value); } -std::uint8_t * encode( +static std::uint8_t * encode( std::uint8_t * pos, std::uint8_t tag, std::uint8_t * & value, std::size_t value_length @@ -102,7 +105,7 @@ std::uint8_t * encode( return pos + value_length; } -std::uint8_t const * decode( +static std::uint8_t const * decode( std::uint8_t const * pos, std::uint8_t const * end, std::uint8_t tag, std::uint32_t & value, bool & has_value @@ -118,7 +121,7 @@ std::uint8_t const * decode( } -std::uint8_t const * decode( +static std::uint8_t const * decode( std::uint8_t const * pos, std::uint8_t const * end, std::uint8_t tag, std::uint8_t const * & value, std::size_t & value_length @@ -136,7 +139,7 @@ std::uint8_t const * decode( return pos; } -std::uint8_t const * skip_unknown( +static std::uint8_t const * skip_unknown( std::uint8_t const * pos, std::uint8_t const * end ) { if (pos != end) { @@ -201,13 +204,17 @@ void olm::decode_message( std::uint8_t const * end = input + input_length - mac_length; std::uint8_t const * unknown = nullptr; - if (pos == end) return; - reader.version = *(pos++); reader.input = input; reader.input_length = input_length; reader.has_counter = false; reader.ratchet_key = nullptr; + reader.ratchet_key_length = 0; reader.ciphertext = nullptr; + reader.ciphertext_length = 0; + + if (pos == end) return; + if (input_length < mac_length) return; + reader.version = *(pos++); while (pos != end) { pos = decode( @@ -281,12 +288,17 @@ void olm::decode_one_time_key_message( std::uint8_t const * end = input + input_length; std::uint8_t const * unknown = nullptr; - if (pos == end) return; - reader.version = *(pos++); reader.one_time_key = nullptr; + reader.one_time_key_length = 0; reader.identity_key = nullptr; + reader.identity_key_length = 0; reader.base_key = nullptr; + reader.base_key_length = 0; reader.message = nullptr; + reader.message_length = 0; + + if (pos == end) return; + reader.version = *(pos++); while (pos != end) { pos = decode( diff --git a/src/olm.cpp b/src/olm.cpp index 28b194f..63f3d83 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -520,8 +520,13 @@ size_t olm_create_outbound_session( void const * their_one_time_key, size_t their_one_time_key_length, void * random, size_t random_length ) { - if (olm::decode_base64_length(their_identity_key_length) != 32 - || olm::decode_base64_length(their_one_time_key_length) != 32 + std::uint8_t const * id_key = from_c(their_identity_key); + std::uint8_t const * ot_key = from_c(their_one_time_key); + std::size_t id_key_length = their_identity_key_length; + std::size_t ot_key_length = their_one_time_key_length; + + if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH + || olm::decode_base64_length(ot_key_length) != olm::KEY_LENGTH ) { from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64; return std::size_t(-1); @@ -529,14 +534,8 @@ size_t olm_create_outbound_session( olm::Curve25519PublicKey identity_key; olm::Curve25519PublicKey one_time_key; - olm::decode_base64( - from_c(their_identity_key), their_identity_key_length, - identity_key.public_key - ); - olm::decode_base64( - from_c(their_one_time_key), their_one_time_key_length, - one_time_key.public_key - ); + olm::decode_base64(id_key, id_key_length, identity_key.public_key); + olm::decode_base64(ot_key, ot_key_length, one_time_key.public_key); size_t result = from_c(session)->new_outbound_session( *from_c(account), identity_key, one_time_key, @@ -570,15 +569,15 @@ size_t olm_create_inbound_session_from( void const * their_identity_key, size_t their_identity_key_length, void * one_time_key_message, size_t message_length ) { - if (olm::decode_base64_length(their_identity_key_length) != 32) { + std::uint8_t const * id_key = from_c(their_identity_key); + std::size_t id_key_length = their_identity_key_length; + + if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH) { from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64; return std::size_t(-1); } olm::Curve25519PublicKey identity_key; - olm::decode_base64( - from_c(their_identity_key), their_identity_key_length, - identity_key.public_key - ); + olm::decode_base64(id_key, id_key_length, identity_key.public_key); std::size_t raw_length = b64_input( from_c(one_time_key_message), message_length, from_c(session)->last_error @@ -641,15 +640,15 @@ size_t olm_matches_inbound_session_from( void const * their_identity_key, size_t their_identity_key_length, void * one_time_key_message, size_t message_length ) { - if (olm::decode_base64_length(their_identity_key_length) != 32) { + std::uint8_t const * id_key = from_c(their_identity_key); + std::size_t id_key_length = their_identity_key_length; + + if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH) { from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64; return std::size_t(-1); } olm::Curve25519PublicKey identity_key; - olm::decode_base64( - from_c(their_identity_key), their_identity_key_length, - identity_key.public_key - ); + olm::decode_base64(id_key, id_key_length, identity_key.public_key); std::size_t raw_length = b64_input( from_c(one_time_key_message), message_length, from_c(session)->last_error @@ -800,15 +799,12 @@ size_t olm_ed25519_verify( void const * message, size_t message_length, void * signature, size_t signature_length ) { - if (olm::decode_base64_length(key_length) != 32) { + if (olm::decode_base64_length(key_length) != olm::KEY_LENGTH) { from_c(utility)->last_error = olm::ErrorCode::INVALID_BASE64; return std::size_t(-1); } olm::Ed25519PublicKey verify_key; - olm::decode_base64( - from_c(key), key_length, - verify_key.public_key - ); + olm::decode_base64(from_c(key), key_length, verify_key.public_key); std::size_t raw_signature_length = b64_input( from_c(signature), signature_length, from_c(utility)->last_error ); @@ -822,5 +818,4 @@ size_t olm_ed25519_verify( ); } - } diff --git a/src/ratchet.cpp b/src/ratchet.cpp index b9108db..5ef1e56 100644 --- a/src/ratchet.cpp +++ b/src/ratchet.cpp @@ -23,13 +23,12 @@ namespace { -std::uint8_t PROTOCOL_VERSION = 3; -std::size_t KEY_LENGTH = olm::Curve25519PublicKey::LENGTH; -std::uint8_t MESSAGE_KEY_SEED[1] = {0x01}; -std::uint8_t CHAIN_KEY_SEED[1] = {0x02}; -std::size_t MAX_MESSAGE_GAP = 2000; +static const std::uint8_t PROTOCOL_VERSION = 3; +static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01}; +static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02}; +static const std::size_t MAX_MESSAGE_GAP = 2000; -void create_chain_key( +static void create_chain_key( olm::SharedKey const & root_key, olm::Curve25519KeyPair const & our_key, olm::Curve25519PublicKey const & their_key, @@ -39,22 +38,23 @@ void create_chain_key( ) { olm::SharedKey secret; olm::curve25519_shared_secret(our_key, their_key, secret); - std::uint8_t derived_secrets[64]; + std::uint8_t derived_secrets[2 * olm::KEY_LENGTH]; olm::hkdf_sha256( secret, sizeof(secret), root_key, sizeof(root_key), info.ratchet_info, info.ratchet_info_length, derived_secrets, sizeof(derived_secrets) ); - std::memcpy(new_root_key, derived_secrets, 32); - std::memcpy(new_chain_key.key, derived_secrets + 32, 32); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(new_root_key, pos); + pos = olm::load_array(new_chain_key.key, pos); new_chain_key.index = 0; olm::unset(derived_secrets); olm::unset(secret); } -void advance_chain_key( +static void advance_chain_key( olm::ChainKey const & chain_key, olm::ChainKey & new_chain_key ) { @@ -67,7 +67,7 @@ void advance_chain_key( } -void create_message_keys( +static void create_message_keys( olm::ChainKey const & chain_key, olm::KdfInfo const & info, olm::MessageKey & message_key @@ -81,7 +81,7 @@ void create_message_keys( } -std::size_t verify_mac_and_decrypt( +static std::size_t verify_mac_and_decrypt( olm::Cipher const & cipher, olm::MessageKey const & message_key, olm::MessageReader const & reader, @@ -96,7 +96,7 @@ std::size_t verify_mac_and_decrypt( } -std::size_t verify_mac_and_decrypt_for_existing_chain( +static std::size_t verify_mac_and_decrypt_for_existing_chain( olm::Ratchet const & session, olm::ChainKey const & chain, olm::MessageReader const & reader, @@ -130,7 +130,7 @@ std::size_t verify_mac_and_decrypt_for_existing_chain( } -std::size_t verify_mac_and_decrypt_for_new_chain( +static std::size_t verify_mac_and_decrypt_for_new_chain( olm::Ratchet const & session, olm::MessageReader const & reader, std::uint8_t * plaintext, std::size_t max_plaintext_length @@ -148,9 +148,7 @@ std::size_t verify_mac_and_decrypt_for_new_chain( if (reader.counter > MAX_MESSAGE_GAP) { return std::size_t(-1); } - std::memcpy( - new_chain.ratchet_key.public_key, reader.ratchet_key, KEY_LENGTH - ); + olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key); create_chain_key( session.root_key, session.sender_chain[0].ratchet_key, @@ -183,7 +181,7 @@ void olm::Ratchet::initialise_as_bob( std::uint8_t const * shared_secret, std::size_t shared_secret_length, olm::Curve25519PublicKey const & their_ratchet_key ) { - std::uint8_t derived_secrets[64]; + std::uint8_t derived_secrets[2 * olm::KEY_LENGTH]; olm::hkdf_sha256( shared_secret, shared_secret_length, nullptr, 0, @@ -192,8 +190,9 @@ void olm::Ratchet::initialise_as_bob( ); receiver_chains.insert(); receiver_chains[0].chain_key.index = 0; - std::memcpy(root_key, derived_secrets, 32); - std::memcpy(receiver_chains[0].chain_key.key, derived_secrets + 32, 32); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(root_key, pos); + pos = olm::load_array(receiver_chains[0].chain_key.key, pos); receiver_chains[0].ratchet_key = their_ratchet_key; olm::unset(derived_secrets); } @@ -203,7 +202,7 @@ void olm::Ratchet::initialise_as_alice( std::uint8_t const * shared_secret, std::size_t shared_secret_length, olm::Curve25519KeyPair const & our_ratchet_key ) { - std::uint8_t derived_secrets[64]; + std::uint8_t derived_secrets[2 * olm::KEY_LENGTH]; olm::hkdf_sha256( shared_secret, shared_secret_length, nullptr, 0, @@ -212,8 +211,9 @@ void olm::Ratchet::initialise_as_alice( ); sender_chain.insert(); sender_chain[0].chain_key.index = 0; - std::memcpy(root_key, derived_secrets, 32); - std::memcpy(sender_chain[0].chain_key.key, derived_secrets + 32, 32); + std::uint8_t const * pos = derived_secrets; + pos = olm::load_array(root_key, pos); + pos = olm::load_array(sender_chain[0].chain_key.key, pos); sender_chain[0].ratchet_key = our_ratchet_key; olm::unset(derived_secrets); } @@ -224,7 +224,7 @@ namespace olm { static std::size_t pickle_length( const olm::SharedKey & value ) { - return KEY_LENGTH; + return olm::KEY_LENGTH; } @@ -232,7 +232,7 @@ static std::uint8_t * pickle( std::uint8_t * pos, const olm::SharedKey & value ) { - return olm::pickle_bytes(pos, value, KEY_LENGTH); + return olm::pickle_bytes(pos, value, olm::KEY_LENGTH); } @@ -240,7 +240,7 @@ static std::uint8_t const * unpickle( std::uint8_t const * pos, std::uint8_t const * end, olm::SharedKey & value ) { - return olm::unpickle_bytes(pos, end, value, KEY_LENGTH); + return olm::unpickle_bytes(pos, end, value, olm::KEY_LENGTH); } @@ -349,7 +349,7 @@ std::size_t olm::pickle_length( olm::Ratchet const & value ) { std::size_t length = 0; - length += KEY_LENGTH; + length += olm::KEY_LENGTH; length += olm::pickle_length(value.sender_chain); length += olm::pickle_length(value.receiver_chains); length += olm::pickle_length(value.skipped_message_keys); @@ -391,13 +391,13 @@ std::size_t olm::Ratchet::encrypt_output_length( plaintext_length ); return olm::encode_message_length( - counter, KEY_LENGTH, padded, ratchet_cipher.mac_length() + counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length() ); } std::size_t olm::Ratchet::encrypt_random_length() { - return sender_chain.empty() ? KEY_LENGTH : 0; + return sender_chain.empty() ? olm::KEY_LENGTH : 0; } @@ -442,10 +442,11 @@ std::size_t olm::Ratchet::encrypt( olm::MessageWriter writer; olm::encode_message( - writer, PROTOCOL_VERSION, counter, KEY_LENGTH, ciphertext_length, output + writer, PROTOCOL_VERSION, counter, olm::KEY_LENGTH, ciphertext_length, + output ); - std::memcpy(writer.ratchet_key, ratchet_key.public_key, KEY_LENGTH); + olm::store_array(writer.ratchet_key, ratchet_key.public_key); ratchet_cipher.encrypt( keys.key, sizeof(keys.key), @@ -504,7 +505,7 @@ std::size_t olm::Ratchet::decrypt( return std::size_t(-1); } - if (reader.ratchet_key_length != KEY_LENGTH) { + if (reader.ratchet_key_length != olm::KEY_LENGTH) { last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT; return std::size_t(-1); } @@ -513,7 +514,7 @@ std::size_t olm::Ratchet::decrypt( for (olm::ReceiverChain & receiver_chain : receiver_chains) { if (0 == std::memcmp( receiver_chain.ratchet_key.public_key, reader.ratchet_key, - KEY_LENGTH + olm::KEY_LENGTH )) { chain = &receiver_chain; break; @@ -533,7 +534,7 @@ std::size_t olm::Ratchet::decrypt( if (reader.counter == skipped.message_key.index && 0 == std::memcmp( skipped.ratchet_key.public_key, reader.ratchet_key, - KEY_LENGTH + olm::KEY_LENGTH ) ) { /* Found the key for this message. Check the MAC. */ @@ -569,9 +570,7 @@ std::size_t olm::Ratchet::decrypt( * We can discard our previous empheral ratchet key. * We will generate a new key when we send the next message. */ chain = receiver_chains.insert(); - std::memcpy( - chain->ratchet_key.public_key, reader.ratchet_key, KEY_LENGTH - ); + olm::load_array(chain->ratchet_key.public_key, reader.ratchet_key); create_chain_key( root_key, sender_chain[0].ratchet_key, chain->ratchet_key, kdf_info, root_key, chain->chain_key diff --git a/src/session.cpp b/src/session.cpp index b17a059..71c80e4 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -24,7 +24,6 @@ namespace { -static const std::size_t KEY_LENGTH = 32; static const std::uint8_t PROTOCOL_VERSION = 0x3; static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT"; @@ -51,7 +50,7 @@ olm::Session::Session( std::size_t olm::Session::new_outbound_session_random_length() { - return KEY_LENGTH * 2; + return olm::KEY_LENGTH * 2; } @@ -66,54 +65,54 @@ std::size_t olm::Session::new_outbound_session( return std::size_t(-1); } - Curve25519KeyPair base_key; + olm::Curve25519KeyPair base_key; olm::curve25519_generate_key(random, base_key); - Curve25519KeyPair ratchet_key; - olm::curve25519_generate_key(random + 32, ratchet_key); + olm::Curve25519KeyPair ratchet_key; + olm::curve25519_generate_key(random + olm::KEY_LENGTH, ratchet_key); + + olm::Curve25519KeyPair const & alice_identity_key_pair = ( + local_account.identity_keys.curve25519_key + ); received_message = false; - alice_identity_key = local_account.identity_keys.curve25519_key; + alice_identity_key = alice_identity_key_pair; alice_base_key = base_key; bob_one_time_key = one_time_key; - std::uint8_t shared_secret[96]; + std::uint8_t secret[3 * olm::KEY_LENGTH]; + std::uint8_t * pos = secret; - olm::curve25519_shared_secret( - local_account.identity_keys.curve25519_key, - one_time_key, shared_secret - ); - olm::curve25519_shared_secret( - base_key, identity_key, shared_secret + 32 - ); - olm::curve25519_shared_secret( - base_key, one_time_key, shared_secret + 64 - ); + olm::curve25519_shared_secret(alice_identity_key_pair, one_time_key, pos); + pos += olm::KEY_LENGTH; + olm::curve25519_shared_secret(base_key, identity_key, pos); + pos += olm::KEY_LENGTH; + olm::curve25519_shared_secret(base_key, one_time_key, pos); - ratchet.initialise_as_alice(shared_secret, 96, ratchet_key); + ratchet.initialise_as_alice(secret, sizeof(secret), ratchet_key); olm::unset(base_key); olm::unset(ratchet_key); - olm::unset(shared_secret); + olm::unset(secret); return std::size_t(0); } namespace { -bool check_message_fields( +static bool check_message_fields( olm::PreKeyMessageReader & reader, bool have_their_identity_key ) { bool ok = true; ok = ok && (have_their_identity_key || reader.identity_key); if (reader.identity_key) { - ok = ok && reader.identity_key_length == KEY_LENGTH; + ok = ok && reader.identity_key_length == olm::KEY_LENGTH; } ok = ok && reader.message; ok = ok && reader.base_key; - ok = ok && reader.base_key_length == KEY_LENGTH; + ok = ok && reader.base_key_length == olm::KEY_LENGTH; ok = ok && reader.one_time_key; - ok = ok && reader.one_time_key_length == KEY_LENGTH; + ok = ok && reader.one_time_key_length == olm::KEY_LENGTH; return ok; } @@ -135,7 +134,7 @@ std::size_t olm::Session::new_inbound_session( if (reader.identity_key && their_identity_key) { bool same = 0 == std::memcmp( - their_identity_key->public_key, reader.identity_key, KEY_LENGTH + their_identity_key->public_key, reader.identity_key, olm::KEY_LENGTH ); if (!same) { last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID; @@ -150,16 +149,16 @@ std::size_t olm::Session::new_inbound_session( ); if (!message_reader.ratchet_key - || message_reader.ratchet_key_length != KEY_LENGTH) { + || message_reader.ratchet_key_length != olm::KEY_LENGTH) { last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT; return std::size_t(-1); } - std::memcpy(alice_identity_key.public_key, reader.identity_key, 32); - std::memcpy(alice_base_key.public_key, reader.base_key, 32); - std::memcpy(bob_one_time_key.public_key, reader.one_time_key, 32); + olm::load_array(alice_identity_key.public_key, reader.identity_key); + olm::load_array(alice_base_key.public_key, reader.base_key); + olm::load_array(bob_one_time_key.public_key, reader.one_time_key); olm::Curve25519PublicKey ratchet_key; - std::memcpy(ratchet_key.public_key, message_reader.ratchet_key, 32); + olm::load_array(ratchet_key.public_key, message_reader.ratchet_key); olm::OneTimeKey const * our_one_time_key = local_account.lookup_key( bob_one_time_key @@ -170,27 +169,28 @@ std::size_t olm::Session::new_inbound_session( return std::size_t(-1); } - std::uint8_t shared_secret[96]; - - olm::curve25519_shared_secret( - our_one_time_key->key, alice_identity_key, shared_secret - ); - olm::curve25519_shared_secret( - local_account.identity_keys.curve25519_key, - alice_base_key, shared_secret + 32 - ); - olm::curve25519_shared_secret( - our_one_time_key->key, alice_base_key, shared_secret + 64 + olm::Curve25519KeyPair const & bob_identity_key = ( + local_account.identity_keys.curve25519_key ); + olm::Curve25519KeyPair const & bob_one_time_key = our_one_time_key->key; + + std::uint8_t secret[olm::KEY_LENGTH * 3]; + std::uint8_t * pos = secret; + olm::curve25519_shared_secret(bob_one_time_key, alice_identity_key, pos); + pos += olm::KEY_LENGTH; + olm::curve25519_shared_secret(bob_identity_key, alice_base_key, pos); + pos += olm::KEY_LENGTH; + olm::curve25519_shared_secret(bob_one_time_key, alice_base_key, pos); - ratchet.initialise_as_bob(shared_secret, 96, ratchet_key); + ratchet.initialise_as_bob(secret, sizeof(secret), ratchet_key); + olm::unset(secret); return std::size_t(0); } std::size_t olm::Session::session_id_length() { - return 32; + return olm::SHA256_OUTPUT_LENGTH; } @@ -201,10 +201,11 @@ std::size_t olm::Session::session_id( last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL; return std::size_t(-1); } - std::uint8_t tmp[96]; - std::memcpy(tmp, alice_identity_key.public_key, 32); - std::memcpy(tmp + 32, alice_base_key.public_key, 32); - std::memcpy(tmp + 64, bob_one_time_key.public_key, 32); + std::uint8_t tmp[olm::KEY_LENGTH * 3]; + std::uint8_t * pos = tmp; + pos = olm::store_array(pos, alice_identity_key.public_key); + pos = olm::store_array(pos, alice_base_key.public_key); + pos = olm::store_array(pos, bob_one_time_key.public_key); olm::sha256(tmp, sizeof(tmp), id); return session_id_length(); } @@ -224,20 +225,20 @@ bool olm::Session::matches_inbound_session( bool same = true; if (reader.identity_key) { same = same && 0 == std::memcmp( - reader.identity_key, alice_identity_key.public_key, KEY_LENGTH + reader.identity_key, alice_identity_key.public_key, olm::KEY_LENGTH ); } if (their_identity_key) { same = same && 0 == std::memcmp( their_identity_key->public_key, alice_identity_key.public_key, - KEY_LENGTH + olm::KEY_LENGTH ); } same = same && 0 == std::memcmp( - reader.base_key, alice_base_key.public_key, KEY_LENGTH + reader.base_key, alice_base_key.public_key, olm::KEY_LENGTH ); same = same && 0 == std::memcmp( - reader.one_time_key, bob_one_time_key.public_key, KEY_LENGTH + reader.one_time_key, bob_one_time_key.public_key, olm::KEY_LENGTH ); return same; } @@ -264,9 +265,9 @@ std::size_t olm::Session::encrypt_message_length( } return encode_one_time_key_message_length( - KEY_LENGTH, - KEY_LENGTH, - KEY_LENGTH, + olm::KEY_LENGTH, + olm::KEY_LENGTH, + olm::KEY_LENGTH, message_length ); } @@ -298,21 +299,15 @@ std::size_t olm::Session::encrypt( encode_one_time_key_message( writer, PROTOCOL_VERSION, - KEY_LENGTH, - KEY_LENGTH, - KEY_LENGTH, + olm::KEY_LENGTH, + olm::KEY_LENGTH, + olm::KEY_LENGTH, message_body_length, message ); - std::memcpy( - writer.one_time_key, bob_one_time_key.public_key, KEY_LENGTH - ); - std::memcpy( - writer.identity_key, alice_identity_key.public_key, KEY_LENGTH - ); - std::memcpy( - writer.base_key, alice_base_key.public_key, KEY_LENGTH - ); + olm::store_array(writer.one_time_key, bob_one_time_key.public_key); + olm::store_array(writer.identity_key, alice_identity_key.public_key); + olm::store_array(writer.base_key, alice_base_key.public_key); message_body = writer.message; } diff --git a/src/utility.cpp b/src/utility.cpp index 1d8c5c1..bc51cff 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -23,7 +23,7 @@ olm::Utility::Utility( size_t olm::Utility::sha256_length() { - return olm::HMAC_SHA256_OUTPUT_LENGTH; + return olm::SHA256_OUTPUT_LENGTH; } @@ -36,7 +36,7 @@ size_t olm::Utility::sha256( return std::size_t(-1); } olm::sha256(input, input_length, output); - return 32; + return olm::SHA256_OUTPUT_LENGTH; } @@ -45,7 +45,7 @@ size_t olm::Utility::ed25519_verify( std::uint8_t const * message, std::size_t message_length, std::uint8_t const * signature, std::size_t signature_length ) { - if (signature_length < 64) { + if (signature_length < olm::SIGNATURE_LENGTH) { last_error = olm::ErrorCode::BAD_MESSAGE_MAC; return std::size_t(-1); } diff --git a/tests/test_base64.cpp b/tests/test_base64.cpp index e6b3710..5bae2f9 100644 --- a/tests/test_base64.cpp +++ b/tests/test_base64.cpp @@ -33,5 +33,4 @@ olm::decode_base64(input, input_length, output); assert_equals(expected_output, output, output_length); } - } diff --git a/tests/test_crypto.cpp b/tests/test_crypto.cpp index b3ff593..4606c52 100644 --- a/tests/test_crypto.cpp +++ b/tests/test_crypto.cpp @@ -70,7 +70,7 @@ olm::curve25519_generate_key(bob_private, bob_pair); assert_equals(bob_private, bob_pair.private_key, 32); assert_equals(bob_public, bob_pair.public_key, 32); -std::uint8_t actual_agreement[olm::CURVE25519_SHARED_SECRET_LENGTH] = {}; +std::uint8_t actual_agreement[olm::KEY_LENGTH] = {}; olm::curve25519_shared_secret(alice_pair, bob_pair, actual_agreement); diff --git a/tests/test_olm_decrypt.cpp b/tests/test_olm_decrypt.cpp new file mode 100644 index 0000000..2a2db98 --- /dev/null +++ b/tests/test_olm_decrypt.cpp @@ -0,0 +1,75 @@ +#include "olm/olm.hh" +#include "unittest.hh" + +const char * test_cases[] = { + "41776f", + "7fff6f0101346d671201", + "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d", + "e9e9c9c1e9e9c9e9c9c1e9e9c9c1", +}; + + +const char * session_data = + "E0p44KO2y2pzp9FIjv0rud2wIvWDi2dx367kP4Fz/9JCMrH+aG369HGymkFtk0+PINTLB9lQRt" + "ohea5d7G/UXQx3r5y4IWuyh1xaRnojEZQ9a5HRZSNtvmZ9NY1f1gutYa4UtcZcbvczN8b/5Bqg" + "e16cPUH1v62JKLlhoAJwRkH1wU6fbyOudERg5gdXA971btR+Q2V8GKbVbO5fGKL5phmEPVXyMs" + "rfjLdzQrgjOTxN8Pf6iuP+WFPvfnR9lDmNCFxJUVAdLIMnLuAdxf1TGcS+zzCzEE8btIZ99mHF" + "dGvPXeH8qLeNZA"; + +void decode_hex( + const char * input, + std::uint8_t * output, std::size_t output_length +) { + std::uint8_t * end = output + output_length; + while (output != end) { + char high = *(input++); + char low = *(input++); + if (high >= 'a') high -= 'a' - ('9' + 1); + if (low >= 'a') low -= 'a' - ('9' + 1); + uint8_t value = ((high - '0') << 4) | (low - '0'); + *(output++) = value; + } +} + +void decrypt_case(int message_type, const char * test_case) { + std::uint8_t session_memory[olm_session_size()]; + ::OlmSession * session = ::olm_session(session_memory); + + std::uint8_t pickled[strlen(session_data)]; + ::memcpy(pickled, session_data, sizeof(pickled)); + ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled)); + + std::size_t message_length = strlen(test_case) / 2; + std::uint8_t * message = (std::uint8_t *) ::malloc(message_length); + decode_hex(test_case, message, message_length); + + size_t max_length = olm_decrypt_max_plaintext_length( + session, message_type, message, message_length + ); + + if (max_length == std::size_t(-1)) { + free(message); + return; + } + + uint8_t plaintext[max_length]; + decode_hex(test_case, message, message_length); + olm_decrypt( + session, message_type, + message, message_length, + plaintext, max_length + ); + free(message); +} + + +int main() { +{ +TestCase my_test("Olm decrypt test"); + +for (int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) { + decrypt_case(0, test_cases[i]); +} + +} +} |