/* * Copyright 2019 New Vector Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "olm_sas.h" #include "olm/olm.h" using namespace AndroidOlmSdk; JNIEXPORT jlong OLM_SAS_FUNC_DEF(createNewSASJni)(JNIEnv *env, jobject thiz) { size_t sasSize = olm_sas_size(); OlmSAS *sasPtr = (OlmSAS *) malloc(sasSize); const char* errorMessage = NULL; if (!sasPtr) { LOGE("## createNewSASJni(): failure - init SAS OOM"); env->ThrowNew(env->FindClass("java/lang/Exception"), "init sas OOM"); } else { sasPtr = olm_sas(sasPtr) LOGD(" ## createNewSASJni(): success - sasPtr=%p (jlong)(intptr_t)accountPtr=%lld",sasPtr,(jlong)(intptr_t)sasPtr); } size_t randomSize = olm_create_sas_random_length(sasPtr); uint8_t *randomBuffPtr = NULL; LOGD("## createNewSASJni(): randomSize=%lu",static_cast<long unsigned int>(randomSize)); if ( (0 != randomSize) && !setRandomInBuffer(env, &randomBuffPtr, randomSize)) { LOGE("## createNewSASJni(): failure - random buffer init"); errorMessage = "Failed to init private key"; } else { size_t result = olm_create_sas(sasPtr, randomBuffPtr, randomSize); if (result == olm_error()) { errorMessage = (const char *)olm_sas_last_error(sasPtr); LOGE("## createNewSASJni(): failure - error creating SAS Msg=%s", errorMessage); } } if (randomBuffPtr) { memset(randomBuffPtr, 0, randomSize); free(randomBuffPtr); } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } return (jlong)(intptr_t)sasPtr; } JNIEXPORT void OLM_SAS_FUNC_DEF(releaseSASJni)(JNIEnv *env, jobject thiz) { LOGD("## releaseSASJni(): IN"); OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); if (!sasPtr) { LOGE("## releaseSessionJni(): failure - invalid Session ptr=NULL"); } else { olm_clear_sas(sasPtr); // even if free(NULL) does not crash, logs are performed for debug purpose free(sasPtr); } } JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(getPubKeyJni)(JNIEnv *env, jobject thiz) { LOGD("## getPubKeyJni(): IN"); const char* errorMessage = NULL; jbyteArray returnValue = 0; OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); if (!sasPtr) { LOGE("## getPubKeyJni(): failure - invalid SAS ptr=NULL"); errorMessage = "invalid SAS ptr=NULL"; } else { size_t pubKeyLength = olm_sas_pubkey_length(sasPtr); void *pubkey = malloc(pubKeyLength*sizeof(uint8_t)); size_t result = olm_sas_get_pubkey(sasPtr, pubkey, pubKeyLength); if (result == olm_error()) { errorMessage = (const char *)olm_sas_last_error(sasPtr); LOGE("## getPubKeyJni(): failure - error getting pub key Msg=%s", errorMessage); } else { returnValue = env->NewByteArray(pubKeyLength); env->SetByteArrayRegion(returnValue, 0 , pubKeyLength, (jbyte*)pubkey); } if (pubkey) { free(pubkey); } } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } return returnValue; } JNIEXPORT void OLM_SAS_FUNC_DEF(setTheirPubKey)(JNIEnv *env, jobject thiz,jbyteArray pubKeyBuffer) { OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); const char* errorMessage = NULL; jbyte *pubKeyPtr = NULL; jboolean pubKeyWasCopied = JNI_FALSE; if (!sasPtr) { LOGE("## setTheirPubKey(): failure - invalid SAS ptr=NULL"); errorMessage = "invalid SAS ptr=NULL"; } else if(!pubKeyBuffer) { LOGE("## setTheirPubKey(): failure - invalid info"); errorMessage = "invalid pubKey"; } else if (!(pubKeyPtr = env->GetByteArrayElements(pubKeyBuffer, &pubKeyWasCopied))) { LOGE(" ## setTheirPubKey(): failure - info JNI allocation OOM"); errorMessage = "info JNI allocation OOM"; } else { size_t pubKeyLength = (size_t)env->GetArrayLength(pubKeyBuffer); size_t result = olm_sas_set_their_key(sasPtr,pubKeyPtr,pubKeyLength); if (result == olm_error()) { errorMessage = (const char *)olm_sas_last_error(sasPtr); LOGE("## setTheirPubKey(): failure - error setting their key Msg=%s", errorMessage); } } // free alloc if (pubKeyPtr) { if (pubKeyWasCopied) { memset(pubKeyPtr, 0, (size_t)env->GetArrayLength(pubKeyBuffer)); } env->ReleaseByteArrayElements(pubKeyBuffer, pubKeyPtr, JNI_ABORT); } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } } JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(generateShortCodeJni)(JNIEnv *env, jobject thiz, jbyteArray infoStringBytes, jint byteNb) { LOGD("## generateShortCodeJni(): IN"); const char* errorMessage = NULL; jbyteArray returnValue = 0; OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); jbyte *infoPtr = NULL; jboolean infoWasCopied = JNI_FALSE; if (!sasPtr) { LOGE("## generateShortCodeJni(): failure - invalid SAS ptr=NULL"); errorMessage = "invalid SAS ptr=NULL"; } else if(!infoStringBytes) { LOGE("## generateShortCodeJni(): failure - invalid info"); errorMessage = "invalid info"; } else if (!(infoPtr = env->GetByteArrayElements(infoStringBytes, &infoWasCopied))) { LOGE(" ## generateShortCodeJni(): failure - info JNI allocation OOM"); errorMessage = "info JNI allocation OOM"; } else { size_t shortBytesCodeLength = (size_t) byteNb; void *shortBytesCode = malloc(shortBytesCodeLength * sizeof(uint8_t)); size_t infoLength = (size_t)env->GetArrayLength(infoStringBytes); olm_sas_generate_bytes(sasPtr, infoPtr, infoLength, shortBytesCode, shortBytesCodeLength); returnValue = env->NewByteArray(shortBytesCodeLength); env->SetByteArrayRegion(returnValue, 0 , shortBytesCodeLength, (jbyte*)shortBytesCode); free(shortBytesCode); } // free alloc if (infoPtr) { if (infoWasCopied) { memset(infoPtr, 0, (size_t)env->GetArrayLength(infoStringBytes)); } env->ReleaseByteArrayElements(infoStringBytes, infoPtr, JNI_ABORT); } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } return returnValue; } JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacJni)(JNIEnv *env, jobject thiz,jbyteArray messageBuffer,jbyteArray infoBuffer) { LOGD("## calculateMacJni(): IN"); const char* errorMessage = NULL; jbyteArray returnValue = 0; OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); jbyte *messagePtr = NULL; jboolean messageWasCopied = JNI_FALSE; jbyte *infoPtr = NULL; jboolean infoWasCopied = JNI_FALSE; if (!sasPtr) { LOGE("## calculateMacJni(): failure - invalid SAS ptr=NULL"); errorMessage = "invalid SAS ptr=NULL"; } else if(!messageBuffer) { LOGE("## calculateMacJni(): failure - invalid message"); errorMessage = "invalid info"; } else if (!(messagePtr = env->GetByteArrayElements(messageBuffer, &messageWasCopied))) { LOGE(" ## calculateMacJni(): failure - message JNI allocation OOM"); errorMessage = "message JNI allocation OOM"; } else if (!(infoPtr = env->GetByteArrayElements(infoBuffer, &infoWasCopied))) { LOGE(" ## calculateMacJni(): failure - info JNI allocation OOM"); errorMessage = "info JNI allocation OOM"; } else { size_t infoLength = (size_t)env->GetArrayLength(infoBuffer); size_t messageLength = (size_t)env->GetArrayLength(messageBuffer); size_t macLength = olm_sas_mac_length(sasPtr); void *macPtr = malloc(macLength*sizeof(uint8_t)); size_t result = olm_sas_calculate_mac(sasPtr,messagePtr,messageLength,infoPtr,infoLength,macPtr,macLength); if (result == olm_error()) { errorMessage = (const char *)olm_sas_last_error(sasPtr); LOGE("## calculateMacJni(): failure - error calculating SAS mac Msg=%s", errorMessage); } else { returnValue = env->NewByteArray(macLength); env->SetByteArrayRegion(returnValue, 0 , macLength, (jbyte*)macPtr); } if (macPtr) { free(macPtr); } } // free alloc if (infoPtr) { if (infoWasCopied) { memset(infoPtr, 0, (size_t)env->GetArrayLength(infoBuffer)); } env->ReleaseByteArrayElements(infoBuffer, infoPtr, JNI_ABORT); } if (messagePtr) { if (messageWasCopied) { memset(messagePtr, 0, (size_t)env->GetArrayLength(messageBuffer)); } env->ReleaseByteArrayElements(messageBuffer, messagePtr, JNI_ABORT); } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } return returnValue; } JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobject thiz,jbyteArray messageBuffer,jbyteArray infoBuffer) { LOGD("## calculateMacLongKdfJni(): IN"); const char* errorMessage = NULL; jbyteArray returnValue = 0; OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz); jbyte *messagePtr = NULL; jboolean messageWasCopied = JNI_FALSE; jbyte *infoPtr = NULL; jboolean infoWasCopied = JNI_FALSE; if (!sasPtr) { LOGE("## calculateMacLongKdfJni(): failure - invalid SAS ptr=NULL"); errorMessage = "invalid SAS ptr=NULL"; } else if(!messageBuffer) { LOGE("## calculateMacLongKdfJni(): failure - invalid message"); errorMessage = "invalid info"; } else if (!(messagePtr = env->GetByteArrayElements(messageBuffer, &messageWasCopied))) { LOGE(" ## calculateMacLongKdfJni(): failure - message JNI allocation OOM"); errorMessage = "message JNI allocation OOM"; } else if (!(infoPtr = env->GetByteArrayElements(infoBuffer, &infoWasCopied))) { LOGE(" ## calculateMacLongKdfJni(): failure - info JNI allocation OOM"); errorMessage = "info JNI allocation OOM"; } else { size_t infoLength = (size_t)env->GetArrayLength(infoBuffer); size_t messageLength = (size_t)env->GetArrayLength(messageBuffer); size_t macLength = olm_sas_mac_length(sasPtr); void *macPtr = malloc(macLength*sizeof(uint8_t)); size_t result = olm_sas_calculate_mac_long_kdf(sasPtr,messagePtr,messageLength,infoPtr,infoLength,macPtr,macLength); if (result == olm_error()) { errorMessage = (const char *)olm_sas_last_error(sasPtr); LOGE("## calculateMacLongKdfJni(): failure - error calculating SAS mac Msg=%s", errorMessage); } else { returnValue = env->NewByteArray(macLength); env->SetByteArrayRegion(returnValue, 0 , macLength, (jbyte*)macPtr); } if (macPtr) { free(macPtr); } } // free alloc if (infoPtr) { if (infoWasCopied) { memset(infoPtr, 0, (size_t)env->GetArrayLength(infoBuffer)); } env->ReleaseByteArrayElements(infoBuffer, infoPtr, JNI_ABORT); } if (messagePtr) { if (messageWasCopied) { memset(messagePtr, 0, (size_t)env->GetArrayLength(messageBuffer)); } env->ReleaseByteArrayElements(messageBuffer, messagePtr, JNI_ABORT); } if (errorMessage) { env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage); } return returnValue; }