Android: Allow switching the voice typing library to Whisper (#11158)

Co-authored-by: Laurent Cozic <laurent22@users.noreply.github.com>
pull/11271/head
Henry Heino 2024-10-26 13:00:56 -07:00 committed by GitHub
parent 3a316a1dbc
commit 9f5282c8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 980 additions and 158 deletions

View File

@ -737,8 +737,12 @@ packages/app-mobile/services/BackButtonService.js
packages/app-mobile/services/e2ee/RSA.react-native.js
packages/app-mobile/services/plugins/PlatformImplementation.js
packages/app-mobile/services/profiles/index.js
packages/app-mobile/services/voiceTyping/VoiceTyping.js
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js
packages/app-mobile/services/voiceTyping/vosk.android.js
packages/app-mobile/services/voiceTyping/vosk.js
packages/app-mobile/services/voiceTyping/whisper.js
packages/app-mobile/setupQuickActions.js
packages/app-mobile/tools/buildInjectedJs/BundledFile.js
packages/app-mobile/tools/buildInjectedJs/constants.js

4
.gitignore vendored
View File

@ -714,8 +714,12 @@ packages/app-mobile/services/BackButtonService.js
packages/app-mobile/services/e2ee/RSA.react-native.js
packages/app-mobile/services/plugins/PlatformImplementation.js
packages/app-mobile/services/profiles/index.js
packages/app-mobile/services/voiceTyping/VoiceTyping.js
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js
packages/app-mobile/services/voiceTyping/vosk.android.js
packages/app-mobile/services/voiceTyping/vosk.js
packages/app-mobile/services/voiceTyping/whisper.js
packages/app-mobile/setupQuickActions.js
packages/app-mobile/tools/buildInjectedJs/BundledFile.js
packages/app-mobile/tools/buildInjectedJs/constants.js

View File

@ -136,6 +136,10 @@ dependencies {
} else {
implementation jscFlavor
}
// Needed for Whisper speech-to-text
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
}
apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project)

View File

@ -10,6 +10,7 @@
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.READ_MEDIA_IMAGES" />
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<!-- Make these features optional to enable Chromebooks -->
<!-- https://github.com/laurent22/joplin/issues/37 -->

View File

@ -11,6 +11,7 @@ import com.facebook.react.defaults.DefaultNewArchitectureEntryPoint.load
import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost
import com.facebook.react.defaults.DefaultReactNativeHost
import com.facebook.soloader.SoLoader
import net.cozic.joplin.audio.SpeechToTextPackage
import net.cozic.joplin.versioninfo.SystemVersionInformationPackage
import net.cozic.joplin.share.SharePackage
import net.cozic.joplin.ssl.SslPackage
@ -25,6 +26,7 @@ class MainApplication : Application(), ReactApplication {
add(SslPackage())
add(TextInputPackage())
add(SystemVersionInformationPackage())
add(SpeechToTextPackage())
}
override fun getJSMainModuleName(): String = "index"

View File

@ -0,0 +1,101 @@
package net.cozic.joplin.audio
import android.Manifest
import android.annotation.SuppressLint
import android.content.Context
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder.AudioSource
import java.io.Closeable
import kotlin.math.max
import kotlin.math.min
typealias AudioRecorderFactory = (context: Context)->AudioRecorder;
class AudioRecorder(context: Context) : Closeable {
private val sampleRate = 16_000
private val maxLengthSeconds = 30 // Whisper supports a maximum of 30s
private val maxBufferSize = sampleRate * maxLengthSeconds
private val buffer = FloatArray(maxBufferSize)
private var bufferWriteOffset = 0
// Accessor must not modify result
val bufferedData: FloatArray get() = buffer.sliceArray(0 until bufferWriteOffset)
val bufferLengthSeconds: Double get() = bufferWriteOffset.toDouble() / sampleRate
init {
val permissionResult = context.checkSelfPermission(Manifest.permission.RECORD_AUDIO)
if (permissionResult == PackageManager.PERMISSION_DENIED) {
throw SecurityException("Missing RECORD_AUDIO permission!")
}
}
// Permissions check is included above
@SuppressLint("MissingPermission")
private val recorder = AudioRecord.Builder()
.setAudioSource(AudioSource.MIC)
.setAudioFormat(
AudioFormat.Builder()
// PCM: A WAV format
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setSampleRate(sampleRate)
.setChannelMask(AudioFormat.CHANNEL_IN_MONO)
.build()
)
.setBufferSizeInBytes(maxBufferSize * Float.SIZE_BYTES)
.build()
// Discards the first [samples] samples from the start of the buffer. Conceptually, this
// advances the buffer's start point.
private fun advanceStartBySamples(samples: Int) {
val samplesClamped = min(samples, maxBufferSize)
val remainingBuffer = buffer.sliceArray(samplesClamped until maxBufferSize)
buffer.fill(0f, samplesClamped, maxBufferSize)
remainingBuffer.copyInto(buffer, 0)
bufferWriteOffset = max(bufferWriteOffset - samplesClamped, 0)
}
fun dropFirstSeconds(seconds: Double) {
advanceStartBySamples((seconds * sampleRate).toInt())
}
fun start() {
recorder.startRecording()
}
private fun read(requestedSize: Int, mode: Int) {
val size = min(requestedSize, maxBufferSize - bufferWriteOffset)
val sizeRead = recorder.read(buffer, bufferWriteOffset, size, mode)
if (sizeRead > 0) {
bufferWriteOffset += sizeRead
}
}
// Pulls all available data from the audio recorder's buffer
fun pullAvailable() {
return read(maxBufferSize, AudioRecord.READ_NON_BLOCKING)
}
fun pullNextSeconds(seconds: Double) {
val remainingSize = maxBufferSize - bufferWriteOffset
val requestedSize = (seconds * sampleRate).toInt()
// If low on size, make more room.
if (remainingSize < maxBufferSize / 3) {
advanceStartBySamples(maxBufferSize / 3)
}
return read(requestedSize, AudioRecord.READ_BLOCKING)
}
override fun close() {
recorder.stop()
recorder.release()
}
companion object {
val factory: AudioRecorderFactory = { context -> AudioRecorder(context) }
}
}

