aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-05-16 16:25:09 +0100
committerRichard van der Hoff <richard@matrix.org>2016-05-23 18:55:05 +0100
commit294cf482ea49f690ac9eaad52f2574a90b2e51e6 (patch)
treea0b7e6505b020d30177a177f607bc514b9b15ed6
parentf9139dfa6aea6ca8c4054a5b5fff9be484d978fa (diff)
Convert cipher.hh to plain C
-rw-r--r--include/olm/cipher.h134
-rw-r--r--include/olm/cipher.hh132
-rw-r--r--include/olm/ratchet.hh8
-rw-r--r--src/cipher.cpp96
-rw-r--r--src/olm.cpp42
-rw-r--r--src/ratchet.cpp32
-rw-r--r--src/session.cpp22
-rw-r--r--tests/test_ratchet.cpp7
8 files changed, 263 insertions, 210 deletions
diff --git a/include/olm/cipher.h b/include/olm/cipher.h
new file mode 100644
index 0000000..0d6fd5b
--- /dev/null
+++ b/include/olm/cipher.h
@@ -0,0 +1,134 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef OLM_CIPHER_H_
+#define OLM_CIPHER_H_
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct olm_cipher;
+
+struct cipher_ops {
+ /**
+ * Returns the length of the message authentication code that will be
+ * appended to the output.
+ */
+ size_t (*mac_length)(const struct olm_cipher *cipher);
+
+ /**
+ * Returns the length of cipher-text for a given length of plain-text.
+ */
+ size_t (*encrypt_ciphertext_length)(const struct olm_cipher *cipher,
+ size_t plaintext_length);
+
+ /*
+ * Encrypts the plain-text into the output buffer and authenticates the
+ * contents of the output buffer covering both cipher-text and any other
+ * associated data in the output buffer.
+ *
+ * |---------------------------------------output_length-->|
+ * output |--ciphertext_length-->| |---mac_length-->|
+ * ciphertext
+ *
+ * The plain-text pointers and cipher-text pointers may be the same.
+ *
+ * Returns size_t(-1) if the length of the cipher-text or the output
+ * buffer is too small. Otherwise returns the length of the output buffer.
+ */
+ size_t (*encrypt)(
+ const struct olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * ciphertext, size_t ciphertext_length,
+ uint8_t * output, size_t output_length
+ );
+
+ /**
+ * Returns the maximum length of plain-text that a given length of
+ * cipher-text can contain.
+ */
+ size_t (*decrypt_max_plaintext_length)(
+ const struct olm_cipher *cipher,
+ size_t ciphertext_length
+ );
+
+ /**
+ * Authenticates the input and decrypts the cipher-text into the plain-text
+ * buffer.
+ *
+ * |----------------------------------------input_length-->|
+ * input |--ciphertext_length-->| |---mac_length-->|
+ * ciphertext
+ *
+ * The plain-text pointers and cipher-text pointers may be the same.
+ *
+ * Returns size_t(-1) if the length of the plain-text buffer is too
+ * small or if the authentication check fails. Otherwise returns the length
+ * of the plain text.
+ */
+ size_t (*decrypt)(
+ const struct olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * input, size_t input_length,
+ uint8_t const * ciphertext, size_t ciphertext_length,
+ uint8_t * plaintext, size_t max_plaintext_length
+ );
+
+ /** destroy any private data associated with this cipher */
+ void (*destruct)(struct olm_cipher *cipher);
+};
+
+struct olm_cipher {
+ const struct cipher_ops *ops;
+ /* cipher-specific fields follow */
+};
+
+struct olm_cipher_aes_sha_256 {
+ struct olm_cipher base_cipher;
+
+ uint8_t const * kdf_info;
+ size_t kdf_info_length;
+};
+
+
+/**
+ * initialises a cipher type which uses AES256 for encryption and SHA256 for
+ * authentication.
+ *
+ * cipher: structure to be initialised
+ *
+ * kdf_info: context string for the HKDF used for deriving the AES256 key, HMAC
+ * key, and AES IV, from the key material passed to encrypt/decrypt. Note that
+ * this is NOT copied so must have a lifetime at least as long as the cipher
+ * instance.
+ *
+ * kdf_info_length: length of context string kdf_info
+ */
+struct olm_cipher *olm_cipher_aes_sha_256_init(
+ struct olm_cipher_aes_sha_256 *cipher,
+ uint8_t const * kdf_info,
+ size_t kdf_info_length);
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
+
+#endif /* OLM_CIPHER_H_ */
diff --git a/include/olm/cipher.hh b/include/olm/cipher.hh
deleted file mode 100644
index c561972..0000000
--- a/include/olm/cipher.hh
+++ /dev/null
@@ -1,132 +0,0 @@
-/* Copyright 2015 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef OLM_CIPHER_HH_
-#define OLM_CIPHER_HH_
-
-#include <cstdint>
-#include <cstddef>
-
-namespace olm {
-
-class Cipher {
-public:
- virtual ~Cipher();
-
- /**
- * Returns the length of the message authentication code that will be
- * appended to the output.
- */
- virtual std::size_t mac_length() const = 0;
-
- /**
- * Returns the length of cipher-text for a given length of plain-text.
- */
- virtual std::size_t encrypt_ciphertext_length(
- std::size_t plaintext_length
- ) const = 0;
-
- /*
- * Encrypts the plain-text into the output buffer and authenticates the
- * contents of the output buffer covering both cipher-text and any other
- * associated data in the output buffer.
- *
- * |---------------------------------------output_length-->|
- * output |--ciphertext_length-->| |---mac_length-->|
- * ciphertext
- *
- * The plain-text pointers and cipher-text pointers may be the same.
- *
- * Returns std::size_t(-1) if the length of the cipher-text or the output
- * buffer is too small. Otherwise returns the length of the output buffer.
- */
- virtual std::size_t encrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * plaintext, std::size_t plaintext_length,
- std::uint8_t * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * output, std::size_t output_length
- ) const = 0;
-
- /**
- * Returns the maximum length of plain-text that a given length of
- * cipher-text can contain.
- */
- virtual std::size_t decrypt_max_plaintext_length(
- std::size_t ciphertext_length
- ) const = 0;
-
- /**
- * Authenticates the input and decrypts the cipher-text into the plain-text
- * buffer.
- *
- * |----------------------------------------input_length-->|
- * input |--ciphertext_length-->| |---mac_length-->|
- * ciphertext
- *
- * The plain-text pointers and cipher-text pointers may be the same.
- *
- * Returns std::size_t(-1) if the length of the plain-text buffer is too
- * small or if the authentication check fails. Otherwise returns the length
- * of the plain text.
- */
- virtual std::size_t decrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * input, std::size_t input_length,
- std::uint8_t const * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * plaintext, std::size_t max_plaintext_length
- ) const = 0;
-};
-
-
-class CipherAesSha256 : public Cipher {
-public:
- CipherAesSha256(
- std::uint8_t const * kdf_info, std::size_t kdf_info_length
- );
-
- virtual std::size_t mac_length() const;
-
- virtual std::size_t encrypt_ciphertext_length(
- std::size_t plaintext_length
- ) const;
-
- virtual std::size_t encrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * plaintext, std::size_t plaintext_length,
- std::uint8_t * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * output, std::size_t output_length
- ) const;
-
- virtual std::size_t decrypt_max_plaintext_length(
- std::size_t ciphertext_length
- ) const;
-
- virtual std::size_t decrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * input, std::size_t input_length,
- std::uint8_t const * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * plaintext, std::size_t max_plaintext_length
- ) const;
-
-private:
- std::uint8_t const * kdf_info;
- std::size_t kdf_info_length;
-};
-
-
-} // namespace
-
-
-#endif /* OLM_CIPHER_HH_ */
diff --git a/include/olm/ratchet.hh b/include/olm/ratchet.hh
index b2787c7..e1d462d 100644
--- a/include/olm/ratchet.hh
+++ b/include/olm/ratchet.hh
@@ -17,9 +17,9 @@
#include "olm/list.hh"
#include "olm/error.h"
-namespace olm {
+struct olm_cipher;
-class Cipher;
+namespace olm {
typedef std::uint8_t SharedKey[olm::KEY_LENGTH];
@@ -69,14 +69,14 @@ struct Ratchet {
Ratchet(
KdfInfo const & kdf_info,
- Cipher const & ratchet_cipher
+ olm_cipher const *ratchet_cipher
);
/** A some strings identifying the application to feed into the KDF. */
KdfInfo const & kdf_info;
/** The AEAD cipher to use for encrypting messages. */
- Cipher const & ratchet_cipher;
+ olm_cipher const *ratchet_cipher;
/** The last error that happened encrypting or decrypting a message. */
OlmErrorCode last_error;
diff --git a/src/cipher.cpp b/src/cipher.cpp
index a550312..8c56efa 100644
--- a/src/cipher.cpp
+++ b/src/cipher.cpp
@@ -12,15 +12,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "olm/cipher.hh"
+#include "olm/cipher.h"
#include "olm/crypto.hh"
#include "olm/memory.hh"
#include <cstring>
-olm::Cipher::~Cipher() {
-
-}
-
namespace {
struct DerivedKeys {
@@ -51,41 +47,34 @@ static void derive_keys(
static const std::size_t MAC_LENGTH = 8;
-} // namespace
-
-
-olm::CipherAesSha256::CipherAesSha256(
- std::uint8_t const * kdf_info, std::size_t kdf_info_length
-) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) {
-
-}
-
-
-std::size_t olm::CipherAesSha256::mac_length() const {
+size_t aes_sha_256_cipher_mac_length(const struct olm_cipher *cipher) {
return MAC_LENGTH;
}
-
-std::size_t olm::CipherAesSha256::encrypt_ciphertext_length(
- std::size_t plaintext_length
-) const {
+size_t aes_sha_256_cipher_encrypt_ciphertext_length(
+ const struct olm_cipher *cipher, size_t plaintext_length
+) {
return olm::aes_encrypt_cbc_length(plaintext_length);
}
+size_t aes_sha_256_cipher_encrypt(
+ const struct olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * ciphertext, size_t ciphertext_length,
+ uint8_t * output, size_t output_length
+) {
+ auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
-std::size_t olm::CipherAesSha256::encrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * plaintext, std::size_t plaintext_length,
- std::uint8_t * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * output, std::size_t output_length
-) const {
- if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) {
+ if (aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length)
+ < ciphertext_length) {
return std::size_t(-1);
}
+
struct DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH];
- derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
+ derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
olm::aes_encrypt_cbc(
keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext
@@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt(
}
-std::size_t olm::CipherAesSha256::decrypt_max_plaintext_length(
- std::size_t ciphertext_length
-) const {
+size_t aes_sha_256_cipher_decrypt_max_plaintext_length(
+ const struct olm_cipher *cipher,
+ size_t ciphertext_length
+) {
return ciphertext_length;
}
-std::size_t olm::CipherAesSha256::decrypt(
- std::uint8_t const * key, std::size_t key_length,
- std::uint8_t const * input, std::size_t input_length,
- std::uint8_t const * ciphertext, std::size_t ciphertext_length,
- std::uint8_t * plaintext, std::size_t max_plaintext_length
-) const {
+size_t aes_sha_256_cipher_decrypt(
+ const struct olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * input, size_t input_length,
+ uint8_t const * ciphertext, size_t ciphertext_length,
+ uint8_t * plaintext, size_t max_plaintext_length
+) {
+ auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
+
DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH];
- derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
+ derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
crypto_hmac_sha256(
keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac
@@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt(
olm::unset(keys);
return plaintext_length;
}
+
+
+void aes_sha_256_cipher_destruct(struct olm_cipher *cipher) {
+}
+
+
+const cipher_ops aes_sha_256_cipher_ops = {
+ aes_sha_256_cipher_mac_length,
+ aes_sha_256_cipher_encrypt_ciphertext_length,
+ aes_sha_256_cipher_encrypt,
+ aes_sha_256_cipher_decrypt_max_plaintext_length,
+ aes_sha_256_cipher_decrypt,
+ aes_sha_256_cipher_destruct
+};
+
+} // namespace
+
+
+olm_cipher *olm_cipher_aes_sha_256_init(struct olm_cipher_aes_sha_256 *cipher,
+ uint8_t const * kdf_info,
+ size_t kdf_info_length)
+{
+ cipher->base_cipher.ops = &aes_sha_256_cipher_ops;
+ cipher->kdf_info = kdf_info;
+ cipher->kdf_info_length = kdf_info_length;
+ return &(cipher->base_cipher);
+}
diff --git a/src/olm.cpp b/src/olm.cpp
index 56bb11f..9d84758 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -15,9 +15,9 @@
#include "olm/olm.h"
#include "olm/session.hh"
#include "olm/account.hh"
+#include "olm/cipher.h"
#include "olm/utility.hh"
#include "olm/base64.hh"
-#include "olm/cipher.hh"
#include "olm/memory.hh"
#include <new>
@@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) {
static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";
-static const olm::CipherAesSha256 PICKLE_CIPHER(
- CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
-);
+const olm_cipher *get_pickle_cipher() {
+ static olm_cipher *cipher = NULL;
+ static olm_cipher_aes_sha_256 PICKLE_CIPHER;
+ if (!cipher) {
+ cipher = olm_cipher_aes_sha_256_init(
+ &PICKLE_CIPHER,
+ CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
+ );
+ }
+ return cipher;
+}
std::size_t enc_output_length(
size_t raw_length
) {
- std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
- length += PICKLE_CIPHER.mac_length();
+ auto *cipher = get_pickle_cipher();
+ std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+ length += cipher->ops->mac_length(cipher);
return olm::encode_base64_length(length);
}
@@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos(
std::uint8_t * output,
size_t raw_length
) {
- std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
- length += PICKLE_CIPHER.mac_length();
+ auto *cipher = get_pickle_cipher();
+ std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+ length += cipher->ops->mac_length(cipher);
return output + olm::encode_base64_length(length) - length;
}
@@ -85,13 +95,15 @@ std::size_t enc_output(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, size_t raw_length
) {
- std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length(
- raw_length
+ auto *cipher = get_pickle_cipher();
+ std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
+ cipher, raw_length
);
- std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length();
+ std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
std::size_t base64_length = olm::encode_base64_length(length);
std::uint8_t * raw_output = output + base64_length - length;
- PICKLE_CIPHER.encrypt(
+ cipher->ops->encrypt(
+ cipher,
key, key_length,
raw_output, raw_length,
raw_output, ciphertext_length,
@@ -112,8 +124,10 @@ std::size_t enc_input(
return std::size_t(-1);
}
olm::decode_base64(input, b64_length, input);
- std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length();
- std::size_t result = PICKLE_CIPHER.decrypt(
+ auto *cipher = get_pickle_cipher();
+ std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
+ std::size_t result = cipher->ops->decrypt(
+ cipher,
key, key_length,
input, enc_length,
input, raw_length,
diff --git a/src/ratchet.cpp b/src/ratchet.cpp
index 56ea106..de46be4 100644
--- a/src/ratchet.cpp
+++ b/src/ratchet.cpp
@@ -15,7 +15,7 @@
#include "olm/ratchet.hh"
#include "olm/message.hh"
#include "olm/memory.hh"
-#include "olm/cipher.hh"
+#include "olm/cipher.h"
#include "olm/pickle.hh"
#include <cstring>
@@ -94,12 +94,13 @@ static void create_message_keys(
static std::size_t verify_mac_and_decrypt(
- olm::Cipher const & cipher,
+ olm_cipher const *cipher,
olm::MessageKey const & message_key,
olm::MessageReader const & reader,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) {
- return cipher.decrypt(
+ return cipher->ops->decrypt(
+ cipher,
message_key.key, sizeof(message_key.key),
reader.input, reader.input_length,
reader.ciphertext, reader.ciphertext_length,
@@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain(
olm::Ratchet::Ratchet(
olm::KdfInfo const & kdf_info,
- Cipher const & ratchet_cipher
+ olm_cipher const * ratchet_cipher
) : kdf_info(kdf_info),
ratchet_cipher(ratchet_cipher),
last_error(OlmErrorCode::OLM_SUCCESS) {
@@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length(
if (!sender_chain.empty()) {
counter = sender_chain[0].chain_key.index;
}
- std::size_t padded = ratchet_cipher.encrypt_ciphertext_length(
+ std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
+ ratchet_cipher,
plaintext_length
);
return olm::encode_message_length(
- counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length()
+ counter, olm::KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
);
}
@@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt(
create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys);
advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key);
- std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length(
+ std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
+ ratchet_cipher,
plaintext_length
);
std::uint32_t counter = keys.index;
@@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt(
olm::store_array(writer.ratchet_key, ratchet_key.public_key);
- ratchet_cipher.encrypt(
+ ratchet_cipher->ops->encrypt(
+ ratchet_cipher,
keys.key, sizeof(keys.key),
plaintext, plaintext_length,
writer.ciphertext, ciphertext_length,
@@ -484,7 +488,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
) {
olm::MessageReader reader;
olm::decode_message(
- reader, input, input_length, ratchet_cipher.mac_length()
+ reader, input, input_length,
+ ratchet_cipher->ops->mac_length(ratchet_cipher)
);
if (!reader.ciphertext) {
@@ -492,7 +497,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
return std::size_t(-1);
}
- return ratchet_cipher.decrypt_max_plaintext_length(reader.ciphertext_length);
+ return ratchet_cipher->ops->decrypt_max_plaintext_length(
+ ratchet_cipher, reader.ciphertext_length);
}
@@ -502,7 +508,8 @@ std::size_t olm::Ratchet::decrypt(
) {
olm::MessageReader reader;
olm::decode_message(
- reader, input, input_length, ratchet_cipher.mac_length()
+ reader, input, input_length,
+ ratchet_cipher->ops->mac_length(ratchet_cipher)
);
if (reader.version != PROTOCOL_VERSION) {
@@ -515,7 +522,8 @@ std::size_t olm::Ratchet::decrypt(
return std::size_t(-1);
}
- std::size_t max_length = ratchet_cipher.decrypt_max_plaintext_length(
+ std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
+ ratchet_cipher,
reader.ciphertext_length
);
diff --git a/src/session.cpp b/src/session.cpp
index c0b6cf4..0d9b58a 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -13,7 +13,7 @@
* limitations under the License.
*/
#include "olm/session.hh"
-#include "olm/cipher.hh"
+#include "olm/cipher.h"
#include "olm/crypto.hh"
#include "olm/account.hh"
#include "olm/memory.hh"
@@ -30,19 +30,27 @@ static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT";
static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET";
static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS";
-static const olm::CipherAesSha256 OLM_CIPHER(
- CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
-);
-
static const olm::KdfInfo OLM_KDF_INFO = {
ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1,
RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1
};
+const olm_cipher *get_cipher() {
+ static olm_cipher *cipher;
+ static olm_cipher_aes_sha_256 OLM_CIPHER;
+ if (!cipher) {
+ cipher = olm_cipher_aes_sha_256_init(
+ &OLM_CIPHER,
+ CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
+ );
+ }
+ return cipher;
+}
+
} // namespace
olm::Session::Session(
-) : ratchet(OLM_KDF_INFO, OLM_CIPHER),
+) : ratchet(OLM_KDF_INFO, get_cipher()),
last_error(OlmErrorCode::OLM_SUCCESS),
received_message(false) {
@@ -149,7 +157,7 @@ std::size_t olm::Session::new_inbound_session(
olm::MessageReader message_reader;
decode_message(
message_reader, reader.message, reader.message_length,
- ratchet.ratchet_cipher.mac_length()
+ ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher)
);
if (!message_reader.ratchet_key
diff --git a/tests/test_ratchet.cpp b/tests/test_ratchet.cpp
index 8f89048..2c8bc1b 100644
--- a/tests/test_ratchet.cpp
+++ b/tests/test_ratchet.cpp
@@ -13,7 +13,7 @@
* limitations under the License.
*/
#include "olm/ratchet.hh"
-#include "olm/cipher.hh"
+#include "olm/cipher.h"
#include "unittest.hh"
@@ -28,8 +28,9 @@ olm::KdfInfo kdf_info = {
ratchet_info, sizeof(ratchet_info) - 1
};
-olm::CipherAesSha256 cipher(
- message_info, sizeof(message_info) - 1
+olm_cipher_aes_sha_256 cipher0;
+olm_cipher *cipher = olm_cipher_aes_sha_256_init(
+ &cipher0, message_info, sizeof(message_info) - 1
);
std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF";