From 4d77608da1d758e1932dc0616e943aaab5820022 Mon Sep 17 00:00:00 2001 From: GiviMAD Date: Sun, 20 Feb 2022 12:45:31 +0100 Subject: [PATCH] [googlestt] lazy abort (#12317) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Miguel Álvarez Díez --- .../googlestt/internal/GoogleSTTService.java | 75 +++++++++---------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java b/bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java index 5f82829cf14..a0f0bfa0f81 100644 --- a/bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java +++ b/bundles/org.openhab.voice.googlestt/src/main/java/org/openhab/voice/googlestt/internal/GoogleSTTService.java @@ -24,7 +24,6 @@ import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.Nullable; @@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService { @Override public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, Set set) { - AtomicBoolean keepStreaming = new AtomicBoolean(true); - Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set); + AtomicBoolean aborted = new AtomicBoolean(false); + backgroundRecognize(sttListener, audioStream, aborted, locale, set); return new STTServiceHandle() { @Override public void abort() { - keepStreaming.set(false); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - } - scheduledTask.cancel(true); + aborted.set(true); } }; } @@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService { } } - private Future backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming, + private Future backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted, Locale locale, Set set) { Credentials credentials = getCredentials(); return executor.submit(() -> { @@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService { ClientStream clientStream = null; try (SpeechClient client = SpeechClient .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) { - TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, - (t) -> keepStreaming.set(false)); + TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted); clientStream = client.streamingRecognizeCallable().splitCall(responseObserver); - streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale); + streamAudio(clientStream, audioStream, responseObserver, aborted, locale); clientStream.closeSend(); logger.debug("Background recognize done"); } catch (IOException e) { @@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService { } private void streamAudio(ClientStream clientStream, AudioStream audioStream, - TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException { + TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException { // Gather stream info and send config AudioFormat streamFormat = audioStream.getFormat(); RecognitionConfig.AudioEncoding streamEncoding; @@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService { long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L); long maxSilenceMillis = (config.maxSilenceSeconds * 1000L); int readBytes = 6400; - while (keepStreaming.get()) { + while (!aborted.get()) { byte[] data = new byte[readBytes]; int dataN = audioStream.read(data); - if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) { + if (aborted.get()) { + logger.debug("Stops listening, aborted"); + break; + } + if (isExpiredInterval(maxTranscriptionMillis, startTime)) { logger.debug("Stops listening, max transcription time reached"); break; } @@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService { private final StringBuilder transcriptBuilder = new StringBuilder(); private final STTListener sttListener; GoogleSTTConfiguration config; - private final Consumer<@Nullable Throwable> completeListener; + private final AtomicBoolean aborted; private float confidenceSum = 0; private int responseCount = 0; private long lastInputTime = 0; - public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, - Consumer<@Nullable Throwable> completeListener) { + public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) { this.sttListener = sttListener; this.config = config; - this.completeListener = completeListener; + this.aborted = aborted; } @Override @@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService { responseCount++; // when in single utterance mode we can just get one final result so complete if (config.singleUtteranceMode) { - completeListener.accept(null); + onComplete(); } } }); @@ -380,16 +376,18 @@ public class GoogleSTTService implements STTService { @Override public void onComplete() { - sttListener.sttEventReceived(new RecognitionStopEvent()); - float averageConfidence = confidenceSum / responseCount; - String transcript = transcriptBuilder.toString(); - if (!transcript.isBlank()) { - sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence)); - } else { - if (!config.noResultsMessage.isBlank()) { - sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage)); + if (!aborted.getAndSet(true)) { + sttListener.sttEventReceived(new RecognitionStopEvent()); + float averageConfidence = confidenceSum / responseCount; + String transcript = transcriptBuilder.toString(); + if (!transcript.isBlank()) { + sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence)); } else { - sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results")); + if (!config.noResultsMessage.isBlank()) { + sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage)); + } else { + sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results")); + } } } } @@ -397,14 +395,15 @@ public class GoogleSTTService implements STTService { @Override public void onError(@Nullable Throwable t) { logger.warn("Recognition error: ", t); - completeListener.accept(t); - sttListener.sttEventReceived(new RecognitionStopEvent()); - if (!config.errorMessage.isBlank()) { - sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage)); - } else { - String errorMessage = t.getMessage(); - sttListener.sttEventReceived( - new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error")); + if (!aborted.getAndSet(true)) { + sttListener.sttEventReceived(new RecognitionStopEvent()); + if (!config.errorMessage.isBlank()) { + sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage)); + } else { + String errorMessage = t.getMessage(); + sttListener.sttEventReceived( + new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error")); + } } }