diff options
Diffstat (limited to 'python/olm/session.py')
-rw-r--r-- | python/olm/session.py | 106 |
1 files changed, 66 insertions, 40 deletions
diff --git a/python/olm/session.py b/python/olm/session.py index b123e8a..cba9be0 100644 --- a/python/olm/session.py +++ b/python/olm/session.py @@ -40,7 +40,7 @@ from future.utils import bytes_to_native_str # pylint: disable=no-name-in-module from _libolm import ffi, lib # type: ignore -from ._compat import URANDOM, to_bytes +from ._compat import URANDOM, to_bytearray, to_bytes from ._finalize import track_for_finalization # This is imported only for type checking purposes @@ -164,15 +164,22 @@ class Session(object): passphrase(str, optional): The passphrase to be used to encrypt the session. """ - byte_key = bytes(passphrase, "utf-8") if passphrase else b"" - key_buffer = ffi.new("char[]", byte_key) + byte_key = bytearray(passphrase, "utf-8") if passphrase else b"" pickle_length = lib.olm_pickle_session_length(self._session) pickle_buffer = ffi.new("char[]", pickle_length) - self._check_error( - lib.olm_pickle_session(self._session, key_buffer, len(byte_key), - pickle_buffer, pickle_length)) + try: + self._check_error( + lib.olm_pickle_session(self._session, + ffi.from_buffer(byte_key), + len(byte_key), + pickle_buffer, pickle_length)) + finally: + # clear out copies of the passphrase + for i in range(0, len(byte_key)): + byte_key[i] = 0 + return ffi.unpack(pickle_buffer, pickle_length) @classmethod @@ -196,16 +203,23 @@ class Session(object): if not pickle: raise ValueError("Pickle can't be empty") - byte_key = bytes(passphrase, "utf-8") if passphrase else b"" - key_buffer = ffi.new("char[]", byte_key) + byte_key = bytearray(passphrase, "utf-8") if passphrase else b"" + # copy because unpickle will destroy the buffer pickle_buffer = ffi.new("char[]", pickle) session = cls.__new__(cls) - ret = lib.olm_unpickle_session(session._session, key_buffer, - len(byte_key), pickle_buffer, - len(pickle)) - session._check_error(ret) + try: + ret = lib.olm_unpickle_session(session._session, + ffi.from_buffer(byte_key), + len(byte_key), + pickle_buffer, + len(pickle)) + session._check_error(ret) + finally: + # clear out copies of the passphrase + for i in range(0, len(byte_key)): + byte_key[i] = 0 return session @@ -217,29 +231,32 @@ class Session(object): Args: plaintext(str): The plaintext message that will be encrypted. """ - byte_plaintext = to_bytes(plaintext) + byte_plaintext = to_bytearray(plaintext) r_length = lib.olm_encrypt_random_length(self._session) random = URANDOM(r_length) - random_buffer = ffi.new("char[]", random) - message_type = lib.olm_encrypt_message_type(self._session) + try: + message_type = lib.olm_encrypt_message_type(self._session) - self._check_error(message_type) - - ciphertext_length = lib.olm_encrypt_message_length( - self._session, len(plaintext) - ) - ciphertext_buffer = ffi.new("char[]", ciphertext_length) + self._check_error(message_type) - plaintext_buffer = ffi.new("char[]", byte_plaintext) + ciphertext_length = lib.olm_encrypt_message_length( + self._session, len(byte_plaintext) + ) + ciphertext_buffer = ffi.new("char[]", ciphertext_length) - self._check_error(lib.olm_encrypt( - self._session, - plaintext_buffer, len(byte_plaintext), - random_buffer, r_length, - ciphertext_buffer, ciphertext_length, - )) + self._check_error(lib.olm_encrypt( + self._session, + ffi.from_buffer(byte_plaintext), len(byte_plaintext), + ffi.from_buffer(random), r_length, + ciphertext_buffer, ciphertext_length, + )) + finally: + # clear out copies of plaintext + if byte_plaintext is not plaintext: + for i in range(0, len(byte_plaintext)): + byte_plaintext[i] = 0 if message_type == lib.OLM_MESSAGE_TYPE_PRE_KEY: return OlmPreKeyMessage( @@ -274,22 +291,34 @@ class Session(object): raise ValueError("Ciphertext can't be empty") byte_ciphertext = to_bytes(message.ciphertext) + # make a copy the ciphertext buffer, because + # olm_decrypt_max_plaintext_length wants to destroy something ciphertext_buffer = ffi.new("char[]", byte_ciphertext) max_plaintext_length = lib.olm_decrypt_max_plaintext_length( self._session, message.message_type, ciphertext_buffer, len(byte_ciphertext) ) + self._check_error(max_plaintext_length) plaintext_buffer = ffi.new("char[]", max_plaintext_length) + + # make a copy the ciphertext buffer, because + # olm_decrypt_max_plaintext_length wants to destroy something ciphertext_buffer = ffi.new("char[]", byte_ciphertext) plaintext_length = lib.olm_decrypt( - self._session, message.message_type, ciphertext_buffer, - len(byte_ciphertext), plaintext_buffer, max_plaintext_length + self._session, message.message_type, + ciphertext_buffer, len(byte_ciphertext), + plaintext_buffer, max_plaintext_length ) self._check_error(plaintext_length) - return bytes_to_native_str( + plaintext = bytes_to_native_str( ffi.unpack(plaintext_buffer, plaintext_length)) + # clear out copies of the plaintext + lib.memset(plaintext_buffer, 0, max_plaintext_length) + + return plaintext + @property def id(self): # type: () -> str @@ -331,16 +360,16 @@ class Session(object): ret = None byte_ciphertext = to_bytes(message.ciphertext) - + # make a copy, because olm_matches_inbound_session(_from) will distroy + # it message_buffer = ffi.new("char[]", byte_ciphertext) if identity_key: byte_id_key = to_bytes(identity_key) - identity_key_buffer = ffi.new("char[]", byte_id_key) ret = lib.olm_matches_inbound_session_from( self._session, - identity_key_buffer, len(byte_id_key), + ffi.from_buffer(byte_id_key), len(byte_id_key), message_buffer, len(byte_ciphertext) ) @@ -447,14 +476,11 @@ class OutboundSession(Session): self._session) random = URANDOM(session_random_length) - random_buffer = ffi.new("char[]", random) - identity_key_buffer = ffi.new("char[]", byte_id_key) - one_time_key_buffer = ffi.new("char[]", byte_one_time) self._check_error(lib.olm_create_outbound_session( self._session, account._account, - identity_key_buffer, len(byte_id_key), - one_time_key_buffer, len(byte_one_time), - random_buffer, session_random_length + ffi.from_buffer(byte_id_key), len(byte_id_key), + ffi.from_buffer(byte_one_time), len(byte_one_time), + ffi.from_buffer(random), session_random_length )) |