aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/olm/base64.hh24
-rw-r--r--include/olm/cipher.hh4
-rw-r--r--include/olm/crypto.hh26
-rw-r--r--include/olm/memory.hh49
-rw-r--r--include/olm/pickle.hh2
-rw-r--r--include/olm/ratchet.hh2
-rw-r--r--include/olm/session.hh46
-rw-r--r--include/olm/utility.hh9
-rw-r--r--src/account.cpp63
-rw-r--r--src/cipher.cpp21
-rw-r--r--src/crypto.cpp56
-rw-r--r--src/message.cpp40
-rw-r--r--src/olm.cpp47
-rw-r--r--src/ratchet.cpp73
-rw-r--r--src/session.cpp125
-rw-r--r--src/utility.cpp6
-rw-r--r--tests/test_base64.cpp1
-rw-r--r--tests/test_crypto.cpp2
-rw-r--r--tests/test_olm_decrypt.cpp75
19 files changed, 412 insertions, 259 deletions
diff --git a/include/olm/base64.hh b/include/olm/base64.hh
index 018924a..da4641d 100644
--- a/include/olm/base64.hh
+++ b/include/olm/base64.hh
@@ -20,31 +20,45 @@
namespace olm {
-
+/**
+ * The number of bytes of unpadded base64 needed to encode a length of input.
+ */
static std::size_t encode_base64_length(
std::size_t input_length
) {
return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2;
}
-
+/**
+ * Encode the raw input as unpadded base64.
+ * Writes encode_base64_length(input_length) bytes to the output buffer.
+ * The input can overlap with the last three quarters of the output buffer.
+ * That is, the input pointer may be output + output_length - input_length.
+ */
std::uint8_t * encode_base64(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output
);
-
+/**
+ * The number of bytes of raw data a length of unpadded base64 will encode to.
+ * Returns std::size_t(-1) if the length is not a valid length for base64.
+ */
std::size_t decode_base64_length(
std::size_t input_length
);
-
+/**
+ * Decodes the unpadded base64 input to raw bytes.
+ * Writes decode_base64_length(input_length) bytes to the output buffer.
+ * The output can overlap with the first three quarters of the input buffer.
+ * That is, the input pointers and output pointer may be the same.
+ */
std::uint8_t const * decode_base64(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output
);
-
} // namespace olm
diff --git a/include/olm/cipher.hh b/include/olm/cipher.hh
index f71b3af..c561972 100644
--- a/include/olm/cipher.hh
+++ b/include/olm/cipher.hh
@@ -47,6 +47,8 @@ public:
* output |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
+ * The plain-text pointers and cipher-text pointers may be the same.
+ *
* Returns std::size_t(-1) if the length of the cipher-text or the output
* buffer is too small. Otherwise returns the length of the output buffer.
*/
@@ -73,6 +75,8 @@ public:
* input |--ciphertext_length-->| |---mac_length-->|
* ciphertext
*
+ * The plain-text pointers and cipher-text pointers may be the same.
+ *
* Returns std::size_t(-1) if the length of the plain-text buffer is too
* small or if the authentication check fails. Otherwise returns the length
* of the plain text.
diff --git a/include/olm/crypto.hh b/include/olm/crypto.hh
index b845bfe..7a05f8d 100644
--- a/include/olm/crypto.hh
+++ b/include/olm/crypto.hh
@@ -20,28 +20,27 @@
namespace olm {
+static const std::size_t KEY_LENGTH = 32;
+static const std::size_t SIGNATURE_LENGTH = 64;
+static const std::size_t IV_LENGTH = 16;
struct Curve25519PublicKey {
- static const int LENGTH = 32;
- std::uint8_t public_key[32];
+ std::uint8_t public_key[KEY_LENGTH];
};
struct Curve25519KeyPair : public Curve25519PublicKey {
- static const int LENGTH = 64;
- std::uint8_t private_key[32];
+ std::uint8_t private_key[KEY_LENGTH];
};
struct Ed25519PublicKey {
- static const int LENGTH = 32;
- std::uint8_t public_key[32];
+ std::uint8_t public_key[KEY_LENGTH];
};
struct Ed25519KeyPair : public Ed25519PublicKey {
- static const int LENGTH = 64;
- std::uint8_t private_key[32];
+ std::uint8_t private_key[KEY_LENGTH];
};
@@ -52,9 +51,6 @@ void curve25519_generate_key(
);
-const std::size_t CURVE25519_SHARED_SECRET_LENGTH = 32;
-
-
/** Create a shared secret using our private key and their public key.
* The output buffer must be at least 32 bytes long. */
void curve25519_shared_secret(
@@ -109,14 +105,12 @@ bool ed25519_verify(
struct Aes256Key {
- static const int LENGTH = 32;
- std::uint8_t key[32];
+ std::uint8_t key[KEY_LENGTH];
};
struct Aes256Iv {
- static const int LENGTH = 16;
- std::uint8_t iv[16];
+ std::uint8_t iv[IV_LENGTH];
};
@@ -156,7 +150,7 @@ void sha256(
);
-const std::size_t HMAC_SHA256_OUTPUT_LENGTH = 32;
+const std::size_t SHA256_OUTPUT_LENGTH = 32;
/** HMAC: Keyed-Hashing for Message Authentication
diff --git a/include/olm/memory.hh b/include/olm/memory.hh
index b19c74b..128990a 100644
--- a/include/olm/memory.hh
+++ b/include/olm/memory.hh
@@ -14,6 +14,8 @@
*/
#include <cstddef>
#include <cstdint>
+#include <cstring>
+#include <type_traits>
namespace olm {
@@ -35,4 +37,51 @@ bool is_equal(
std::size_t length
);
+/** Check if two fixed size arrays are equals */
+template<typename T>
+bool array_equal(
+ T const & array_a,
+ T const & array_b
+) {
+ static_assert(
+ std::is_array<T>::value
+ && std::is_convertible<T, std::uint8_t *>::value
+ && sizeof(T) > 0,
+ "Arguments to array_equal must be std::uint8_t arrays[]."
+ );
+ return is_equal(array_a, array_b, sizeof(T));
+}
+
+/** Copy into a fixed size array */
+template<typename T>
+std::uint8_t const * load_array(
+ T & destination,
+ std::uint8_t const * source
+) {
+ static_assert(
+ std::is_array<T>::value
+ && std::is_convertible<T, std::uint8_t *>::value
+ && sizeof(T) > 0,
+ "The first argument to load_array must be a std::uint8_t array[]."
+ );
+ std::memcpy(destination, source, sizeof(T));
+ return source + sizeof(T);
+}
+
+/** Copy from a fixed size array */
+template<typename T>
+std::uint8_t * store_array(
+ std::uint8_t * destination,
+ T const & source
+) {
+ static_assert(
+ std::is_array<T>::value
+ && std::is_convertible<T, std::uint8_t *>::value
+ && sizeof(T) > 0,
+ "The second argument to store_array must be a std::uint8_t array[]."
+ );
+ std::memcpy(destination, source, sizeof(T));
+ return destination + sizeof(T);
+}
+
} // namespace olm
diff --git a/include/olm/pickle.hh b/include/olm/pickle.hh
index 7a2bd1b..27f1f26 100644
--- a/include/olm/pickle.hh
+++ b/include/olm/pickle.hh
@@ -109,7 +109,7 @@ std::uint8_t const * unpickle(
) {
std::uint32_t size;
pos = unpickle(pos, end, size);
- while (size--) {
+ while (size-- && pos != end) {
T * value = list.insert(list.end());
pos = unpickle(pos, end, *value);
}
diff --git a/include/olm/ratchet.hh b/include/olm/ratchet.hh
index 7274255..2393e5b 100644
--- a/include/olm/ratchet.hh
+++ b/include/olm/ratchet.hh
@@ -21,7 +21,7 @@ namespace olm {
class Cipher;
-typedef std::uint8_t SharedKey[32];
+typedef std::uint8_t SharedKey[olm::KEY_LENGTH];
struct ChainKey {
diff --git a/include/olm/session.hh b/include/olm/session.hh
index 993a8da..b21b0aa 100644
--- a/include/olm/session.hh
+++ b/include/olm/session.hh
@@ -39,8 +39,13 @@ struct Session {
Curve25519PublicKey alice_base_key;
Curve25519PublicKey bob_one_time_key;
+ /** The number of random bytes that are needed to create a new outbound
+ * session. This will be 64 bytes since two ephemeral keys are needed. */
std::size_t new_outbound_session_random_length();
+ /** Start a new outbound session. Returns std::size_t(-1) on failure. On
+ * failure last_error will be set with an error code. The last_error will be
+ * NOT_ENOUGH_RANDOM if the number of random bytes was too small. */
std::size_t new_outbound_session(
Account const & local_account,
Curve25519PublicKey const & identity_key,
@@ -48,42 +53,79 @@ struct Session {
std::uint8_t const * random, std::size_t random_length
);
+ /** Start a new inbound session from a pre-key message.
+ * Returns std::size_t(-1) on failure. On failure last_error will be set
+ * with an error code. The last_error will be BAD_MESSAGE_FORMAT if
+ * the message headers could not be decoded. */
std::size_t new_inbound_session(
Account & local_account,
Curve25519PublicKey const * their_identity_key,
- std::uint8_t const * one_time_key_message, std::size_t message_length
+ std::uint8_t const * pre_key_message, std::size_t message_length
);
+ /** The number of bytes written by session_id() */
std::size_t session_id_length();
+ /** An identifier for this session. Generated by hashing the public keys
+ * used to create the session. Returns the length of the session id on
+ * success or std::size_t(-1) on failure. On failure last_error will be set
+ * with an error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if
+ * the id buffer was too small. */
std::size_t session_id(
std::uint8_t * id, std::size_t id_length
);
+ /** True if this session can be used to decode an inbound pre-key message.
+ * This can be used to test whether a pre-key message should be decoded
+ * with an existing session or if a new session will need to be created.
+ * Returns true if the session is the same. Returns false if either the
+ * session does not match or the pre-key message could not be decoded.
+ */
bool matches_inbound_session(
Curve25519PublicKey const * their_identity_key,
- std::uint8_t const * one_time_key_message, std::size_t message_length
+ std::uint8_t const * pre_key_message, std::size_t message_length
);
+ /** Whether the next message will be a pre-key message or a normal message.
+ * An outbound session will send pre-key messages until it receives a
+ * message with a ratchet key. */
MessageType encrypt_message_type();
std::size_t encrypt_message_length(
std::size_t plaintext_length
);
+ /** The number of bytes of random data the encrypt method will need to
+ * encrypt a message. This will be 32 bytes if the session needs to
+ * generate a new ephemeral key, or will be 0 bytes otherwise. */
std::size_t encrypt_random_length();
+ /** Encrypt some plain-text. Returns the length of the encrypted message
+ * or std::size_t(-1) on failure. On failure last_error will be set with
+ * an error code. The last_error will be NOT_ENOUGH_RANDOM if the number
+ * of random bytes is too small. The last_error will be
+ * OUTPUT_BUFFER_TOO_SMALL if the output buffer is too small. */
std::size_t encrypt(
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t const * random, std::size_t random_length,
std::uint8_t * message, std::size_t message_length
);
+ /** An upper bound on the number of bytes of plain-text the decrypt method
+ * will write for a given input message length. */
std::size_t decrypt_max_plaintext_length(
MessageType message_type,
std::uint8_t const * message, std::size_t message_length
);
+ /** Decrypt a message. Returns the length of the decrypted plain-text or
+ * std::size_t(-1) on failure. On failure last_error will be set with an
+ * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the
+ * plain-text buffer is too small. The last_error will be
+ * BAD_MESSAGE_VERSION if the message was encrypted with an unsupported
+ * version of the protocol. The last_error will be BAD_MESSAGE_FORMAT if
+ * the message headers could not be decoded. The last_error will be
+ * BAD_MESSAGE_MAC if the message could not be verified */
std::size_t decrypt(
MessageType message_type,
std::uint8_t const * message, std::size_t message_length,
diff --git a/include/olm/utility.hh b/include/olm/utility.hh
index 241d7e0..5329a59 100644
--- a/include/olm/utility.hh
+++ b/include/olm/utility.hh
@@ -31,13 +31,22 @@ struct Utility {
ErrorCode last_error;
+ /** The length of a SHA-256 hash in bytes. */
std::size_t sha256_length();
+ /** Compute a SHA-256 hash. Returns the length of the SHA-256 hash in bytes
+ * on success. Returns std::size_t(-1) on failure. On failure last_error
+ * will be set with an error code. If the output buffer was too small then
+ * last error will be OUTPUT_BUFFER_TOO_SMALL. */
std::size_t sha256(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output, std::size_t output_length
);
+ /** Verify a ed25519 signature. Returns std::size_t(0) on success. Returns
+ * std::size_t(-1) on failure or if the signature was invalid. On failure
+ * last_error will be set with an error code. If the signature was too short
+ * or was not a valid signature then last_error will be BAD_MESSAGE_MAC. */
std::size_t ed25519_verify(
Ed25519PublicKey const & key,
std::uint8_t const * message, std::size_t message_length,
diff --git a/src/account.cpp b/src/account.cpp
index cf6f0cb..43033c8 100644
--- a/src/account.cpp
+++ b/src/account.cpp
@@ -15,6 +15,7 @@
#include "olm/account.hh"
#include "olm/base64.hh"
#include "olm/pickle.hh"
+#include "olm/memory.hh"
olm::Account::Account(
) : next_one_time_key_id(0),
@@ -26,7 +27,7 @@ olm::OneTimeKey const * olm::Account::lookup_key(
olm::Curve25519PublicKey const & public_key
) {
for (olm::OneTimeKey const & key : one_time_keys) {
- if (0 == memcmp(key.key.public_key, public_key.public_key, 32)) {
+ if (olm::array_equal(key.key.public_key, public_key.public_key)) {
return &key;
}
}
@@ -38,7 +39,7 @@ std::size_t olm::Account::remove_key(
) {
OneTimeKey * i;
for (i = one_time_keys.begin(); i != one_time_keys.end(); ++i) {
- if (0 == memcmp(i->key.public_key, public_key.public_key, 32)) {
+ if (olm::array_equal(i->key.public_key, public_key.public_key)) {
std::uint32_t id = i->id;
one_time_keys.erase(i);
return id;
@@ -48,7 +49,7 @@ std::size_t olm::Account::remove_key(
}
std::size_t olm::Account::new_account_random_length() {
- return 2 * 32;
+ return 2 * olm::KEY_LENGTH;
}
std::size_t olm::Account::new_account(
@@ -60,35 +61,19 @@ std::size_t olm::Account::new_account(
}
olm::ed25519_generate_key(random, identity_keys.ed25519_key);
- random += 32;
+ random += KEY_LENGTH;
olm::curve25519_generate_key(random, identity_keys.curve25519_key);
- random += 32;
return 0;
}
namespace {
-
-namespace {
uint8_t KEY_JSON_ED25519[] = "\"ed25519\":";
uint8_t KEY_JSON_CURVE25519[] = "\"curve25519\":";
-}
-
-
-std::size_t count_digits(
- std::uint64_t value
-) {
- std::size_t digits = 0;
- do {
- digits++;
- value /= 10;
- } while (value);
- return digits;
-}
template<typename T>
-std::uint8_t * write_string(
+static std::uint8_t * write_string(
std::uint8_t * pos,
T const & value
) {
@@ -96,27 +81,6 @@ std::uint8_t * write_string(
return pos + (sizeof(T) - 1);
}
-std::uint8_t * write_string(
- std::uint8_t * pos,
- std::uint8_t const * value, std::size_t value_length
-) {
- std::memcpy(pos, value, value_length);
- return pos + value_length;
-}
-
-std::uint8_t * write_digits(
- std::uint8_t * pos,
- std::uint64_t value
-) {
- size_t digits = count_digits(value);
- pos += digits;
- do {
- *(--pos) = '0' + (value % 10);
- value /= 10;
- } while (value);
- return pos + digits;
-}
-
}
@@ -143,7 +107,6 @@ std::size_t olm::Account::get_identity_json(
std::uint8_t * identity_json, std::size_t identity_json_length
) {
std::uint8_t * pos = identity_json;
- std::uint8_t signature[64];
size_t expected_length = get_identity_json_length();
if (identity_json_length < expected_length) {
@@ -174,7 +137,7 @@ std::size_t olm::Account::get_identity_json(
std::size_t olm::Account::signature_length(
) {
- return 64;
+ return olm::SIGNATURE_LENGTH;
}
@@ -196,19 +159,20 @@ std::size_t olm::Account::sign(
std::size_t olm::Account::get_one_time_keys_json_length(
) {
std::size_t length = 0;
+ bool is_empty = true;
for (auto const & key : one_time_keys) {
if (key.published) {
continue;
}
+ is_empty = false;
length += 2; /* {" */
length += olm::encode_base64_length(olm::pickle_length(key.id));
length += 3; /* ":" */
length += olm::encode_base64_length(sizeof(key.key.public_key));
length += 1; /* " */
}
- if (length == 0) {
- /* The list was empty. Add a byte for the opening '{' */
- length = 1;
+ if (is_empty) {
+ length += 1; /* { */
}
length += 3; /* }{} */
length += sizeof(KEY_JSON_CURVE25519) - 1;
@@ -244,6 +208,7 @@ std::size_t olm::Account::get_one_time_keys_json(
sep = ',';
}
if (sep != ',') {
+ /* The list was empty */
*(pos++) = sep;
}
*(pos++) = '}';
@@ -273,7 +238,7 @@ std::size_t olm::Account::max_number_of_one_time_keys(
std::size_t olm::Account::generate_one_time_keys_random_length(
std::size_t number_of_keys
) {
- return 32 * number_of_keys;
+ return olm::KEY_LENGTH * number_of_keys;
}
std::size_t olm::Account::generate_one_time_keys(
@@ -289,7 +254,7 @@ std::size_t olm::Account::generate_one_time_keys(
key.id = ++next_one_time_key_id;
key.published = false;
olm::curve25519_generate_key(random, key.key);
- random += 32;
+ random += olm::KEY_LENGTH;
}
return number_of_keys;
}
diff --git a/src/cipher.cpp b/src/cipher.cpp
index 2202746..7bb11b8 100644
--- a/src/cipher.cpp
+++ b/src/cipher.cpp
@@ -23,11 +23,9 @@ olm::Cipher::~Cipher() {
namespace {
-static const std::size_t SHA256_LENGTH = 32;
-
struct DerivedKeys {
olm::Aes256Key aes_key;
- std::uint8_t mac_key[SHA256_LENGTH];
+ std::uint8_t mac_key[olm::KEY_LENGTH];
olm::Aes256Iv aes_iv;
};
@@ -37,16 +35,17 @@ static void derive_keys(
std::uint8_t const * key, std::size_t key_length,
DerivedKeys & keys
) {
- std::uint8_t derived_secrets[80];
+ std::uint8_t derived_secrets[2 * olm::KEY_LENGTH + olm::IV_LENGTH];
olm::hkdf_sha256(
key, key_length,
nullptr, 0,
kdf_info, kdf_info_length,
derived_secrets, sizeof(derived_secrets)
);
- std::memcpy(keys.aes_key.key, derived_secrets, 32);
- std::memcpy(keys.mac_key, derived_secrets + 32, 32);
- std::memcpy(keys.aes_iv.iv, derived_secrets + 64, 16);
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(keys.aes_key.key, pos);
+ pos = olm::load_array(keys.mac_key, pos);
+ pos = olm::load_array(keys.aes_iv.iv, pos);
olm::unset(derived_secrets);
}
@@ -84,7 +83,7 @@ std::size_t olm::CipherAesSha256::encrypt(
return std::size_t(-1);
}
struct DerivedKeys keys;
- std::uint8_t mac[SHA256_LENGTH];
+ std::uint8_t mac[olm::SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
@@ -93,7 +92,7 @@ std::size_t olm::CipherAesSha256::encrypt(
);
olm::hmac_sha256(
- keys.mac_key, SHA256_LENGTH, output, output_length - MAC_LENGTH, mac
+ keys.mac_key, olm::KEY_LENGTH, output, output_length - MAC_LENGTH, mac
);
std::memcpy(output + output_length - MAC_LENGTH, mac, MAC_LENGTH);
@@ -116,12 +115,12 @@ std::size_t olm::CipherAesSha256::decrypt(
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const {
DerivedKeys keys;
- std::uint8_t mac[SHA256_LENGTH];
+ std::uint8_t mac[olm::SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
olm::hmac_sha256(
- keys.mac_key, SHA256_LENGTH, input, input_length - MAC_LENGTH, mac
+ keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac
);
std::uint8_t const * input_mac = input + input_length - MAC_LENGTH;
diff --git a/src/crypto.cpp b/src/crypto.cpp
index ed89d64..8024355 100644
--- a/src/crypto.cpp
+++ b/src/crypto.cpp
@@ -66,8 +66,9 @@ void ed25519_keypair(
namespace {
static const std::uint8_t CURVE25519_BASEPOINT[32] = {9};
+static const std::size_t AES_KEY_SCHEDULE_LENGTH = 60;
+static const std::size_t AES_KEY_BITS = 8 * olm::KEY_LENGTH;
static const std::size_t AES_BLOCK_LENGTH = 16;
-static const std::size_t SHA256_HASH_LENGTH = 32;
static const std::size_t SHA256_BLOCK_LENGTH = 64;
static const std::uint8_t HKDF_DEFAULT_SALT[32] = {};
@@ -99,7 +100,7 @@ inline static void hmac_sha256_key(
}
-inline void hmac_sha256_init(
+inline static void hmac_sha256_init(
::SHA256_CTX * context,
std::uint8_t const * hmac_key
) {
@@ -114,12 +115,12 @@ inline void hmac_sha256_init(
}
-inline void hmac_sha256_final(
+inline static void hmac_sha256_final(
::SHA256_CTX * context,
std::uint8_t const * hmac_key,
std::uint8_t * output
) {
- std::uint8_t o_pad[SHA256_BLOCK_LENGTH + SHA256_HASH_LENGTH];
+ std::uint8_t o_pad[SHA256_BLOCK_LENGTH + olm::SHA256_OUTPUT_LENGTH];
std::memcpy(o_pad, hmac_key, SHA256_BLOCK_LENGTH);
for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) {
o_pad[i] ^= 0x5C;
@@ -140,7 +141,7 @@ void olm::curve25519_generate_key(
std::uint8_t const * random_32_bytes,
olm::Curve25519KeyPair & key_pair
) {
- std::memcpy(key_pair.private_key, random_32_bytes, 32);
+ std::memcpy(key_pair.private_key, random_32_bytes, KEY_LENGTH);
::curve25519_donna(
key_pair.public_key, key_pair.private_key, CURVE25519_BASEPOINT
);
@@ -161,9 +162,9 @@ void olm::curve25519_sign(
std::uint8_t const * message, std::size_t message_length,
std::uint8_t * output
) {
- std::uint8_t private_key[32];
- std::uint8_t public_key[32];
- std::memcpy(private_key, our_key.private_key, 32);
+ std::uint8_t private_key[KEY_LENGTH];
+ std::uint8_t public_key[KEY_LENGTH];
+ std::memcpy(private_key, our_key.private_key, KEY_LENGTH);
::ed25519_keypair(private_key, public_key);
::ed25519_sign(
output,
@@ -179,10 +180,10 @@ bool olm::curve25519_verify(
std::uint8_t const * message, std::size_t message_length,
std::uint8_t const * signature
) {
- std::uint8_t public_key[32];
- std::uint8_t signature_buffer[64];
- std::memcpy(public_key, their_key.public_key, 32);
- std::memcpy(signature_buffer, signature, 64);
+ std::uint8_t public_key[KEY_LENGTH];
+ std::uint8_t signature_buffer[SIGNATURE_LENGTH];
+ std::memcpy(public_key, their_key.public_key, KEY_LENGTH);
+ std::memcpy(signature_buffer, signature, SIGNATURE_LENGTH);
::convert_curve25519_to_ed25519(public_key, signature_buffer);
return 0 != ::ed25519_verify(
signature,
@@ -196,7 +197,7 @@ void olm::ed25519_generate_key(
std::uint8_t const * random_32_bytes,
olm::Ed25519KeyPair & key_pair
) {
- std::memcpy(key_pair.private_key, random_32_bytes, 32);
+ std::memcpy(key_pair.private_key, random_32_bytes, KEY_LENGTH);
::ed25519_keypair(key_pair.private_key, key_pair.public_key);
}
@@ -240,13 +241,13 @@ void olm::aes_encrypt_cbc(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output
) {
- std::uint32_t key_schedule[60];
- ::aes_key_setup(key.key, key_schedule, 256);
+ std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH];
+ ::aes_key_setup(key.key, key_schedule, AES_KEY_BITS);
std::uint8_t input_block[AES_BLOCK_LENGTH];
std::memcpy(input_block, iv.iv, AES_BLOCK_LENGTH);
while (input_length >= AES_BLOCK_LENGTH) {
xor_block<AES_BLOCK_LENGTH>(input_block, input);
- ::aes_encrypt(input_block, output, key_schedule, 256);
+ ::aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS);
std::memcpy(input_block, output, AES_BLOCK_LENGTH);
input += AES_BLOCK_LENGTH;
output += AES_BLOCK_LENGTH;
@@ -259,7 +260,7 @@ void olm::aes_encrypt_cbc(
for (; i < AES_BLOCK_LENGTH; ++i) {
input_block[i] ^= AES_BLOCK_LENGTH - input_length;
}
- ::aes_encrypt(input_block, output, key_schedule, 256);
+ ::aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS);
olm::unset(key_schedule);
olm::unset(input_block);
}
@@ -271,14 +272,14 @@ std::size_t olm::aes_decrypt_cbc(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * output
) {
- std::uint32_t key_schedule[60];
- ::aes_key_setup(key.key, key_schedule, 256);
+ std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH];
+ ::aes_key_setup(key.key, key_schedule, AES_KEY_BITS);
std::uint8_t block1[AES_BLOCK_LENGTH];
std::uint8_t block2[AES_BLOCK_LENGTH];
std::memcpy(block1, iv.iv, AES_BLOCK_LENGTH);
for (std::size_t i = 0; i < input_length; i += AES_BLOCK_LENGTH) {
std::memcpy(block2, &input[i], AES_BLOCK_LENGTH);
- ::aes_decrypt(&input[i], &output[i], key_schedule, 256);
+ ::aes_decrypt(&input[i], &output[i], key_schedule, AES_KEY_BITS);
xor_block<AES_BLOCK_LENGTH>(&output[i], block1);
std::memcpy(block1, block2, AES_BLOCK_LENGTH);
}
@@ -301,6 +302,7 @@ void olm::sha256(
olm::unset(context);
}
+
void olm::hmac_sha256(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
@@ -325,7 +327,7 @@ void olm::hkdf_sha256(
) {
::SHA256_CTX context;
std::uint8_t hmac_key[SHA256_BLOCK_LENGTH];
- std::uint8_t step_result[SHA256_HASH_LENGTH];
+ std::uint8_t step_result[olm::SHA256_OUTPUT_LENGTH];
std::size_t bytes_remaining = output_length;
std::uint8_t iteration = 1;
if (!salt) {
@@ -337,20 +339,20 @@ void olm::hkdf_sha256(
hmac_sha256_init(&context, hmac_key);
::sha256_update(&context, input, input_length);
hmac_sha256_final(&context, hmac_key, step_result);
- hmac_sha256_key(step_result, SHA256_HASH_LENGTH, hmac_key);
+ hmac_sha256_key(step_result, olm::SHA256_OUTPUT_LENGTH, hmac_key);
/* Extract */
hmac_sha256_init(&context, hmac_key);
::sha256_update(&context, info, info_length);
::sha256_update(&context, &iteration, 1);
hmac_sha256_final(&context, hmac_key, step_result);
- while (bytes_remaining > SHA256_HASH_LENGTH) {
- std::memcpy(output, step_result, SHA256_HASH_LENGTH);
- output += SHA256_HASH_LENGTH;
- bytes_remaining -= SHA256_HASH_LENGTH;
+ while (bytes_remaining > olm::SHA256_OUTPUT_LENGTH) {
+ std::memcpy(output, step_result, olm::SHA256_OUTPUT_LENGTH);
+ output += olm::SHA256_OUTPUT_LENGTH;
+ bytes_remaining -= olm::SHA256_OUTPUT_LENGTH;
iteration ++;
hmac_sha256_init(&context, hmac_key);
- ::sha256_update(&context, step_result, SHA256_HASH_LENGTH);
+ ::sha256_update(&context, step_result, olm::SHA256_OUTPUT_LENGTH);
::sha256_update(&context, info, info_length);
::sha256_update(&context, &iteration, 1);
hmac_sha256_final(&context, hmac_key, step_result);
diff --git a/src/message.cpp b/src/message.cpp
index 93473b9..05f707f 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -17,7 +17,7 @@
namespace {
template<typename T>
-std::size_t varint_length(
+static std::size_t varint_length(
T value
) {
std::size_t result = 1;
@@ -30,7 +30,7 @@ std::size_t varint_length(
template<typename T>
-std::uint8_t * varint_encode(
+static std::uint8_t * varint_encode(
std::uint8_t * output,
T value
) {
@@ -44,11 +44,14 @@ std::uint8_t * varint_encode(
template<typename T>
-T varint_decode(
+static T varint_decode(
std::uint8_t const * varint_start,
std::uint8_t const * varint_end
) {
T value = 0;
+ if (varint_end == varint_start) {
+ return 0;
+ }
do {
value <<= 7;
value |= 0x7F & *(--varint_end);
@@ -57,7 +60,7 @@ T varint_decode(
}
-std::uint8_t const * varint_skip(
+static std::uint8_t const * varint_skip(
std::uint8_t const * input,
std::uint8_t const * input_end
) {
@@ -71,7 +74,7 @@ std::uint8_t const * varint_skip(
}
-std::size_t varstring_length(
+static std::size_t varstring_length(
std::size_t string_length
) {
return varint_length(string_length) + string_length;
@@ -82,7 +85,7 @@ static std::uint8_t const RATCHET_KEY_TAG = 012;
static std::uint8_t const COUNTER_TAG = 020;
static std::uint8_t const CIPHERTEXT_TAG = 042;
-std::uint8_t * encode(
+static std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint32_t value
@@ -91,7 +94,7 @@ std::uint8_t * encode(
return varint_encode(pos, value);
}
-std::uint8_t * encode(
+static std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint8_t * & value, std::size_t value_length
@@ -102,7 +105,7 @@ std::uint8_t * encode(
return pos + value_length;
}
-std::uint8_t const * decode(
+static std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint32_t & value, bool & has_value
@@ -118,7 +121,7 @@ std::uint8_t const * decode(
}
-std::uint8_t const * decode(
+static std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint8_t const * & value, std::size_t & value_length
@@ -136,7 +139,7 @@ std::uint8_t const * decode(
return pos;
}
-std::uint8_t const * skip_unknown(
+static std::uint8_t const * skip_unknown(
std::uint8_t const * pos, std::uint8_t const * end
) {
if (pos != end) {
@@ -201,13 +204,17 @@ void olm::decode_message(
std::uint8_t const * end = input + input_length - mac_length;
std::uint8_t const * unknown = nullptr;
- if (pos == end) return;
- reader.version = *(pos++);
reader.input = input;
reader.input_length = input_length;
reader.has_counter = false;
reader.ratchet_key = nullptr;
+ reader.ratchet_key_length = 0;
reader.ciphertext = nullptr;
+ reader.ciphertext_length = 0;
+
+ if (pos == end) return;
+ if (input_length < mac_length) return;
+ reader.version = *(pos++);
while (pos != end) {
pos = decode(
@@ -281,12 +288,17 @@ void olm::decode_one_time_key_message(
std::uint8_t const * end = input + input_length;
std::uint8_t const * unknown = nullptr;
- if (pos == end) return;
- reader.version = *(pos++);
reader.one_time_key = nullptr;
+ reader.one_time_key_length = 0;
reader.identity_key = nullptr;
+ reader.identity_key_length = 0;
reader.base_key = nullptr;
+ reader.base_key_length = 0;
reader.message = nullptr;
+ reader.message_length = 0;
+
+ if (pos == end) return;
+ reader.version = *(pos++);
while (pos != end) {
pos = decode(
diff --git a/src/olm.cpp b/src/olm.cpp
index 28b194f..63f3d83 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -520,8 +520,13 @@ size_t olm_create_outbound_session(
void const * their_one_time_key, size_t their_one_time_key_length,
void * random, size_t random_length
) {
- if (olm::decode_base64_length(their_identity_key_length) != 32
- || olm::decode_base64_length(their_one_time_key_length) != 32
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::uint8_t const * ot_key = from_c(their_one_time_key);
+ std::size_t id_key_length = their_identity_key_length;
+ std::size_t ot_key_length = their_one_time_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH
+ || olm::decode_base64_length(ot_key_length) != olm::KEY_LENGTH
) {
from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
@@ -529,14 +534,8 @@ size_t olm_create_outbound_session(
olm::Curve25519PublicKey identity_key;
olm::Curve25519PublicKey one_time_key;
- olm::decode_base64(
- from_c(their_identity_key), their_identity_key_length,
- identity_key.public_key
- );
- olm::decode_base64(
- from_c(their_one_time_key), their_one_time_key_length,
- one_time_key.public_key
- );
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
+ olm::decode_base64(ot_key, ot_key_length, one_time_key.public_key);
size_t result = from_c(session)->new_outbound_session(
*from_c(account), identity_key, one_time_key,
@@ -570,15 +569,15 @@ size_t olm_create_inbound_session_from(
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
) {
- if (olm::decode_base64_length(their_identity_key_length) != 32) {
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::size_t id_key_length = their_identity_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH) {
from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
}
olm::Curve25519PublicKey identity_key;
- olm::decode_base64(
- from_c(their_identity_key), their_identity_key_length,
- identity_key.public_key
- );
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
std::size_t raw_length = b64_input(
from_c(one_time_key_message), message_length, from_c(session)->last_error
@@ -641,15 +640,15 @@ size_t olm_matches_inbound_session_from(
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
) {
- if (olm::decode_base64_length(their_identity_key_length) != 32) {
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::size_t id_key_length = their_identity_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != olm::KEY_LENGTH) {
from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
}
olm::Curve25519PublicKey identity_key;
- olm::decode_base64(
- from_c(their_identity_key), their_identity_key_length,
- identity_key.public_key
- );
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
std::size_t raw_length = b64_input(
from_c(one_time_key_message), message_length, from_c(session)->last_error
@@ -800,15 +799,12 @@ size_t olm_ed25519_verify(
void const * message, size_t message_length,
void * signature, size_t signature_length
) {
- if (olm::decode_base64_length(key_length) != 32) {
+ if (olm::decode_base64_length(key_length) != olm::KEY_LENGTH) {
from_c(utility)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
}
olm::Ed25519PublicKey verify_key;
- olm::decode_base64(
- from_c(key), key_length,
- verify_key.public_key
- );
+ olm::decode_base64(from_c(key), key_length, verify_key.public_key);
std::size_t raw_signature_length = b64_input(
from_c(signature), signature_length, from_c(utility)->last_error
);
@@ -822,5 +818,4 @@ size_t olm_ed25519_verify(
);
}
-
}
diff --git a/src/ratchet.cpp b/src/ratchet.cpp
index b9108db..5ef1e56 100644
--- a/src/ratchet.cpp
+++ b/src/ratchet.cpp
@@ -23,13 +23,12 @@
namespace {
-std::uint8_t PROTOCOL_VERSION = 3;
-std::size_t KEY_LENGTH = olm::Curve25519PublicKey::LENGTH;
-std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
-std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
-std::size_t MAX_MESSAGE_GAP = 2000;
+static const std::uint8_t PROTOCOL_VERSION = 3;
+static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
+static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
+static const std::size_t MAX_MESSAGE_GAP = 2000;
-void create_chain_key(
+static void create_chain_key(
olm::SharedKey const & root_key,
olm::Curve25519KeyPair const & our_key,
olm::Curve25519PublicKey const & their_key,
@@ -39,22 +38,23 @@ void create_chain_key(
) {
olm::SharedKey secret;
olm::curve25519_shared_secret(our_key, their_key, secret);
- std::uint8_t derived_secrets[64];
+ std::uint8_t derived_secrets[2 * olm::KEY_LENGTH];
olm::hkdf_sha256(
secret, sizeof(secret),
root_key, sizeof(root_key),
info.ratchet_info, info.ratchet_info_length,
derived_secrets, sizeof(derived_secrets)
);
- std::memcpy(new_root_key, derived_secrets, 32);
- std::memcpy(new_chain_key.key, derived_secrets + 32, 32);
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(new_root_key, pos);
+ pos = olm::load_array(new_chain_key.key, pos);
new_chain_key.index = 0;
olm::unset(derived_secrets);
olm::unset(secret);
}
-void advance_chain_key(
+static void advance_chain_key(
olm::ChainKey const & chain_key,
olm::ChainKey & new_chain_key
) {
@@ -67,7 +67,7 @@ void advance_chain_key(
}
-void create_message_keys(
+static void create_message_keys(
olm::ChainKey const & chain_key,
olm::KdfInfo const & info,
olm::MessageKey & message_key
@@ -81,7 +81,7 @@ void create_message_keys(
}
-std::size_t verify_mac_and_decrypt(
+static std::size_t verify_mac_and_decrypt(
olm::Cipher const & cipher,
olm::MessageKey const & message_key,
olm::MessageReader const & reader,
@@ -96,7 +96,7 @@ std::size_t verify_mac_and_decrypt(
}
-std::size_t verify_mac_and_decrypt_for_existing_chain(
+static std::size_t verify_mac_and_decrypt_for_existing_chain(
olm::Ratchet const & session,
olm::ChainKey const & chain,
olm::MessageReader const & reader,
@@ -130,7 +130,7 @@ std::size_t verify_mac_and_decrypt_for_existing_chain(
}
-std::size_t verify_mac_and_decrypt_for_new_chain(
+static std::size_t verify_mac_and_decrypt_for_new_chain(
olm::Ratchet const & session,
olm::MessageReader const & reader,
std::uint8_t * plaintext, std::size_t max_plaintext_length
@@ -148,9 +148,7 @@ std::size_t verify_mac_and_decrypt_for_new_chain(
if (reader.counter > MAX_MESSAGE_GAP) {
return std::size_t(-1);
}
- std::memcpy(
- new_chain.ratchet_key.public_key, reader.ratchet_key, KEY_LENGTH
- );
+ olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key);
create_chain_key(
session.root_key, session.sender_chain[0].ratchet_key,
@@ -183,7 +181,7 @@ void olm::Ratchet::initialise_as_bob(
std::uint8_t const * shared_secret, std::size_t shared_secret_length,
olm::Curve25519PublicKey const & their_ratchet_key
) {
- std::uint8_t derived_secrets[64];
+ std::uint8_t derived_secrets[2 * olm::KEY_LENGTH];
olm::hkdf_sha256(
shared_secret, shared_secret_length,
nullptr, 0,
@@ -192,8 +190,9 @@ void olm::Ratchet::initialise_as_bob(
);
receiver_chains.insert();
receiver_chains[0].chain_key.index = 0;
- std::memcpy(root_key, derived_secrets, 32);
- std::memcpy(receiver_chains[0].chain_key.key, derived_secrets + 32, 32);
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(root_key, pos);
+ pos = olm::load_array(receiver_chains[0].chain_key.key, pos);
receiver_chains[0].ratchet_key = their_ratchet_key;
olm::unset(derived_secrets);
}
@@ -203,7 +202,7 @@ void olm::Ratchet::initialise_as_alice(
std::uint8_t const * shared_secret, std::size_t shared_secret_length,
olm::Curve25519KeyPair const & our_ratchet_key
) {
- std::uint8_t derived_secrets[64];
+ std::uint8_t derived_secrets[2 * olm::KEY_LENGTH];
olm::hkdf_sha256(
shared_secret, shared_secret_length,
nullptr, 0,
@@ -212,8 +211,9 @@ void olm::Ratchet::initialise_as_alice(
);
sender_chain.insert();
sender_chain[0].chain_key.index = 0;
- std::memcpy(root_key, derived_secrets, 32);
- std::memcpy(sender_chain[0].chain_key.key, derived_secrets + 32, 32);
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(root_key, pos);
+ pos = olm::load_array(sender_chain[0].chain_key.key, pos);
sender_chain[0].ratchet_key = our_ratchet_key;
olm::unset(derived_secrets);
}
@@ -224,7 +224,7 @@ namespace olm {
static std::size_t pickle_length(
const olm::SharedKey & value
) {
- return KEY_LENGTH;
+ return olm::KEY_LENGTH;
}
@@ -232,7 +232,7 @@ static std::uint8_t * pickle(
std::uint8_t * pos,
const olm::SharedKey & value
) {
- return olm::pickle_bytes(pos, value, KEY_LENGTH);
+ return olm::pickle_bytes(pos, value, olm::KEY_LENGTH);
}
@@ -240,7 +240,7 @@ static std::uint8_t const * unpickle(
std::uint8_t const * pos, std::uint8_t const * end,
olm::SharedKey & value
) {
- return olm::unpickle_bytes(pos, end, value, KEY_LENGTH);
+ return olm::unpickle_bytes(pos, end, value, olm::KEY_LENGTH);
}
@@ -349,7 +349,7 @@ std::size_t olm::pickle_length(
olm::Ratchet const & value
) {
std::size_t length = 0;
- length += KEY_LENGTH;
+ length += olm::KEY_LENGTH;
length += olm::pickle_length(value.sender_chain);
length += olm::pickle_length(value.receiver_chains);
length += olm::pickle_length(value.skipped_message_keys);
@@ -391,13 +391,13 @@ std::size_t olm::Ratchet::encrypt_output_length(
plaintext_length
);
return olm::encode_message_length(
- counter, KEY_LENGTH, padded, ratchet_cipher.mac_length()
+ counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length()
);
}
std::size_t olm::Ratchet::encrypt_random_length() {
- return sender_chain.empty() ? KEY_LENGTH : 0;
+ return sender_chain.empty() ? olm::KEY_LENGTH : 0;
}
@@ -442,10 +442,11 @@ std::size_t olm::Ratchet::encrypt(
olm::MessageWriter writer;
olm::encode_message(
- writer, PROTOCOL_VERSION, counter, KEY_LENGTH, ciphertext_length, output
+ writer, PROTOCOL_VERSION, counter, olm::KEY_LENGTH, ciphertext_length,
+ output
);
- std::memcpy(writer.ratchet_key, ratchet_key.public_key, KEY_LENGTH);
+ olm::store_array(writer.ratchet_key, ratchet_key.public_key);
ratchet_cipher.encrypt(
keys.key, sizeof(keys.key),
@@ -504,7 +505,7 @@ std::size_t olm::Ratchet::decrypt(
return std::size_t(-1);
}
- if (reader.ratchet_key_length != KEY_LENGTH) {
+ if (reader.ratchet_key_length != olm::KEY_LENGTH) {
last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
@@ -513,7 +514,7 @@ std::size_t olm::Ratchet::decrypt(
for (olm::ReceiverChain & receiver_chain : receiver_chains) {
if (0 == std::memcmp(
receiver_chain.ratchet_key.public_key, reader.ratchet_key,
- KEY_LENGTH
+ olm::KEY_LENGTH
)) {
chain = &receiver_chain;
break;
@@ -533,7 +534,7 @@ std::size_t olm::Ratchet::decrypt(
if (reader.counter == skipped.message_key.index
&& 0 == std::memcmp(
skipped.ratchet_key.public_key, reader.ratchet_key,
- KEY_LENGTH
+ olm::KEY_LENGTH
)
) {
/* Found the key for this message. Check the MAC. */
@@ -569,9 +570,7 @@ std::size_t olm::Ratchet::decrypt(
* We can discard our previous empheral ratchet key.
* We will generate a new key when we send the next message. */
chain = receiver_chains.insert();
- std::memcpy(
- chain->ratchet_key.public_key, reader.ratchet_key, KEY_LENGTH
- );
+ olm::load_array(chain->ratchet_key.public_key, reader.ratchet_key);
create_chain_key(
root_key, sender_chain[0].ratchet_key, chain->ratchet_key,
kdf_info, root_key, chain->chain_key
diff --git a/src/session.cpp b/src/session.cpp
index b17a059..71c80e4 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -24,7 +24,6 @@
namespace {
-static const std::size_t KEY_LENGTH = 32;
static const std::uint8_t PROTOCOL_VERSION = 0x3;
static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT";
@@ -51,7 +50,7 @@ olm::Session::Session(
std::size_t olm::Session::new_outbound_session_random_length() {
- return KEY_LENGTH * 2;
+ return olm::KEY_LENGTH * 2;
}
@@ -66,54 +65,54 @@ std::size_t olm::Session::new_outbound_session(
return std::size_t(-1);
}
- Curve25519KeyPair base_key;
+ olm::Curve25519KeyPair base_key;
olm::curve25519_generate_key(random, base_key);
- Curve25519KeyPair ratchet_key;
- olm::curve25519_generate_key(random + 32, ratchet_key);
+ olm::Curve25519KeyPair ratchet_key;
+ olm::curve25519_generate_key(random + olm::KEY_LENGTH, ratchet_key);
+
+ olm::Curve25519KeyPair const & alice_identity_key_pair = (
+ local_account.identity_keys.curve25519_key
+ );
received_message = false;
- alice_identity_key = local_account.identity_keys.curve25519_key;
+ alice_identity_key = alice_identity_key_pair;
alice_base_key = base_key;
bob_one_time_key = one_time_key;
- std::uint8_t shared_secret[96];
+ std::uint8_t secret[3 * olm::KEY_LENGTH];
+ std::uint8_t * pos = secret;
- olm::curve25519_shared_secret(
- local_account.identity_keys.curve25519_key,
- one_time_key, shared_secret
- );
- olm::curve25519_shared_secret(
- base_key, identity_key, shared_secret + 32
- );
- olm::curve25519_shared_secret(
- base_key, one_time_key, shared_secret + 64
- );
+ olm::curve25519_shared_secret(alice_identity_key_pair, one_time_key, pos);
+ pos += olm::KEY_LENGTH;
+ olm::curve25519_shared_secret(base_key, identity_key, pos);
+ pos += olm::KEY_LENGTH;
+ olm::curve25519_shared_secret(base_key, one_time_key, pos);
- ratchet.initialise_as_alice(shared_secret, 96, ratchet_key);
+ ratchet.initialise_as_alice(secret, sizeof(secret), ratchet_key);
olm::unset(base_key);
olm::unset(ratchet_key);
- olm::unset(shared_secret);
+ olm::unset(secret);
return std::size_t(0);
}
namespace {
-bool check_message_fields(
+static bool check_message_fields(
olm::PreKeyMessageReader & reader, bool have_their_identity_key
) {
bool ok = true;
ok = ok && (have_their_identity_key || reader.identity_key);
if (reader.identity_key) {
- ok = ok && reader.identity_key_length == KEY_LENGTH;
+ ok = ok && reader.identity_key_length == olm::KEY_LENGTH;
}
ok = ok && reader.message;
ok = ok && reader.base_key;
- ok = ok && reader.base_key_length == KEY_LENGTH;
+ ok = ok && reader.base_key_length == olm::KEY_LENGTH;
ok = ok && reader.one_time_key;
- ok = ok && reader.one_time_key_length == KEY_LENGTH;
+ ok = ok && reader.one_time_key_length == olm::KEY_LENGTH;
return ok;
}
@@ -135,7 +134,7 @@ std::size_t olm::Session::new_inbound_session(
if (reader.identity_key && their_identity_key) {
bool same = 0 == std::memcmp(
- their_identity_key->public_key, reader.identity_key, KEY_LENGTH
+ their_identity_key->public_key, reader.identity_key, olm::KEY_LENGTH
);
if (!same) {
last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
@@ -150,16 +149,16 @@ std::size_t olm::Session::new_inbound_session(
);
if (!message_reader.ratchet_key
- || message_reader.ratchet_key_length != KEY_LENGTH) {
+ || message_reader.ratchet_key_length != olm::KEY_LENGTH) {
last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
- std::memcpy(alice_identity_key.public_key, reader.identity_key, 32);
- std::memcpy(alice_base_key.public_key, reader.base_key, 32);
- std::memcpy(bob_one_time_key.public_key, reader.one_time_key, 32);
+ olm::load_array(alice_identity_key.public_key, reader.identity_key);
+ olm::load_array(alice_base_key.public_key, reader.base_key);
+ olm::load_array(bob_one_time_key.public_key, reader.one_time_key);
olm::Curve25519PublicKey ratchet_key;
- std::memcpy(ratchet_key.public_key, message_reader.ratchet_key, 32);
+ olm::load_array(ratchet_key.public_key, message_reader.ratchet_key);
olm::OneTimeKey const * our_one_time_key = local_account.lookup_key(
bob_one_time_key
@@ -170,27 +169,28 @@ std::size_t olm::Session::new_inbound_session(
return std::size_t(-1);
}
- std::uint8_t shared_secret[96];
-
- olm::curve25519_shared_secret(
- our_one_time_key->key, alice_identity_key, shared_secret
- );
- olm::curve25519_shared_secret(
- local_account.identity_keys.curve25519_key,
- alice_base_key, shared_secret + 32
- );
- olm::curve25519_shared_secret(
- our_one_time_key->key, alice_base_key, shared_secret + 64
+ olm::Curve25519KeyPair const & bob_identity_key = (
+ local_account.identity_keys.curve25519_key
);
+ olm::Curve25519KeyPair const & bob_one_time_key = our_one_time_key->key;
+
+ std::uint8_t secret[olm::KEY_LENGTH * 3];
+ std::uint8_t * pos = secret;
+ olm::curve25519_shared_secret(bob_one_time_key, alice_identity_key, pos);
+ pos += olm::KEY_LENGTH;
+ olm::curve25519_shared_secret(bob_identity_key, alice_base_key, pos);
+ pos += olm::KEY_LENGTH;
+ olm::curve25519_shared_secret(bob_one_time_key, alice_base_key, pos);
- ratchet.initialise_as_bob(shared_secret, 96, ratchet_key);
+ ratchet.initialise_as_bob(secret, sizeof(secret), ratchet_key);
+ olm::unset(secret);
return std::size_t(0);
}
std::size_t olm::Session::session_id_length() {
- return 32;
+ return olm::SHA256_OUTPUT_LENGTH;
}
@@ -201,10 +201,11 @@ std::size_t olm::Session::session_id(
last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
return std::size_t(-1);
}
- std::uint8_t tmp[96];
- std::memcpy(tmp, alice_identity_key.public_key, 32);
- std::memcpy(tmp + 32, alice_base_key.public_key, 32);
- std::memcpy(tmp + 64, bob_one_time_key.public_key, 32);
+ std::uint8_t tmp[olm::KEY_LENGTH * 3];
+ std::uint8_t * pos = tmp;
+ pos = olm::store_array(pos, alice_identity_key.public_key);
+ pos = olm::store_array(pos, alice_base_key.public_key);
+ pos = olm::store_array(pos, bob_one_time_key.public_key);
olm::sha256(tmp, sizeof(tmp), id);
return session_id_length();
}
@@ -224,20 +225,20 @@ bool olm::Session::matches_inbound_session(
bool same = true;
if (reader.identity_key) {
same = same && 0 == std::memcmp(
- reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
+ reader.identity_key, alice_identity_key.public_key, olm::KEY_LENGTH
);
}
if (their_identity_key) {
same = same && 0 == std::memcmp(
their_identity_key->public_key, alice_identity_key.public_key,
- KEY_LENGTH
+ olm::KEY_LENGTH
);
}
same = same && 0 == std::memcmp(
- reader.base_key, alice_base_key.public_key, KEY_LENGTH
+ reader.base_key, alice_base_key.public_key, olm::KEY_LENGTH
);
same = same && 0 == std::memcmp(
- reader.one_time_key, bob_one_time_key.public_key, KEY_LENGTH
+ reader.one_time_key, bob_one_time_key.public_key, olm::KEY_LENGTH
);
return same;
}
@@ -264,9 +265,9 @@ std::size_t olm::Session::encrypt_message_length(
}
return encode_one_time_key_message_length(
- KEY_LENGTH,
- KEY_LENGTH,
- KEY_LENGTH,
+ olm::KEY_LENGTH,
+ olm::KEY_LENGTH,
+ olm::KEY_LENGTH,
message_length
);
}
@@ -298,21 +299,15 @@ std::size_t olm::Session::encrypt(
encode_one_time_key_message(
writer,
PROTOCOL_VERSION,
- KEY_LENGTH,
- KEY_LENGTH,
- KEY_LENGTH,
+ olm::KEY_LENGTH,
+ olm::KEY_LENGTH,
+ olm::KEY_LENGTH,
message_body_length,
message
);
- std::memcpy(
- writer.one_time_key, bob_one_time_key.public_key, KEY_LENGTH
- );
- std::memcpy(
- writer.identity_key, alice_identity_key.public_key, KEY_LENGTH
- );
- std::memcpy(
- writer.base_key, alice_base_key.public_key, KEY_LENGTH
- );
+ olm::store_array(writer.one_time_key, bob_one_time_key.public_key);
+ olm::store_array(writer.identity_key, alice_identity_key.public_key);
+ olm::store_array(writer.base_key, alice_base_key.public_key);
message_body = writer.message;
}
diff --git a/src/utility.cpp b/src/utility.cpp
index 1d8c5c1..bc51cff 100644
--- a/src/utility.cpp
+++ b/src/utility.cpp
@@ -23,7 +23,7 @@ olm::Utility::Utility(
size_t olm::Utility::sha256_length() {
- return olm::HMAC_SHA256_OUTPUT_LENGTH;
+ return olm::SHA256_OUTPUT_LENGTH;
}
@@ -36,7 +36,7 @@ size_t olm::Utility::sha256(
return std::size_t(-1);
}
olm::sha256(input, input_length, output);
- return 32;
+ return olm::SHA256_OUTPUT_LENGTH;
}
@@ -45,7 +45,7 @@ size_t olm::Utility::ed25519_verify(
std::uint8_t const * message, std::size_t message_length,
std::uint8_t const * signature, std::size_t signature_length
) {
- if (signature_length < 64) {
+ if (signature_length < olm::SIGNATURE_LENGTH) {
last_error = olm::ErrorCode::BAD_MESSAGE_MAC;
return std::size_t(-1);
}
diff --git a/tests/test_base64.cpp b/tests/test_base64.cpp
index e6b3710..5bae2f9 100644
--- a/tests/test_base64.cpp
+++ b/tests/test_base64.cpp
@@ -33,5 +33,4 @@ olm::decode_base64(input, input_length, output);
assert_equals(expected_output, output, output_length);
}
-
}
diff --git a/tests/test_crypto.cpp b/tests/test_crypto.cpp
index b3ff593..4606c52 100644
--- a/tests/test_crypto.cpp
+++ b/tests/test_crypto.cpp
@@ -70,7 +70,7 @@ olm::curve25519_generate_key(bob_private, bob_pair);
assert_equals(bob_private, bob_pair.private_key, 32);
assert_equals(bob_public, bob_pair.public_key, 32);
-std::uint8_t actual_agreement[olm::CURVE25519_SHARED_SECRET_LENGTH] = {};
+std::uint8_t actual_agreement[olm::KEY_LENGTH] = {};
olm::curve25519_shared_secret(alice_pair, bob_pair, actual_agreement);
diff --git a/tests/test_olm_decrypt.cpp b/tests/test_olm_decrypt.cpp
new file mode 100644
index 0000000..2a2db98
--- /dev/null
+++ b/tests/test_olm_decrypt.cpp
@@ -0,0 +1,75 @@
+#include "olm/olm.hh"
+#include "unittest.hh"
+
+const char * test_cases[] = {
+ "41776f",
+ "7fff6f0101346d671201",
+ "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d",
+ "e9e9c9c1e9e9c9e9c9c1e9e9c9c1",
+};
+
+
+const char * session_data =
+ "E0p44KO2y2pzp9FIjv0rud2wIvWDi2dx367kP4Fz/9JCMrH+aG369HGymkFtk0+PINTLB9lQRt"
+ "ohea5d7G/UXQx3r5y4IWuyh1xaRnojEZQ9a5HRZSNtvmZ9NY1f1gutYa4UtcZcbvczN8b/5Bqg"
+ "e16cPUH1v62JKLlhoAJwRkH1wU6fbyOudERg5gdXA971btR+Q2V8GKbVbO5fGKL5phmEPVXyMs"
+ "rfjLdzQrgjOTxN8Pf6iuP+WFPvfnR9lDmNCFxJUVAdLIMnLuAdxf1TGcS+zzCzEE8btIZ99mHF"
+ "dGvPXeH8qLeNZA";
+
+void decode_hex(
+ const char * input,
+ std::uint8_t * output, std::size_t output_length
+) {
+ std::uint8_t * end = output + output_length;
+ while (output != end) {
+ char high = *(input++);
+ char low = *(input++);
+ if (high >= 'a') high -= 'a' - ('9' + 1);
+ if (low >= 'a') low -= 'a' - ('9' + 1);
+ uint8_t value = ((high - '0') << 4) | (low - '0');
+ *(output++) = value;
+ }
+}
+
+void decrypt_case(int message_type, const char * test_case) {
+ std::uint8_t session_memory[olm_session_size()];
+ ::OlmSession * session = ::olm_session(session_memory);
+
+ std::uint8_t pickled[strlen(session_data)];
+ ::memcpy(pickled, session_data, sizeof(pickled));
+ ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled));
+
+ std::size_t message_length = strlen(test_case) / 2;
+ std::uint8_t * message = (std::uint8_t *) ::malloc(message_length);
+ decode_hex(test_case, message, message_length);
+
+ size_t max_length = olm_decrypt_max_plaintext_length(
+ session, message_type, message, message_length
+ );
+
+ if (max_length == std::size_t(-1)) {
+ free(message);
+ return;
+ }
+
+ uint8_t plaintext[max_length];
+ decode_hex(test_case, message, message_length);
+ olm_decrypt(
+ session, message_type,
+ message, message_length,
+ plaintext, max_length
+ );
+ free(message);
+}
+
+
+int main() {
+{
+TestCase my_test("Olm decrypt test");
+
+for (int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) {
+ decrypt_case(0, test_cases[i]);
+}
+
+}
+}