diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cipher.cpp | 96 | ||||
-rw-r--r-- | src/olm.cpp | 42 | ||||
-rw-r--r-- | src/ratchet.cpp | 32 | ||||
-rw-r--r-- | src/session.cpp | 22 |
4 files changed, 121 insertions, 71 deletions
diff --git a/src/cipher.cpp b/src/cipher.cpp index a550312..8c56efa 100644 --- a/src/cipher.cpp +++ b/src/cipher.cpp @@ -12,15 +12,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "olm/cipher.hh" +#include "olm/cipher.h" #include "olm/crypto.hh" #include "olm/memory.hh" #include <cstring> -olm::Cipher::~Cipher() { - -} - namespace { struct DerivedKeys { @@ -51,41 +47,34 @@ static void derive_keys( static const std::size_t MAC_LENGTH = 8; -} // namespace - - -olm::CipherAesSha256::CipherAesSha256( - std::uint8_t const * kdf_info, std::size_t kdf_info_length -) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) { - -} - - -std::size_t olm::CipherAesSha256::mac_length() const { +size_t aes_sha_256_cipher_mac_length(const struct olm_cipher *cipher) { return MAC_LENGTH; } - -std::size_t olm::CipherAesSha256::encrypt_ciphertext_length( - std::size_t plaintext_length -) const { +size_t aes_sha_256_cipher_encrypt_ciphertext_length( + const struct olm_cipher *cipher, size_t plaintext_length +) { return olm::aes_encrypt_cbc_length(plaintext_length); } +size_t aes_sha_256_cipher_encrypt( + const struct olm_cipher *cipher, + uint8_t const * key, size_t key_length, + uint8_t const * plaintext, size_t plaintext_length, + uint8_t * ciphertext, size_t ciphertext_length, + uint8_t * output, size_t output_length +) { + auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher); -std::size_t olm::CipherAesSha256::encrypt( - std::uint8_t const * key, std::size_t key_length, - std::uint8_t const * plaintext, std::size_t plaintext_length, - std::uint8_t * ciphertext, std::size_t ciphertext_length, - std::uint8_t * output, std::size_t output_length -) const { - if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) { + if (aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length) + < ciphertext_length) { return std::size_t(-1); } + struct DerivedKeys keys; std::uint8_t mac[SHA256_OUTPUT_LENGTH]; - derive_keys(kdf_info, kdf_info_length, key, key_length, keys); + derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys); olm::aes_encrypt_cbc( keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext @@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt( } -std::size_t olm::CipherAesSha256::decrypt_max_plaintext_length( - std::size_t ciphertext_length -) const { +size_t aes_sha_256_cipher_decrypt_max_plaintext_length( + const struct olm_cipher *cipher, + size_t ciphertext_length +) { return ciphertext_length; } -std::size_t olm::CipherAesSha256::decrypt( - std::uint8_t const * key, std::size_t key_length, - std::uint8_t const * input, std::size_t input_length, - std::uint8_t const * ciphertext, std::size_t ciphertext_length, - std::uint8_t * plaintext, std::size_t max_plaintext_length -) const { +size_t aes_sha_256_cipher_decrypt( + const struct olm_cipher *cipher, + uint8_t const * key, size_t key_length, + uint8_t const * input, size_t input_length, + uint8_t const * ciphertext, size_t ciphertext_length, + uint8_t * plaintext, size_t max_plaintext_length +) { + auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher); + DerivedKeys keys; std::uint8_t mac[SHA256_OUTPUT_LENGTH]; - derive_keys(kdf_info, kdf_info_length, key, key_length, keys); + derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys); crypto_hmac_sha256( keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac @@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt( olm::unset(keys); return plaintext_length; } + + +void aes_sha_256_cipher_destruct(struct olm_cipher *cipher) { +} + + +const cipher_ops aes_sha_256_cipher_ops = { + aes_sha_256_cipher_mac_length, + aes_sha_256_cipher_encrypt_ciphertext_length, + aes_sha_256_cipher_encrypt, + aes_sha_256_cipher_decrypt_max_plaintext_length, + aes_sha_256_cipher_decrypt, + aes_sha_256_cipher_destruct +}; + +} // namespace + + +olm_cipher *olm_cipher_aes_sha_256_init(struct olm_cipher_aes_sha_256 *cipher, + uint8_t const * kdf_info, + size_t kdf_info_length) +{ + cipher->base_cipher.ops = &aes_sha_256_cipher_ops; + cipher->kdf_info = kdf_info; + cipher->kdf_info_length = kdf_info_length; + return &(cipher->base_cipher); +} diff --git a/src/olm.cpp b/src/olm.cpp index 56bb11f..9d84758 100644 --- a/src/olm.cpp +++ b/src/olm.cpp @@ -15,9 +15,9 @@ #include "olm/olm.h" #include "olm/session.hh" #include "olm/account.hh" +#include "olm/cipher.h" #include "olm/utility.hh" #include "olm/base64.hh" -#include "olm/cipher.hh" #include "olm/memory.hh" #include <new> @@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) { static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle"; -static const olm::CipherAesSha256 PICKLE_CIPHER( - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1 -); +const olm_cipher *get_pickle_cipher() { + static olm_cipher *cipher = NULL; + static olm_cipher_aes_sha_256 PICKLE_CIPHER; + if (!cipher) { + cipher = olm_cipher_aes_sha_256_init( + &PICKLE_CIPHER, + CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 + ); + } + return cipher; +} std::size_t enc_output_length( size_t raw_length ) { - std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); - length += PICKLE_CIPHER.mac_length(); + auto *cipher = get_pickle_cipher(); + std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); return olm::encode_base64_length(length); } @@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos( std::uint8_t * output, size_t raw_length ) { - std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); - length += PICKLE_CIPHER.mac_length(); + auto *cipher = get_pickle_cipher(); + std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length); + length += cipher->ops->mac_length(cipher); return output + olm::encode_base64_length(length) - length; } @@ -85,13 +95,15 @@ std::size_t enc_output( std::uint8_t const * key, std::size_t key_length, std::uint8_t * output, size_t raw_length ) { - std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length( - raw_length + auto *cipher = get_pickle_cipher(); + std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length( + cipher, raw_length ); - std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length(); + std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher); std::size_t base64_length = olm::encode_base64_length(length); std::uint8_t * raw_output = output + base64_length - length; - PICKLE_CIPHER.encrypt( + cipher->ops->encrypt( + cipher, key, key_length, raw_output, raw_length, raw_output, ciphertext_length, @@ -112,8 +124,10 @@ std::size_t enc_input( return std::size_t(-1); } olm::decode_base64(input, b64_length, input); - std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length(); - std::size_t result = PICKLE_CIPHER.decrypt( + auto *cipher = get_pickle_cipher(); + std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher); + std::size_t result = cipher->ops->decrypt( + cipher, key, key_length, input, enc_length, input, raw_length, diff --git a/src/ratchet.cpp b/src/ratchet.cpp index 56ea106..de46be4 100644 --- a/src/ratchet.cpp +++ b/src/ratchet.cpp @@ -15,7 +15,7 @@ #include "olm/ratchet.hh" #include "olm/message.hh" #include "olm/memory.hh" -#include "olm/cipher.hh" +#include "olm/cipher.h" #include "olm/pickle.hh" #include <cstring> @@ -94,12 +94,13 @@ static void create_message_keys( static std::size_t verify_mac_and_decrypt( - olm::Cipher const & cipher, + olm_cipher const *cipher, olm::MessageKey const & message_key, olm::MessageReader const & reader, std::uint8_t * plaintext, std::size_t max_plaintext_length ) { - return cipher.decrypt( + return cipher->ops->decrypt( + cipher, message_key.key, sizeof(message_key.key), reader.input, reader.input_length, reader.ciphertext, reader.ciphertext_length, @@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain( olm::Ratchet::Ratchet( olm::KdfInfo const & kdf_info, - Cipher const & ratchet_cipher + olm_cipher const * ratchet_cipher ) : kdf_info(kdf_info), ratchet_cipher(ratchet_cipher), last_error(OlmErrorCode::OLM_SUCCESS) { @@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length( if (!sender_chain.empty()) { counter = sender_chain[0].chain_key.index; } - std::size_t padded = ratchet_cipher.encrypt_ciphertext_length( + std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length( + ratchet_cipher, plaintext_length ); return olm::encode_message_length( - counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length() + counter, olm::KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher) ); } @@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt( create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys); advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key); - std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length( + std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length( + ratchet_cipher, plaintext_length ); std::uint32_t counter = keys.index; @@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt( olm::store_array(writer.ratchet_key, ratchet_key.public_key); - ratchet_cipher.encrypt( + ratchet_cipher->ops->encrypt( + ratchet_cipher, keys.key, sizeof(keys.key), plaintext, plaintext_length, writer.ciphertext, ciphertext_length, @@ -484,7 +488,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length( ) { olm::MessageReader reader; olm::decode_message( - reader, input, input_length, ratchet_cipher.mac_length() + reader, input, input_length, + ratchet_cipher->ops->mac_length(ratchet_cipher) ); if (!reader.ciphertext) { @@ -492,7 +497,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length( return std::size_t(-1); } - return ratchet_cipher.decrypt_max_plaintext_length(reader.ciphertext_length); + return ratchet_cipher->ops->decrypt_max_plaintext_length( + ratchet_cipher, reader.ciphertext_length); } @@ -502,7 +508,8 @@ std::size_t olm::Ratchet::decrypt( ) { olm::MessageReader reader; olm::decode_message( - reader, input, input_length, ratchet_cipher.mac_length() + reader, input, input_length, + ratchet_cipher->ops->mac_length(ratchet_cipher) ); if (reader.version != PROTOCOL_VERSION) { @@ -515,7 +522,8 @@ std::size_t olm::Ratchet::decrypt( return std::size_t(-1); } - std::size_t max_length = ratchet_cipher.decrypt_max_plaintext_length( + std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length( + ratchet_cipher, reader.ciphertext_length ); diff --git a/src/session.cpp b/src/session.cpp index c0b6cf4..0d9b58a 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -13,7 +13,7 @@ * limitations under the License. */ #include "olm/session.hh" -#include "olm/cipher.hh" +#include "olm/cipher.h" #include "olm/crypto.hh" #include "olm/account.hh" #include "olm/memory.hh" @@ -30,19 +30,27 @@ static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT"; static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET"; static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS"; -static const olm::CipherAesSha256 OLM_CIPHER( - CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1 -); - static const olm::KdfInfo OLM_KDF_INFO = { ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1, RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1 }; +const olm_cipher *get_cipher() { + static olm_cipher *cipher; + static olm_cipher_aes_sha_256 OLM_CIPHER; + if (!cipher) { + cipher = olm_cipher_aes_sha_256_init( + &OLM_CIPHER, + CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1 + ); + } + return cipher; +} + } // namespace olm::Session::Session( -) : ratchet(OLM_KDF_INFO, OLM_CIPHER), +) : ratchet(OLM_KDF_INFO, get_cipher()), last_error(OlmErrorCode::OLM_SUCCESS), received_message(false) { @@ -149,7 +157,7 @@ std::size_t olm::Session::new_inbound_session( olm::MessageReader message_reader; decode_message( message_reader, reader.message, reader.message_length, - ratchet.ratchet_cipher.mac_length() + ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher) ); if (!message_reader.ratchet_key |