View File

@ -0,0 +1,5 @@
package net.cozic.joplin.audio
class InvalidSessionIdException(id: Int) : IllegalArgumentException("Invalid session ID $id") {
}

View File

@ -0,0 +1,136 @@
package net.cozic.joplin.audio
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import ai.onnxruntime.extensions.OrtxPackage
import android.annotation.SuppressLint
import android.content.Context
import android.util.Log
import java.io.Closeable
import java.nio.FloatBuffer
import java.nio.IntBuffer
import kotlin.time.DurationUnit
import kotlin.time.measureTimedValue
class SpeechToTextConverter(
modelPath: String,
locale: String,
recorderFactory: AudioRecorderFactory,
private val environment: OrtEnvironment,
context: Context,
) : Closeable {
private val recorder = recorderFactory(context)
private val session: OrtSession = environment.createSession(
modelPath,
OrtSession.SessionOptions().apply {
// Needed for audio decoding
registerCustomOpLibrary(OrtxPackage.getLibraryPath())
},
)
private val languageCode = Regex("_.*").replace(locale, "")
private val decoderInputIds = when (languageCode) {
// Add 50363 to the end to omit timestamps
"en" -> intArrayOf(50258, 50259, 50359)
"fr" -> intArrayOf(50258, 50265, 50359)
"es" -> intArrayOf(50258, 50262, 50359)
"de" -> intArrayOf(50258, 50261, 50359)
"it" -> intArrayOf(50258, 50274, 50359)
"nl" -> intArrayOf(50258, 50271, 50359)
"ko" -> intArrayOf(50258, 50264, 50359)
"th" -> intArrayOf(50258, 50289, 50359)
"ru" -> intArrayOf(50258, 50263, 50359)
"pt" -> intArrayOf(50258, 50267, 50359)
"pl" -> intArrayOf(50258, 50269, 50359)
"id" -> intArrayOf(50258, 50275, 50359)
"hi" -> intArrayOf(50258, 50276, 50359)
// Let Whisper guess the language
else -> intArrayOf(50258)
}
fun start() {
recorder.start()
}
private fun getInputs(data: FloatArray): MutableMap<String, OnnxTensor> {
fun intTensor(value: Int) = OnnxTensor.createTensor(
environment,
IntBuffer.wrap(intArrayOf(value)),
longArrayOf(1),
)
fun floatTensor(value: Float) = OnnxTensor.createTensor(
environment,
FloatBuffer.wrap(floatArrayOf(value)),
longArrayOf(1),
)
val audioPcmTensor = OnnxTensor.createTensor(
environment,
FloatBuffer.wrap(data),
longArrayOf(1, data.size.toLong()),
)
val decoderInputIdsTensor = OnnxTensor.createTensor(
environment,
IntBuffer.wrap(decoderInputIds),
longArrayOf(1, decoderInputIds.size.toLong())
)
return mutableMapOf(
"audio_pcm" to audioPcmTensor,
"max_length" to intTensor(412),
"min_length" to intTensor(0),
"num_return_sequences" to intTensor(1),
"num_beams" to intTensor(1),
"length_penalty" to floatTensor(1.1f),
"repetition_penalty" to floatTensor(3f),
"decoder_input_ids" to decoderInputIdsTensor,
// Required for timestamps
"logits_processor" to intTensor(1)
)
}
// TODO .get() fails on older Android versions
@SuppressLint("NewApi")
private fun convert(data: FloatArray): String {
val (inputs, convertInputsTime) = measureTimedValue {
getInputs(data)
}
val (outputs, getOutputsTime) = measureTimedValue {
session.run(inputs, setOf("str"))
}
val mainOutput = outputs.get("str").get().value as Array<Array<String>>
outputs.close()
Log.i("Whisper", "Converted ${data.size / 16000}s of data in ${
getOutputsTime.toString(DurationUnit.SECONDS, 2)
} converted inputs in ${convertInputsTime.inWholeMilliseconds}ms")
return mainOutput[0][0]
}
fun dropFirstSeconds(seconds: Double) {
Log.i("Whisper", "Drop first seconds $seconds")
recorder.dropFirstSeconds(seconds)
}
val bufferLengthSeconds: Double get() = recorder.bufferLengthSeconds
fun expandBufferAndConvert(seconds: Double): String {
recorder.pullNextSeconds(seconds)
// Also pull any extra available data, in case the speech-to-text converter
// is lagging behind the audio recorder.
recorder.pullAvailable()
return convert(recorder.bufferedData)
}
// Converts as many seconds of buffered data as possible, without waiting
fun expandBufferAndConvert(): String {
recorder.pullAvailable()
return convert(recorder.bufferedData)
}
override fun close() {
recorder.close()
session.close()
}
}

