aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmSasTest.java11
-rw-r--r--android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java10
-rw-r--r--android/olm-sdk/src/main/jni/olm_sas.cpp80
-rw-r--r--android/olm-sdk/src/main/jni/olm_sas.h1
4 files changed, 102 insertions, 0 deletions
diff --git a/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmSasTest.java b/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmSasTest.java
index a757050..39127cd 100644
--- a/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmSasTest.java
+++ b/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmSasTest.java
@@ -83,6 +83,17 @@ public class OlmSasTest {
Log.e(OlmSasTest.class.getSimpleName(), "#### Bob Mac is " + new String(bobMac, "UTF-8"));
+ byte[] aliceLongKdfMac = aliceSas.calculateMacLongKdf("Hello world!", "SAS");
+ byte[] bobLongKdfMac = bobSas.calculateMacLongKdf("Hello world!", "SAS");
+
+ assertTrue(aliceLongKdfMac.length > 0 && bobLongKdfMac.length > 0);
+ assertEquals(aliceLongKdfMac.length, bobLongKdfMac.length);
+ assertArrayEquals(aliceLongKdfMac, bobLongKdfMac);
+
+ Log.e(OlmSasTest.class.getSimpleName(), "#### Alice lkdf Mac is " + new String(aliceLongKdfMac, "UTF-8"));
+ Log.e(OlmSasTest.class.getSimpleName(), "#### Bob lkdf Mac is " + new String(bobLongKdfMac, "UTF-8"));
+
+
} catch (Exception e) {
assertTrue("OlmSas init failed " + e.getMessage(), false);
e.printStackTrace();
diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java
index 2869aa4..70cfb8c 100644
--- a/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java
+++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmSAS.java
@@ -103,6 +103,14 @@ public class OlmSAS {
}
}
+ public byte[] calculateMacLongKdf(String message, String info) throws OlmException {
+ try {
+ return calculateMacLongKdfJni(message.getBytes("UTF-8"), info.getBytes("UTF-8"));
+ } catch (UnsupportedEncodingException e) {
+ throw new OlmException(OlmException.EXCEPTION_CODE_SAS_ERROR, e.getMessage());
+ }
+ }
+
/**
* Create an OLM session in native side.<br>
* Do not forget to call {@link #releaseSASJni()} when JAVA side is done.
@@ -127,6 +135,8 @@ public class OlmSAS {
private native byte[] calculateMacJni(byte[] message, byte[] info);
+ private native byte[] calculateMacLongKdfJni(byte[] message, byte[] info);
+
/**
* Release native session and invalid its JAVA reference counter part.<br>
* Public API for {@link #releaseSASJni()}.
diff --git a/android/olm-sdk/src/main/jni/olm_sas.cpp b/android/olm-sdk/src/main/jni/olm_sas.cpp
index 2ad1e0f..40b9183 100644
--- a/android/olm-sdk/src/main/jni/olm_sas.cpp
+++ b/android/olm-sdk/src/main/jni/olm_sas.cpp
@@ -307,4 +307,84 @@ JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacJni)(JNIEnv *env, jobject thiz
}
return returnValue;
+}
+
+JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobject thiz,jbyteArray messageBuffer,jbyteArray infoBuffer) {
+ LOGD("## calculateMacLongKdfJni(): IN");
+ const char* errorMessage = NULL;
+ jbyteArray returnValue = 0;
+ OlmSAS* sasPtr = getOlmSasInstanceId(env, thiz);
+
+ jbyte *messagePtr = NULL;
+ jboolean messageWasCopied = JNI_FALSE;
+
+ jbyte *infoPtr = NULL;
+ jboolean infoWasCopied = JNI_FALSE;
+
+ if (!sasPtr)
+ {
+ LOGE("## calculateMacLongKdfJni(): failure - invalid SAS ptr=NULL");
+ errorMessage = "invalid SAS ptr=NULL";
+ } else if(!messageBuffer) {
+ LOGE("## calculateMacLongKdfJni(): failure - invalid message");
+ errorMessage = "invalid info";
+ }
+ else if (!(messagePtr = env->GetByteArrayElements(messageBuffer, &messageWasCopied)))
+ {
+ LOGE(" ## calculateMacLongKdfJni(): failure - message JNI allocation OOM");
+ errorMessage = "message JNI allocation OOM";
+ }
+ else if (!(infoPtr = env->GetByteArrayElements(infoBuffer, &infoWasCopied)))
+ {
+ LOGE(" ## calculateMacLongKdfJni(): failure - info JNI allocation OOM");
+ errorMessage = "info JNI allocation OOM";
+ } else {
+
+ size_t infoLength = (size_t)env->GetArrayLength(infoBuffer);
+ size_t messageLength = (size_t)env->GetArrayLength(messageBuffer);
+ size_t macLength = olm_sas_mac_length(sasPtr);
+
+ void *macPtr = malloc(macLength*sizeof(uint8_t));
+
+ size_t result = olm_sas_calculate_mac_long_kdf(sasPtr,messagePtr,messageLength,infoPtr,infoLength,macPtr,macLength);
+ if (result == olm_error())
+ {
+ errorMessage = (const char *)olm_sas_last_error(sasPtr);
+ LOGE("## calculateMacLongKdfJni(): failure - error calculating SAS mac Msg=%s", errorMessage);
+ }
+ else
+ {
+ returnValue = env->NewByteArray(macLength);
+ env->SetByteArrayRegion(returnValue, 0 , macLength, (jbyte*)macPtr);
+ }
+
+ if (macPtr) {
+ free(macPtr);
+ }
+ }
+
+ // free alloc
+ if (infoPtr)
+ {
+ if (infoWasCopied)
+ {
+ memset(infoPtr, 0, (size_t)env->GetArrayLength(infoBuffer));
+ }
+ env->ReleaseByteArrayElements(infoBuffer, infoPtr, JNI_ABORT);
+ }
+ if (messagePtr)
+ {
+ if (messageWasCopied)
+ {
+ memset(messagePtr, 0, (size_t)env->GetArrayLength(messageBuffer));
+ }
+ env->ReleaseByteArrayElements(messageBuffer, messagePtr, JNI_ABORT);
+ }
+
+ if (errorMessage)
+ {
+ env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+ }
+
+ return returnValue;
} \ No newline at end of file
diff --git a/android/olm-sdk/src/main/jni/olm_sas.h b/android/olm-sdk/src/main/jni/olm_sas.h
index ffd4494..3340459 100644
--- a/android/olm-sdk/src/main/jni/olm_sas.h
+++ b/android/olm-sdk/src/main/jni/olm_sas.h
@@ -32,6 +32,7 @@ JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(getPubKeyJni)(JNIEnv *env, jobject thiz);
JNIEXPORT void OLM_SAS_FUNC_DEF(setTheirPubKey)(JNIEnv *env, jobject thiz,jbyteArray pubKey);
JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(generateShortCodeJni)(JNIEnv *env, jobject thiz, jbyteArray infoStringBytes, jint byteNb);
JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacJni)(JNIEnv *env, jobject thiz, jbyteArray messageBuffer, jbyteArray infoBuffer);
+JNIEXPORT jbyteArray OLM_SAS_FUNC_DEF(calculateMacLongKdfJni)(JNIEnv *env, jobject thiz, jbyteArray messageBuffer, jbyteArray infoBuffer);
#ifdef __cplusplus
}