View File

@ -0,0 +1,86 @@
package net.cozic.joplin.audio
import ai.onnxruntime.OrtEnvironment
import com.facebook.react.ReactPackage
import com.facebook.react.bridge.LifecycleEventListener
import com.facebook.react.bridge.NativeModule
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReactContextBaseJavaModule
import com.facebook.react.bridge.ReactMethod
import com.facebook.react.uimanager.ViewManager
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
class SpeechToTextPackage : ReactPackage {
override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> {
return listOf<NativeModule>(SpeechToTextModule(reactContext))
}
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
return emptyList()
}
class SpeechToTextModule(
private var context: ReactApplicationContext,
) : ReactContextBaseJavaModule(context), LifecycleEventListener {
private var environment: OrtEnvironment? = null
private val executorService: ExecutorService = Executors.newFixedThreadPool(1)
private val sessionManager = SpeechToTextSessionManager(executorService)
override fun getName() = "SpeechToTextModule"
override fun onHostResume() { }
override fun onHostPause() { }
override fun onHostDestroy() {
environment?.close()
}
@ReactMethod
fun openSession(modelPath: String, locale: String, promise: Promise) {
val appContext = context.applicationContext
// Initialize environment as late as possible:
val ortEnvironment = environment ?: OrtEnvironment.getEnvironment()
if (environment != null) {
environment = ortEnvironment
}
try {
val sessionId = sessionManager.openSession(modelPath, locale, ortEnvironment, appContext)
promise.resolve(sessionId)
} catch (exception: Throwable) {
promise.reject(exception)
}
}
@ReactMethod
fun startRecording(sessionId: Int, promise: Promise) {
sessionManager.startRecording(sessionId, promise)
}
@ReactMethod
fun getBufferLengthSeconds(sessionId: Int, promise: Promise) {
sessionManager.getBufferLengthSeconds(sessionId, promise)
}
@ReactMethod
fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) {
sessionManager.dropFirstSeconds(sessionId, duration, promise)
}
@ReactMethod
fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) {
sessionManager.expandBufferAndConvert(sessionId, duration, promise)
}
@ReactMethod
fun convertAvailable(sessionId: Int, promise: Promise) {
sessionManager.convertAvailable(sessionId, promise)
}
@ReactMethod
fun closeSession(sessionId: Int, promise: Promise) {
sessionManager.closeSession(sessionId, promise)
}
}
}

View File

@ -0,0 +1,111 @@
package net.cozic.joplin.audio
import ai.onnxruntime.OrtEnvironment
import android.content.Context
import com.facebook.react.bridge.Promise
import java.util.concurrent.Executor
import java.util.concurrent.locks.ReentrantLock
class SpeechToTextSession (
val converter: SpeechToTextConverter
) {
val mutex = ReentrantLock()
}
class SpeechToTextSessionManager(
private var executor: Executor,
) {
private val sessions: MutableMap<Int, SpeechToTextSession> = mutableMapOf()
private var nextSessionId: Int = 0
fun openSession(
modelPath: String,
locale: String,
environment: OrtEnvironment,
context: Context,
): Int {
val sessionId = nextSessionId++
sessions[sessionId] = SpeechToTextSession(
SpeechToTextConverter(
modelPath, locale, recorderFactory = AudioRecorder.factory, environment, context,
)
)
return sessionId
}
private fun getSession(id: Int): SpeechToTextSession {
return sessions[id] ?: throw InvalidSessionIdException(id)
}
private fun concurrentWithSession(
id: Int,
callback: (session: SpeechToTextSession)->Unit,
) {
executor.execute {
val session = getSession(id)
session.mutex.lock()
try {
callback(session)
} finally {
session.mutex.unlock()
}
}
}
private fun concurrentWithSession(
id: Int,
onError: (error: Throwable)->Unit,
callback: (session: SpeechToTextSession)->Unit,
) {
return concurrentWithSession(id) { session ->
try {
callback(session)
} catch (error: Throwable) {
onError(error)
}
}
}
fun startRecording(sessionId: Int, promise: Promise) {
this.concurrentWithSession(sessionId, promise::reject) { session ->
session.converter.start()
promise.resolve(null)
}
}
// Left-shifts the recording buffer by [duration] seconds
fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) {
this.concurrentWithSession(sessionId, promise::reject) { session ->
session.converter.dropFirstSeconds(duration)
promise.resolve(sessionId)
}
}
fun getBufferLengthSeconds(sessionId: Int, promise: Promise) {
this.concurrentWithSession(sessionId, promise::reject) { session ->
promise.resolve(session.converter.bufferLengthSeconds)
}
}
// Waits for the next [duration] seconds to become available, then converts
fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) {
this.concurrentWithSession(sessionId, promise::reject) { session ->
val result = session.converter.expandBufferAndConvert(duration)
promise.resolve(result)
}
}
// Converts all available recorded data
fun convertAvailable(sessionId: Int, promise: Promise) {
this.concurrentWithSession(sessionId, promise::reject) { session ->
val result = session.converter.expandBufferAndConvert()
promise.resolve(result)
}
}
fun closeSession(sessionId: Int, promise: Promise) {
this.concurrentWithSession(sessionId) { session ->
session.converter.close()
promise.resolve(null)
}
}
}

View File

@ -45,7 +45,6 @@ import ImageEditor from '../NoteEditor/ImageEditor/ImageEditor';
import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave';
import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource';
import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog';
import { voskEnabled } from '../../services/voiceTyping/vosk';
import { isSupportedLanguage } from '../../services/voiceTyping/vosk';
import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events';
import { join } from 'path';
@ -1204,8 +1203,8 @@ class NoteScreenComponent extends BaseScreenComponent<Props, State> implements B
});
}
// Voice typing is enabled only for French language and on Android for now
if (voskEnabled && shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) {
// Voice typing is enabled only on Android for now
if (shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) {
output.push({
title: _('Voice typing...'),
onPress: () => {

View File

@ -1,14 +1,18 @@
import * as React from 'react';
import { useState, useEffect, useCallback } from 'react';
import { Banner, ActivityIndicator } from 'react-native-paper';
import { useState, useEffect, useCallback, useRef, useMemo } from 'react';
import { Banner, ActivityIndicator, Text } from 'react-native-paper';
import { _, languageName } from '@joplin/lib/locale';
import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect';
import { getVosk, Recorder, startRecording, Vosk } from '../../services/voiceTyping/vosk';
import { IconSource } from 'react-native-paper/lib/typescript/components/Icon';
import { modelIsDownloaded } from '../../services/voiceTyping/vosk';
import VoiceTyping, { OnTextCallback, VoiceTypingSession } from '../../services/voiceTyping/VoiceTyping';
import whisper from '../../services/voiceTyping/whisper';
import vosk from '../../services/voiceTyping/vosk';
import { AppState } from '../../utils/types';
import { connect } from 'react-redux';
interface Props {
locale: string;
provider: string;
onDismiss: ()=> void;
onText: (text: string)=> void;
}
@ -21,44 +25,77 @@ enum RecorderState {
Downloading = 5,
}
const useVosk = (locale: string): [Error | null, boolean, Vosk|null] => {
const [vosk, setVosk] = useState<Vosk>(null);
interface UseVoiceTypingProps {
locale: string;
provider: string;
onSetPreview: OnTextCallback;
onText: OnTextCallback;
}
const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => {
const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null);
const [error, setError] = useState<Error>(null);
const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null);
useAsyncEffect(async (event: AsyncEffectEvent) => {
if (mustDownloadModel === null) return;
const onTextRef = useRef(onText);
onTextRef.current = onText;
const onSetPreviewRef = useRef(onSetPreview);
onSetPreviewRef.current = onSetPreview;
const voiceTypingRef = useRef(voiceTyping);
voiceTypingRef.current = voiceTyping;
const builder = useMemo(() => {
return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]);
}, [locale, provider]);
useAsyncEffect(async (event: AsyncEffectEvent) => {
try {
const v = await getVosk(locale);
await voiceTypingRef.current?.stop();
if (!await builder.isDownloaded()) {
if (event.cancelled) return;
await builder.download();
}
if (event.cancelled) return;
setVosk(v);
const voiceTyping = await builder.build({
onPreview: (text) => onSetPreviewRef.current(text),
onFinalize: (text) => onTextRef.current(text),
});
if (event.cancelled) return;
setVoiceTyping(voiceTyping);
} catch (error) {
setError(error);
} finally {
setMustDownloadModel(false);
}
}, [locale, mustDownloadModel]);
}, [builder]);
useAsyncEffect(async (_event: AsyncEffectEvent) => {
setMustDownloadModel(!(await modelIsDownloaded(locale)));
}, [locale]);
setMustDownloadModel(!(await builder.isDownloaded()));
}, [builder]);
return [error, mustDownloadModel, vosk];
return [error, mustDownloadModel, voiceTyping];
};
export default (props: Props) => {
const [recorder, setRecorder] = useState<Recorder>(null);
const VoiceTypingDialog: React.FC<Props> = props => {
const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading);
const [voskError, mustDownloadModel, vosk] = useVosk(props.locale);
const [preview, setPreview] = useState<string>('');
const [modelError, mustDownloadModel, voiceTyping] = useWhisper({
locale: props.locale,
onSetPreview: setPreview,
onText: props.onText,
provider: props.provider,
});
useEffect(() => {
if (voskError) {
if (modelError) {
setRecorderState(RecorderState.Error);
} else if (vosk) {
} else if (voiceTyping) {
setRecorderState(RecorderState.Recording);
}
}, [vosk, voskError]);
}, [voiceTyping, modelError]);
useEffect(() => {
if (mustDownloadModel) {
@ -68,27 +105,22 @@ export default (props: Props) => {
useEffect(() => {
if (recorderState === RecorderState.Recording) {
setRecorder(startRecording(vosk, {
onResult: (text: string) => {
props.onText(text);
},
}));
void voiceTyping.start();
}
}, [recorderState, vosk, props.onText]);
}, [recorderState, voiceTyping, props.onText]);
const onDismiss = useCallback(() => {
if (recorder) recorder.cleanup();
void voiceTyping?.stop();
props.onDismiss();
}, [recorder, props.onDismiss]);
}, [voiceTyping, props.onDismiss]);
const renderContent = () => {
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
const components: Record<RecorderState, Function> = {
const components: Record<RecorderState, ()=> string> = {
[RecorderState.Loading]: () => _('Loading...'),
[RecorderState.Recording]: () => _('Please record your voice...'),
[RecorderState.Processing]: () => _('Converting speech to text...'),
[RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)),
[RecorderState.Error]: () => _('Error: %s', voskError.message),
[RecorderState.Error]: () => _('Error: %s', modelError.message),
};
return components[recorderState]();
@ -106,6 +138,11 @@ export default (props: Props) => {
return components[recorderState];
};
const renderPreview = () => {
return <Text variant='labelSmall'>{preview}</Text>;
};
const headerAndStatus = <Text variant='bodyMedium'>{`${_('Voice typing...')}\n${renderContent()}`}</Text>;
return (
<Banner
visible={true}
@ -115,8 +152,15 @@ export default (props: Props) => {
label: _('Done'),
onPress: onDismiss,
},
]}>
{`${_('Voice typing...')}\n${renderContent()}`}
]}
>
{headerAndStatus}
<Text>{'\n'}</Text>
{renderPreview()}
</Banner>
);
};
export default connect((state: AppState) => ({
provider: state.settings['voiceTyping.preferredProvider'],
}))(VoiceTypingDialog);

View File

@ -80,6 +80,10 @@ jest.mock('react-native-image-picker', () => {
return { default: { } };
});
jest.mock('react-native-zip-archive', () => {
return { default: { } };
});
jest.mock('react-native-document-picker', () => ({ default: { } }));
// Used by the renderer

View File

@ -0,0 +1,139 @@
import shim from '@joplin/lib/shim';
import Logger from '@joplin/utils/Logger';
import { PermissionsAndroid, Platform } from 'react-native';
import { unzip } from 'react-native-zip-archive';
const md5 = require('md5');
const logger = Logger.create('voiceTyping');
export type OnTextCallback = (text: string)=> void;
export interface SpeechToTextCallbacks {
// Called with a block of text that might change in the future
onPreview: OnTextCallback;
// Called with text that will not change and should be added to the document
onFinalize: OnTextCallback;
}
export interface VoiceTypingSession {
start(): Promise<void>;
stop(): Promise<void>;
}
export interface BuildProviderOptions {
locale: string;
modelPath: string;
callbacks: SpeechToTextCallbacks;
}
export interface VoiceTypingProvider {
modelName: string;
supported(): boolean;
modelLocalFilepath(locale: string): string;
getDownloadUrl(locale: string): string;
getUuidPath(locale: string): string;
build(options: BuildProviderOptions): Promise<VoiceTypingSession>;
}
export default class VoiceTyping {
private provider: VoiceTypingProvider|null = null;
public constructor(
private locale: string,
providers: VoiceTypingProvider[],
) {
this.provider = providers.find(p => p.supported()) ?? null;
}
public supported() {
return this.provider !== null;
}
private getModelPath() {
const localFilePath = shim.fsDriver().resolveRelativePathWithinDir(
shim.fsDriver().getAppDirectoryPath(),
this.provider.modelLocalFilepath(this.locale),
);
if (localFilePath === shim.fsDriver().getAppDirectoryPath()) {
throw new Error('Invalid local file path!');
}
return localFilePath;
}
private getUuidPath() {
return shim.fsDriver().resolveRelativePathWithinDir(
shim.fsDriver().getAppDirectoryPath(),
this.provider.getUuidPath(this.locale),
);
}
public async isDownloaded() {
return await shim.fsDriver().exists(this.getUuidPath());
}
public async download() {
const modelPath = this.getModelPath();
const modelUrl = this.provider.getDownloadUrl(this.locale);
await shim.fsDriver().remove(modelPath);
logger.info(`Downloading model from: ${modelUrl}`);
const isZipped = modelUrl.endsWith('.zip');
const downloadPath = isZipped ? `${modelPath}.zip` : modelPath;
const response = await shim.fetchBlob(modelUrl, {
path: downloadPath,
});
if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`);
if (isZipped) {
const modelName = this.provider.modelName;
const unzipDir = `${shim.fsDriver().getCacheDirectoryPath()}/voice-typing-extract/${modelName}/${this.locale}`;
try {
logger.info(`Unzipping ${downloadPath} => ${unzipDir}`);
await unzip(downloadPath, unzipDir);
const contents = await shim.fsDriver().readDirStats(unzipDir);
if (contents.length !== 1) {
logger.error('Expected 1 file or directory but got', contents);
throw new Error(`Expected 1 file or directory, but got ${contents.length}`);
}
const fullUnzipPath = `${unzipDir}/${contents[0].path}`;
logger.info(`Moving ${fullUnzipPath} => ${modelPath}`);
await shim.fsDriver().move(fullUnzipPath, modelPath);
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
if (!await this.isDownloaded()) {
logger.warn('Model should be downloaded!');
}
} finally {
await shim.fsDriver().remove(unzipDir);
await shim.fsDriver().remove(downloadPath);
}
}
}
public async build(callbacks: SpeechToTextCallbacks) {
if (!this.provider) {
throw new Error('No supported provider found!');
}
if (!await this.isDownloaded()) {
await this.download();
}
const audioPermission = 'android.permission.RECORD_AUDIO';
if (Platform.OS === 'android' && !await PermissionsAndroid.check(audioPermission)) {
await PermissionsAndroid.request(audioPermission);
}
return this.provider.build({
locale: this.locale,
modelPath: this.getModelPath(),
callbacks,
});
}
}

View File

@ -0,0 +1,61 @@
import splitWhisperText from './splitWhisperText';
describe('splitWhisperText', () => {
test.each([
{
// Should trim at sentence breaks
input: '<|0.00|> This is a test. <|5.00|><|6.00|> This is another sentence. <|7.00|>',
recordingLength: 8,
expected: {
trimTo: 6,
dataBeforeTrim: '<|0.00|> This is a test. ',
dataAfterTrim: ' This is another sentence. <|7.00|>',
},
},
{
// Should prefer sentence break splits to non sentence break splits
input: '<|0.00|> This is <|4.00|><|4.50|> a test. <|5.00|><|5.50|> Testing, <|6.00|><|7.00|> this is a test. <|8.00|>',
recordingLength: 8,
expected: {
trimTo: 5.50,
dataBeforeTrim: '<|0.00|> This is <|4.00|><|4.50|> a test. ',
dataAfterTrim: ' Testing, <|6.00|><|7.00|> this is a test. <|8.00|>',
},
},
{
// Should avoid splitting for very small timestamps
input: '<|0.00|> This is a test. <|2.00|><|2.30|> Testing! <|3.00|>',
recordingLength: 4,
expected: {
trimTo: 0,
dataBeforeTrim: '',
dataAfterTrim: ' This is a test. <|2.00|><|2.30|> Testing! <|3.00|>',
},
},
{
// For larger timestamps, should allow splitting at pauses, even if not on sentence breaks.
input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>',
recordingLength: 16,
expected: {
trimTo: 12,
dataBeforeTrim: '<|0.00|> This is a test, ',
dataAfterTrim: ' of splitting on timestamps. <|15.00|>',
},
},
{
// Should prefer to break at the end, if a large gap after the last timestamp.
input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>',
recordingLength: 30,
expected: {
trimTo: 15,
dataBeforeTrim: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. ',
dataAfterTrim: '',
},
},
])('should prefer to split at the end of sentences (case %#)', ({ input, recordingLength, expected }) => {
const actual = splitWhisperText(input, recordingLength);
expect(actual.trimTo).toBeCloseTo(expected.trimTo);
expect(actual.dataBeforeTrim).toBe(expected.dataBeforeTrim);
expect(actual.dataAfterTrim).toBe(expected.dataAfterTrim);
});
});

View File

@ -0,0 +1,65 @@
// Matches pairs of timestamps or single timestamps.
const timestampExp = /<\|(\d+\.\d*)\|>(?:<\|(\d+\.\d*)\|>)?/g;
const timestampMatchToNumber = (match: RegExpMatchArray) => {
const firstTimestamp = match[1];
const secondTimestamp = match[2];
// Prefer the second timestamp in the pair, to remove leading silence.
const timestamp = Number(secondTimestamp ? secondTimestamp : firstTimestamp);
// Should always be a finite number (i.e. not NaN)
if (!isFinite(timestamp)) throw new Error(`Timestamp match failed with ${match[0]}`);
return timestamp;
};
const splitWhisperText = (textWithTimestamps: string, recordingLengthSeconds: number) => {
const timestamps = [
...textWithTimestamps.matchAll(timestampExp),
].map(match => {
const timestamp = timestampMatchToNumber(match);
return { timestamp, match };
});
if (!timestamps.length) {
return { trimTo: 0, dataBeforeTrim: '', dataAfterTrim: textWithTimestamps };
}
const firstTimestamp = timestamps[0];
let breakAt = firstTimestamp;
const lastTimestamp = timestamps[timestamps.length - 1];
const hasLongPauseAfterData = lastTimestamp.timestamp + 4 < recordingLengthSeconds;
if (hasLongPauseAfterData) {
breakAt = lastTimestamp;
} else {
const textWithTimestampsContentLength = textWithTimestamps.trimEnd().length;
for (const timestampData of timestamps) {
const { match, timestamp } = timestampData;
const contentBefore = textWithTimestamps.substring(Math.max(match.index - 3, 0), match.index);
const isNearEndOfLatinSentence = contentBefore.match(/[.?!]/);
const isNearEndOfData = match.index + match[0].length >= textWithTimestampsContentLength;
// Use a heuristic to determine whether to move content from the preview to the document.
// These are based on the maximum buffer length of 30 seconds -- as the buffer gets longer, the
// data should be more likely to be broken into chunks. Where possible, the break should be near
// the end of a sentence:
const canBreak = (timestamp > 4 && isNearEndOfLatinSentence && !isNearEndOfData)
|| (timestamp > 8 && !isNearEndOfData)
|| timestamp > 16;
if (canBreak) {
breakAt = timestampData;
break;
}
}
}
const trimTo = breakAt.timestamp;
const dataBeforeTrim = textWithTimestamps.substring(0, breakAt.match.index);
const dataAfterTrim = textWithTimestamps.substring(breakAt.match.index + breakAt.match[0].length);
return { trimTo, dataBeforeTrim, dataAfterTrim };
};
export default splitWhisperText;

View File

@ -4,9 +4,9 @@ import Setting from '@joplin/lib/models/Setting';
import { rtrimSlashes } from '@joplin/lib/path-utils';
import shim from '@joplin/lib/shim';
import Vosk from 'react-native-vosk';
import { unzip } from 'react-native-zip-archive';
import RNFetchBlob from 'rn-fetch-blob';
const md5 = require('md5');
import { VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping';
import { join } from 'path';
const logger = Logger.create('voiceTyping/vosk');
@ -24,15 +24,6 @@ let vosk_: Record<string, Vosk> = {};
let state_: State = State.Idle;
export const voskEnabled = true;
export { Vosk };
export interface Recorder {
stop: ()=> Promise<string>;
cleanup: ()=> void;
}
const defaultSupportedLanguages = {
'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip',
'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip',
@ -74,7 +65,7 @@ const getModelDir = (locale: string) => {
return `${getUnzipDir(locale)}/model`;
};
const languageModelUrl = (locale: string) => {
const languageModelUrl = (locale: string): string => {
const lang = languageCodeOnly(locale).toLowerCase();
if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`);
@ -90,16 +81,11 @@ const languageModelUrl = (locale: string) => {
}
};
export const modelIsDownloaded = async (locale: string) => {
const uuidFile = `${getModelDir(locale)}/uuid`;
return shim.fsDriver().exists(uuidFile);
};
export const getVosk = async (locale: string) => {
export const getVosk = async (modelDir: string, locale: string) => {
if (vosk_[locale]) return vosk_[locale];
const vosk = new Vosk();
const modelDir = await downloadModel(locale);
logger.info(`Loading model from ${modelDir}`);
await shim.fsDriver().readDirStats(modelDir);
const result = await vosk.loadModel(modelDir);
@ -110,51 +96,7 @@ export const getVosk = async (locale: string) => {
return vosk;
};
const downloadModel = async (locale: string) => {
const modelUrl = languageModelUrl(locale);
const unzipDir = getUnzipDir(locale);
const zipFilePath = `${unzipDir}.zip`;
const modelDir = getModelDir(locale);
const uuidFile = `${modelDir}/uuid`;
if (await modelIsDownloaded(locale)) {
logger.info(`Model for ${locale} already exists at ${modelDir}`);
return modelDir;
}
await shim.fsDriver().remove(unzipDir);
logger.info(`Downloading model from: ${modelUrl}`);
const response = await shim.fetchBlob(modelUrl, {
path: zipFilePath,
});
if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`);
logger.info(`Unzipping ${zipFilePath} => ${unzipDir}`);
await unzip(zipFilePath, unzipDir);
const dirs = await shim.fsDriver().readDirStats(unzipDir);
if (dirs.length !== 1) {
logger.error('Expected 1 directory but got', dirs);
throw new Error(`Expected 1 directory, but got ${dirs.length}`);
}
const fullUnzipPath = `${unzipDir}/${dirs[0].path}`;
logger.info(`Moving ${fullUnzipPath} => ${modelDir}`);
await shim.fsDriver().rename(fullUnzipPath, modelDir);
await shim.fsDriver().writeFile(uuidFile, md5(modelUrl));
await shim.fsDriver().remove(zipFilePath);
return modelDir;
};
export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSession => {
if (state_ !== State.Idle) throw new Error('Vosk is already recording');
state_ = State.Recording;
@ -163,10 +105,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
const eventHandlers: any[] = [];
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
let finalResultPromiseResolve: Function = null;
const finalResultPromiseResolve: Function = null;
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
let finalResultPromiseReject: Function = null;
let finalResultTimeout = false;
const finalResultPromiseReject: Function = null;
const finalResultTimeout = false;
const completeRecording = (finalResult: string, error: Error) => {
logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error);
@ -212,31 +154,13 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
completeRecording(e.data, null);
}));
logger.info('Starting recording...');
void vosk.start();
return {
stop: (): Promise<string> => {
logger.info('Stopping recording...');
vosk.stopOnly();
logger.info('Waiting for final result...');
setTimeout(() => {
finalResultTimeout = true;
logger.warn('Timed out waiting for finalResult event');
completeRecording('', new Error('Could not process your message. Please try again.'));
}, 5000);
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
return new Promise((resolve: Function, reject: Function) => {
finalResultPromiseResolve = resolve;
finalResultPromiseReject = reject;
});
start: async () => {
logger.info('Starting recording...');
await vosk.start();
},
cleanup: () => {
stop: async () => {
if (state_ === State.Recording) {
logger.info('Cancelling...');
state_ = State.Completing;
@ -246,3 +170,18 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
},
};
};
const vosk: VoiceTypingProvider = {
supported: () => true,
modelLocalFilepath: (locale: string) => getModelDir(locale),
getDownloadUrl: (locale) => languageModelUrl(locale),
getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'),
build: async ({ callbacks, locale, modelPath }) => {
const vosk = await getVosk(modelPath, locale);
return startRecording(vosk, { onResult: callbacks.onFinalize });
},
modelName: 'vosk',
};
export default vosk;

View File

@ -1,37 +1,14 @@
// Currently disabled on non-Android platforms
import { VoiceTypingProvider } from './VoiceTyping';
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
type Vosk = any;
export { Vosk };
interface StartOptions {
onResult: (text: string)=> void;
}
export interface Recorder {
stop: ()=> Promise<string>;
cleanup: ()=> void;
}
export const isSupportedLanguage = (_locale: string) => {
return false;
const vosk: VoiceTypingProvider = {
supported: () => false,
modelLocalFilepath: () => null,
getDownloadUrl: () => null,
getUuidPath: () => null,
build: async () => {
throw new Error('Unsupported!');
},
modelName: 'vosk',
};
export const modelIsDownloaded = async (_locale: string) => {
return false;
};
export const getVosk = async (_locale: string) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
return {} as any;
};
export const startRecording = (_vosk: Vosk, _options: StartOptions): Recorder => {
return {
stop: async () => { return ''; },
cleanup: () => {},
};
};
export const voskEnabled = false;
export default vosk;

View File

@ -0,0 +1,119 @@
import Setting from '@joplin/lib/models/Setting';
import shim from '@joplin/lib/shim';
import Logger from '@joplin/utils/Logger';
import { rtrimSlashes } from '@joplin/utils/path';
import { dirname, join } from 'path';
import { NativeModules } from 'react-native';
import { SpeechToTextCallbacks, VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping';
import splitWhisperText from './utils/splitWhisperText';
const logger = Logger.create('voiceTyping/whisper');
const { SpeechToTextModule } = NativeModules;
// Timestamps are in the form <|0.00|>. They seem to be added:
// - After long pauses.
// - Between sentences (in pairs).
// - At the beginning and end of a sequence.
const timestampExp = /<\|(\d+\.\d*)\|>/g;
const postProcessSpeech = (text: string) => {
return text.replace(timestampExp, '').replace(/\[BLANK_AUDIO\]/g, '');
};
class Whisper implements VoiceTypingSession {
private lastPreviewData: string;
private closeCounter = 0;
public constructor(
private sessionId: number|null,
private callbacks: SpeechToTextCallbacks,
) { }
public async start() {
if (this.sessionId === null) {
throw new Error('Session closed.');
}
try {
logger.debug('starting recorder');
await SpeechToTextModule.startRecording(this.sessionId);
logger.debug('recorder started');
const loopStartCounter = this.closeCounter;
while (this.closeCounter === loopStartCounter) {
logger.debug('reading block');
const data: string = await SpeechToTextModule.expandBufferAndConvert(this.sessionId, 4);
logger.debug('done reading block. Length', data?.length);
if (this.sessionId === null) {
logger.debug('Session stopped. Ending inference loop.');
return;
}
const recordingLength = await SpeechToTextModule.getBufferLengthSeconds(this.sessionId);
logger.debug('recording length so far', recordingLength);
const { trimTo, dataBeforeTrim, dataAfterTrim } = splitWhisperText(data, recordingLength);
if (trimTo > 2) {
logger.debug('Trim to', trimTo, 'in recording with length', recordingLength);
this.callbacks.onFinalize(postProcessSpeech(dataBeforeTrim));
this.callbacks.onPreview(postProcessSpeech(dataAfterTrim));
this.lastPreviewData = dataAfterTrim;
await SpeechToTextModule.dropFirstSeconds(this.sessionId, trimTo);
} else {
logger.debug('Preview', data);
this.lastPreviewData = data;
this.callbacks.onPreview(postProcessSpeech(data));
}
}
} catch (error) {
logger.error('Whisper error:', error);
this.lastPreviewData = '';
await this.stop();
throw error;
}
}
public async stop() {
if (this.sessionId === null) {
logger.warn('Session already closed.');
return;
}
const sessionId = this.sessionId;
this.sessionId = null;
this.closeCounter ++;
await SpeechToTextModule.closeSession(sessionId);
if (this.lastPreviewData) {
this.callbacks.onFinalize(postProcessSpeech(this.lastPreviewData));
}
}
}
const modelLocalFilepath = () => {
return `${shim.fsDriver().getAppDirectoryPath()}/voice-typing-models/whisper_tiny.onnx`;
};
const whisper: VoiceTypingProvider = {
supported: () => !!SpeechToTextModule,
modelLocalFilepath: modelLocalFilepath,
getDownloadUrl: () => {
let urlTemplate = rtrimSlashes(Setting.value('voiceTypingBaseUrl').trim());
if (!urlTemplate) {
urlTemplate = 'https://github.com/personalizedrefrigerator/joplin-voice-typing-test/releases/download/test-release/{task}.zip';
}
return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx');
},
getUuidPath: () => {
return join(dirname(modelLocalFilepath()), 'uuid');
},
build: async ({ modelPath, callbacks, locale }) => {
const sessionId = await SpeechToTextModule.openSession(modelPath, locale);
return new Whisper(sessionId, callbacks);
},
modelName: 'whisper',
};
export default whisper;

View File

@ -1615,6 +1615,25 @@ const builtInMetadata = (Setting: typeof SettingType) => {
section: 'note',
},
'voiceTyping.preferredProvider': {
value: 'whisper-tiny',
type: SettingItemType.String,
public: true,
appTypes: [AppType.Mobile],
label: () => _('Preferred voice typing provider'),
isEnum: true,
// For now, iOS and web don't support voice typing.
show: () => shim.mobilePlatform() === 'android',
section: 'note',
options: () => {
return {
'vosk': _('Vosk'),
'whisper-tiny': _('Whisper'),
};
},
},
'trash.autoDeletionEnabled': {
value: true,
type: SettingItemType.Bool,

View File

@ -132,4 +132,6 @@ Famegear
rcompare
tabindex
Backblaze
onnx
onnxruntime
treeitem