Upload 19 files
Browse files- .gitattributes +1 -0
- assets/jfk.flac +3 -0
- moyoyo_asr_models/ggml-medium-encoder.mlmodelc/analytics/coremldata.bin +3 -0
- moyoyo_asr_models/ggml-medium-encoder.mlmodelc/coremldata.bin +3 -0
- moyoyo_asr_models/ggml-medium-encoder.mlmodelc/metadata.json +64 -0
- moyoyo_asr_models/ggml-medium-encoder.mlmodelc/model.mil +0 -0
- moyoyo_asr_models/ggml-medium-encoder.mlmodelc/weights/weight.bin +3 -0
- moyoyo_asr_models/ggml-medium-q5_0.bin +3 -0
- run_client.py +15 -0
- run_server.py +31 -0
- transcribe/__init__.py +0 -0
- transcribe/__pycache__/__init__.cpython-311.pyc +0 -0
- transcribe/__pycache__/client.cpython-311.pyc +0 -0
- transcribe/__pycache__/server.cpython-311.pyc +0 -0
- transcribe/__pycache__/utils.cpython-311.pyc +0 -0
- transcribe/__pycache__/vad.cpython-311.pyc +0 -0
- transcribe/client.py +675 -0
- transcribe/server.py +684 -0
- transcribe/utils.py +81 -0
- transcribe/vad.py +160 -0
.gitattributes
CHANGED
|
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 36 |
*.icns filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/jfk.flac filter=lfs diff=lfs merge=lfs -text
|
| 37 |
*.icns filter=lfs diff=lfs merge=lfs -text
|
assets/jfk.flac
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63a4b1e4c1dc655ac70961ffbf518acd249df237e5a0152faae9a4a836949715
|
| 3 |
+
size 1152693
|
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/analytics/coremldata.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adbe456375e7eb3407732a426ecb65bbda86860e4aa801f3a696b70b8a533cdd
|
| 3 |
+
size 207
|
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/coremldata.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05fe28591b40616fa0c34ad7b853133623f5300923ec812acb11459c411acf3b
|
| 3 |
+
size 149
|
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/metadata.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"metadataOutputVersion" : "3.0",
|
| 4 |
+
"storagePrecision" : "Float16",
|
| 5 |
+
"outputSchema" : [
|
| 6 |
+
{
|
| 7 |
+
"hasShapeFlexibility" : "0",
|
| 8 |
+
"isOptional" : "0",
|
| 9 |
+
"dataType" : "Float32",
|
| 10 |
+
"formattedType" : "MultiArray (Float32)",
|
| 11 |
+
"shortDescription" : "",
|
| 12 |
+
"shape" : "[]",
|
| 13 |
+
"name" : "output",
|
| 14 |
+
"type" : "MultiArray"
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"modelParameters" : [
|
| 18 |
+
|
| 19 |
+
],
|
| 20 |
+
"specificationVersion" : 6,
|
| 21 |
+
"mlProgramOperationTypeHistogram" : {
|
| 22 |
+
"Linear" : 144,
|
| 23 |
+
"Matmul" : 48,
|
| 24 |
+
"Cast" : 2,
|
| 25 |
+
"Conv" : 2,
|
| 26 |
+
"Softmax" : 24,
|
| 27 |
+
"Add" : 49,
|
| 28 |
+
"LayerNorm" : 49,
|
| 29 |
+
"Mul" : 48,
|
| 30 |
+
"Transpose" : 97,
|
| 31 |
+
"Gelu" : 26,
|
| 32 |
+
"Reshape" : 96
|
| 33 |
+
},
|
| 34 |
+
"computePrecision" : "Mixed (Float16, Float32, Int32)",
|
| 35 |
+
"isUpdatable" : "0",
|
| 36 |
+
"availability" : {
|
| 37 |
+
"macOS" : "12.0",
|
| 38 |
+
"tvOS" : "15.0",
|
| 39 |
+
"watchOS" : "8.0",
|
| 40 |
+
"iOS" : "15.0",
|
| 41 |
+
"macCatalyst" : "15.0"
|
| 42 |
+
},
|
| 43 |
+
"modelType" : {
|
| 44 |
+
"name" : "MLModelType_mlProgram"
|
| 45 |
+
},
|
| 46 |
+
"userDefinedMetadata" : {
|
| 47 |
+
|
| 48 |
+
},
|
| 49 |
+
"inputSchema" : [
|
| 50 |
+
{
|
| 51 |
+
"hasShapeFlexibility" : "0",
|
| 52 |
+
"isOptional" : "0",
|
| 53 |
+
"dataType" : "Float32",
|
| 54 |
+
"formattedType" : "MultiArray (Float32 1 × 80 × 3000)",
|
| 55 |
+
"shortDescription" : "",
|
| 56 |
+
"shape" : "[1, 80, 3000]",
|
| 57 |
+
"name" : "logmel_data",
|
| 58 |
+
"type" : "MultiArray"
|
| 59 |
+
}
|
| 60 |
+
],
|
| 61 |
+
"generatedClassName" : "coreml_encoder_medium",
|
| 62 |
+
"method" : "predict"
|
| 63 |
+
}
|
| 64 |
+
]
|
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/model.mil
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/weights/weight.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a188b0e4e3109f28f38f1f47ea2497ffe623923419df8e1ae12cb5f809a1815
|
| 3 |
+
size 614507008
|
moyoyo_asr_models/ggml-medium-q5_0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19fea4b380c3a618ec4723c3eef2eb785ffba0d0538cf43f8f235e7b3b34220f
|
| 3 |
+
size 539212467
|
run_client.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transcribe.client import TranscriptionClient
|
| 2 |
+
|
| 3 |
+
client = TranscriptionClient(
|
| 4 |
+
"localhost",
|
| 5 |
+
9000,
|
| 6 |
+
lang="zh",
|
| 7 |
+
save_output_recording=False, # Only used for microphone input, False by Default
|
| 8 |
+
output_recording_filename="./output_recording.wav", # Only used for microphone input
|
| 9 |
+
max_clients=4,
|
| 10 |
+
max_connection_time=600,
|
| 11 |
+
mute_audio_playback=False, # Only used for file input, False by Default
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
if __name__ == '__main__':
|
| 15 |
+
client()
|
run_server.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument('--port', '-p',
|
| 7 |
+
type=int,
|
| 8 |
+
default=9090,
|
| 9 |
+
help="Websocket port to run the server on.")
|
| 10 |
+
parser.add_argument('--backend', '-b',
|
| 11 |
+
type=str,
|
| 12 |
+
default='pywhispercpp',
|
| 13 |
+
help='Backends from ["pywhispercpp"]')
|
| 14 |
+
|
| 15 |
+
parser.add_argument('--omp_num_threads', '-omp',
|
| 16 |
+
type=int,
|
| 17 |
+
default=1,
|
| 18 |
+
help="Number of threads to use for OpenMP")
|
| 19 |
+
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
if "OMP_NUM_THREADS" not in os.environ:
|
| 23 |
+
os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads)
|
| 24 |
+
|
| 25 |
+
from transcribe.server import TranscriptionServer
|
| 26 |
+
server = TranscriptionServer()
|
| 27 |
+
server.run(
|
| 28 |
+
"0.0.0.0",
|
| 29 |
+
port=args.port,
|
| 30 |
+
backend=args.backend,
|
| 31 |
+
)
|
transcribe/__init__.py
ADDED
|
File without changes
|
transcribe/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
transcribe/__pycache__/client.cpython-311.pyc
ADDED
|
Binary file (39 kB). View file
|
|
|
transcribe/__pycache__/server.cpython-311.pyc
ADDED
|
Binary file (36 kB). View file
|
|
|
transcribe/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
transcribe/__pycache__/vad.cpython-311.pyc
ADDED
|
Binary file (9.36 kB). View file
|
|
|
transcribe/client.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import threading
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import wave
|
| 8 |
+
|
| 9 |
+
import av
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pyaudio
|
| 12 |
+
import websocket
|
| 13 |
+
|
| 14 |
+
import transcribe.utils as utils
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Client:
|
| 18 |
+
"""
|
| 19 |
+
Handles communication with a server using WebSocket.
|
| 20 |
+
"""
|
| 21 |
+
INSTANCES = {}
|
| 22 |
+
END_OF_AUDIO = "END_OF_AUDIO"
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
host=None,
|
| 27 |
+
port=None,
|
| 28 |
+
lang=None,
|
| 29 |
+
log_transcription=True,
|
| 30 |
+
max_clients=4,
|
| 31 |
+
max_connection_time=600,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Initializes a Client instance for audio recording and streaming to a server.
|
| 35 |
+
|
| 36 |
+
If host and port are not provided, the WebSocket connection will not be established.
|
| 37 |
+
the audio recording starts immediately upon initialization.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
host (str): The hostname or IP address of the server.
|
| 41 |
+
port (int): The port number for the WebSocket server.
|
| 42 |
+
lang (str, optional): The selected language for transcription. Default is None.
|
| 43 |
+
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 44 |
+
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 45 |
+
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 46 |
+
"""
|
| 47 |
+
self.recording = False
|
| 48 |
+
self.uid = str(uuid.uuid4())
|
| 49 |
+
self.waiting = False
|
| 50 |
+
self.last_response_received = None
|
| 51 |
+
self.disconnect_if_no_response_for = 15
|
| 52 |
+
self.language = lang
|
| 53 |
+
self.server_error = False
|
| 54 |
+
self.last_segment = None
|
| 55 |
+
self.last_received_segment = None
|
| 56 |
+
self.log_transcription = log_transcription
|
| 57 |
+
self.max_clients = max_clients
|
| 58 |
+
self.max_connection_time = max_connection_time
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
self.audio_bytes = None
|
| 62 |
+
|
| 63 |
+
if host is not None and port is not None:
|
| 64 |
+
socket_url = f"ws://{host}:{port}"
|
| 65 |
+
self.client_socket = websocket.WebSocketApp(
|
| 66 |
+
socket_url,
|
| 67 |
+
on_open=lambda ws: self.on_open(ws),
|
| 68 |
+
on_message=lambda ws, message: self.on_message(ws, message),
|
| 69 |
+
on_error=lambda ws, error: self.on_error(ws, error),
|
| 70 |
+
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
| 71 |
+
ws, close_status_code, close_msg
|
| 72 |
+
),
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
print("[ERROR]: No host or port specified.")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
Client.INSTANCES[self.uid] = self
|
| 79 |
+
|
| 80 |
+
# start websocket client in a thread
|
| 81 |
+
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
| 82 |
+
self.ws_thread.daemon = True
|
| 83 |
+
self.ws_thread.start()
|
| 84 |
+
|
| 85 |
+
self.transcript = []
|
| 86 |
+
print("[INFO]: * recording")
|
| 87 |
+
|
| 88 |
+
def handle_status_messages(self, message_data):
|
| 89 |
+
"""Handles server status messages."""
|
| 90 |
+
status = message_data["status"]
|
| 91 |
+
if status == "WAIT":
|
| 92 |
+
self.waiting = True
|
| 93 |
+
print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
|
| 94 |
+
elif status == "ERROR":
|
| 95 |
+
print(f"Message from Server: {message_data['message']}")
|
| 96 |
+
self.server_error = True
|
| 97 |
+
elif status == "WARNING":
|
| 98 |
+
print(f"Message from Server: {message_data['message']}")
|
| 99 |
+
|
| 100 |
+
def process_segments(self, segments):
|
| 101 |
+
"""Processes transcript segments."""
|
| 102 |
+
text = []
|
| 103 |
+
for i, seg in enumerate(segments):
|
| 104 |
+
if not text or text[-1] != seg["text"]:
|
| 105 |
+
text.append(seg["text"])
|
| 106 |
+
if i == len(segments) - 1 and not seg.get("completed", False):
|
| 107 |
+
self.last_segment = seg
|
| 108 |
+
|
| 109 |
+
# update last received segment and last valid response time
|
| 110 |
+
if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
|
| 111 |
+
self.last_response_received = time.time()
|
| 112 |
+
self.last_received_segment = segments[-1]["text"]
|
| 113 |
+
|
| 114 |
+
if self.log_transcription:
|
| 115 |
+
# Truncate to last 3 entries for brevity.
|
| 116 |
+
text = text[-3:]
|
| 117 |
+
utils.clear_screen()
|
| 118 |
+
utils.print_transcript(text)
|
| 119 |
+
|
| 120 |
+
def on_message(self, ws, message):
|
| 121 |
+
"""
|
| 122 |
+
Callback function called when a message is received from the server.
|
| 123 |
+
|
| 124 |
+
It updates various attributes of the client based on the received message, including
|
| 125 |
+
recording status, language detection, and server messages. If a disconnect message
|
| 126 |
+
is received, it sets the recording status to False.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 130 |
+
message (str): The received message from the server.
|
| 131 |
+
|
| 132 |
+
"""
|
| 133 |
+
message = json.loads(message)
|
| 134 |
+
|
| 135 |
+
if self.uid != message.get("uid"):
|
| 136 |
+
print("[ERROR]: invalid client uid")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
if "status" in message.keys():
|
| 140 |
+
self.handle_status_messages(message)
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
| 144 |
+
print("[INFO]: Server disconnected due to overtime.")
|
| 145 |
+
self.recording = False
|
| 146 |
+
|
| 147 |
+
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
| 148 |
+
self.last_response_received = time.time()
|
| 149 |
+
self.recording = True
|
| 150 |
+
self.server_backend = message["backend"]
|
| 151 |
+
print(f"[INFO]: Server Running with backend {self.server_backend}")
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
if "language" in message.keys():
|
| 155 |
+
self.language = message.get("language")
|
| 156 |
+
lang_prob = message.get("language_prob")
|
| 157 |
+
print(
|
| 158 |
+
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
| 159 |
+
)
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
if "segments" in message.keys():
|
| 163 |
+
self.process_segments(message["segments"])
|
| 164 |
+
|
| 165 |
+
def on_error(self, ws, error):
|
| 166 |
+
print(f"[ERROR] WebSocket Error: {error}")
|
| 167 |
+
self.server_error = True
|
| 168 |
+
self.error_message = error
|
| 169 |
+
|
| 170 |
+
def on_close(self, ws, close_status_code, close_msg):
|
| 171 |
+
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
| 172 |
+
self.recording = False
|
| 173 |
+
self.waiting = False
|
| 174 |
+
|
| 175 |
+
def on_open(self, ws):
|
| 176 |
+
"""
|
| 177 |
+
Callback function called when the WebSocket connection is successfully opened.
|
| 178 |
+
|
| 179 |
+
Sends an initial configuration message to the server, including client UID,
|
| 180 |
+
language selection, and task type.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 184 |
+
|
| 185 |
+
"""
|
| 186 |
+
print("[INFO]: Opened connection")
|
| 187 |
+
ws.send(
|
| 188 |
+
json.dumps(
|
| 189 |
+
{
|
| 190 |
+
"uid": self.uid,
|
| 191 |
+
"language": self.language,
|
| 192 |
+
"max_clients": self.max_clients,
|
| 193 |
+
"max_connection_time": self.max_connection_time,
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def send_packet_to_server(self, message):
|
| 199 |
+
"""
|
| 200 |
+
Send an audio packet to the server using WebSocket.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
message (bytes): The audio data packet in bytes to be sent to the server.
|
| 204 |
+
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(e)
|
| 210 |
+
|
| 211 |
+
def close_websocket(self):
|
| 212 |
+
"""
|
| 213 |
+
Close the WebSocket connection and join the WebSocket thread.
|
| 214 |
+
|
| 215 |
+
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
| 216 |
+
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
try:
|
| 220 |
+
self.client_socket.close()
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print("[ERROR]: Error closing WebSocket:", e)
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
self.ws_thread.join()
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print("[ERROR:] Error joining WebSocket thread:", e)
|
| 228 |
+
|
| 229 |
+
def get_client_socket(self):
|
| 230 |
+
"""
|
| 231 |
+
Get the WebSocket client socket instance.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
| 235 |
+
"""
|
| 236 |
+
return self.client_socket
|
| 237 |
+
|
| 238 |
+
def wait_before_disconnect(self):
|
| 239 |
+
"""Waits a bit before disconnecting in order to process pending responses."""
|
| 240 |
+
assert self.last_response_received
|
| 241 |
+
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class TranscriptionTeeClient:
|
| 246 |
+
"""
|
| 247 |
+
Client for handling audio recording, streaming, and transcription tasks via one or more
|
| 248 |
+
WebSocket connections.
|
| 249 |
+
|
| 250 |
+
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 251 |
+
to send audio data for transcription to one or more servers, and receive transcribed text segments.
|
| 252 |
+
Args:
|
| 253 |
+
clients (list): one or more previously initialized Client instances
|
| 254 |
+
|
| 255 |
+
Attributes:
|
| 256 |
+
clients (list): the underlying Client instances responsible for handling WebSocket connections.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
|
| 260 |
+
mute_audio_playback=False):
|
| 261 |
+
self.clients = clients
|
| 262 |
+
if not self.clients:
|
| 263 |
+
raise Exception("At least one client is required.")
|
| 264 |
+
self.chunk = 4096
|
| 265 |
+
self.format = pyaudio.paInt16
|
| 266 |
+
self.channels = 1
|
| 267 |
+
self.rate = 16000
|
| 268 |
+
self.record_seconds = 60000
|
| 269 |
+
self.save_output_recording = save_output_recording
|
| 270 |
+
self.output_recording_filename = output_recording_filename
|
| 271 |
+
self.mute_audio_playback = mute_audio_playback
|
| 272 |
+
self.frames = b""
|
| 273 |
+
self.p = pyaudio.PyAudio()
|
| 274 |
+
try:
|
| 275 |
+
self.stream = self.p.open(
|
| 276 |
+
format=self.format,
|
| 277 |
+
channels=self.channels,
|
| 278 |
+
rate=self.rate,
|
| 279 |
+
input=True,
|
| 280 |
+
frames_per_buffer=self.chunk,
|
| 281 |
+
)
|
| 282 |
+
except OSError as error:
|
| 283 |
+
print(f"[WARN]: Unable to access microphone. {error}")
|
| 284 |
+
self.stream = None
|
| 285 |
+
|
| 286 |
+
def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
|
| 287 |
+
"""
|
| 288 |
+
Start the transcription process.
|
| 289 |
+
|
| 290 |
+
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
| 291 |
+
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
| 292 |
+
will be played and streamed to the server; otherwise, it will perform live recording.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
| 296 |
+
|
| 297 |
+
"""
|
| 298 |
+
assert sum(
|
| 299 |
+
source is not None for source in [audio, rtsp_url, hls_url]
|
| 300 |
+
) <= 1, 'You must provide only one selected source'
|
| 301 |
+
|
| 302 |
+
print("[INFO]: Waiting for server ready ...")
|
| 303 |
+
for client in self.clients:
|
| 304 |
+
while not client.recording:
|
| 305 |
+
if client.waiting or client.server_error:
|
| 306 |
+
self.close_all_clients()
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
print("[INFO]: Server Ready!")
|
| 310 |
+
if hls_url is not None:
|
| 311 |
+
self.process_hls_stream(hls_url, save_file)
|
| 312 |
+
elif audio is not None:
|
| 313 |
+
resampled_file = utils.resample(audio)
|
| 314 |
+
self.play_file(resampled_file)
|
| 315 |
+
elif rtsp_url is not None:
|
| 316 |
+
self.process_rtsp_stream(rtsp_url)
|
| 317 |
+
else:
|
| 318 |
+
self.record()
|
| 319 |
+
|
| 320 |
+
def close_all_clients(self):
|
| 321 |
+
"""Closes all client websockets."""
|
| 322 |
+
for client in self.clients:
|
| 323 |
+
client.close_websocket()
|
| 324 |
+
|
| 325 |
+
def multicast_packet(self, packet, unconditional=False):
|
| 326 |
+
"""
|
| 327 |
+
Sends an identical packet via all clients.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
packet (bytes): The audio data packet in bytes to be sent.
|
| 331 |
+
unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
|
| 332 |
+
"""
|
| 333 |
+
for client in self.clients:
|
| 334 |
+
if (unconditional or client.recording):
|
| 335 |
+
client.send_packet_to_server(packet)
|
| 336 |
+
|
| 337 |
+
def play_file(self, filename):
|
| 338 |
+
"""
|
| 339 |
+
Play an audio file and send it to the server for processing.
|
| 340 |
+
|
| 341 |
+
Reads an audio file, plays it through the audio output, and simultaneously sends
|
| 342 |
+
the audio data to the server for processing. It uses PyAudio to create an audio
|
| 343 |
+
stream for playback. The audio data is read from the file in chunks, converted to
|
| 344 |
+
floating-point format, and sent to the server using WebSocket communication.
|
| 345 |
+
This method is typically used when you want to process pre-recorded audio and send it
|
| 346 |
+
to the server in real-time.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
filename (str): The path to the audio file to be played and sent to the server.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
# read audio and create pyaudio stream
|
| 353 |
+
with wave.open(filename, "rb") as wavfile:
|
| 354 |
+
self.stream = self.p.open(
|
| 355 |
+
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
| 356 |
+
channels=wavfile.getnchannels(),
|
| 357 |
+
rate=wavfile.getframerate(),
|
| 358 |
+
input=True,
|
| 359 |
+
output=True,
|
| 360 |
+
frames_per_buffer=self.chunk,
|
| 361 |
+
)
|
| 362 |
+
chunk_duration = self.chunk / float(wavfile.getframerate())
|
| 363 |
+
try:
|
| 364 |
+
while any(client.recording for client in self.clients):
|
| 365 |
+
data = wavfile.readframes(self.chunk)
|
| 366 |
+
if data == b"":
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
audio_array = self.bytes_to_float_array(data)
|
| 370 |
+
self.multicast_packet(audio_array.tobytes())
|
| 371 |
+
if self.mute_audio_playback:
|
| 372 |
+
time.sleep(chunk_duration)
|
| 373 |
+
else:
|
| 374 |
+
self.stream.write(data)
|
| 375 |
+
|
| 376 |
+
wavfile.close()
|
| 377 |
+
|
| 378 |
+
for client in self.clients:
|
| 379 |
+
client.wait_before_disconnect()
|
| 380 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 381 |
+
self.stream.close()
|
| 382 |
+
self.close_all_clients()
|
| 383 |
+
|
| 384 |
+
except KeyboardInterrupt:
|
| 385 |
+
wavfile.close()
|
| 386 |
+
self.stream.stop_stream()
|
| 387 |
+
self.stream.close()
|
| 388 |
+
self.p.terminate()
|
| 389 |
+
self.close_all_clients()
|
| 390 |
+
print("[INFO]: Keyboard interrupt.")
|
| 391 |
+
|
| 392 |
+
def process_rtsp_stream(self, rtsp_url):
|
| 393 |
+
"""
|
| 394 |
+
Connect to an RTSP source, process the audio stream, and send it for transcription.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
rtsp_url (str): The URL of the RTSP stream source.
|
| 398 |
+
"""
|
| 399 |
+
print("[INFO]: Connecting to RTSP stream...")
|
| 400 |
+
try:
|
| 401 |
+
container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
|
| 402 |
+
self.process_av_stream(container, stream_type="RTSP")
|
| 403 |
+
except Exception as e:
|
| 404 |
+
print(f"[ERROR]: Failed to process RTSP stream: {e}")
|
| 405 |
+
finally:
|
| 406 |
+
for client in self.clients:
|
| 407 |
+
client.wait_before_disconnect()
|
| 408 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 409 |
+
self.close_all_clients()
|
| 410 |
+
print("[INFO]: RTSP stream processing finished.")
|
| 411 |
+
|
| 412 |
+
def process_hls_stream(self, hls_url, save_file=None):
|
| 413 |
+
"""
|
| 414 |
+
Connect to an HLS source, process the audio stream, and send it for transcription.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
hls_url (str): The URL of the HLS stream source.
|
| 418 |
+
save_file (str, optional): Local path to save the network stream.
|
| 419 |
+
"""
|
| 420 |
+
print("[INFO]: Connecting to HLS stream...")
|
| 421 |
+
try:
|
| 422 |
+
container = av.open(hls_url, format="hls")
|
| 423 |
+
self.process_av_stream(container, stream_type="HLS", save_file=save_file)
|
| 424 |
+
except Exception as e:
|
| 425 |
+
print(f"[ERROR]: Failed to process HLS stream: {e}")
|
| 426 |
+
finally:
|
| 427 |
+
for client in self.clients:
|
| 428 |
+
client.wait_before_disconnect()
|
| 429 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 430 |
+
self.close_all_clients()
|
| 431 |
+
print("[INFO]: HLS stream processing finished.")
|
| 432 |
+
|
| 433 |
+
def process_av_stream(self, container, stream_type, save_file=None):
|
| 434 |
+
"""
|
| 435 |
+
Process an AV container stream and send audio packets to the server.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
container (av.container.InputContainer): The input container to process.
|
| 439 |
+
stream_type (str): The type of stream being processed ("RTSP" or "HLS").
|
| 440 |
+
save_file (str, optional): Local path to save the stream. Default is None.
|
| 441 |
+
"""
|
| 442 |
+
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
| 443 |
+
if not audio_stream:
|
| 444 |
+
print(f"[ERROR]: No audio stream found in {stream_type} source.")
|
| 445 |
+
return
|
| 446 |
+
|
| 447 |
+
output_container = None
|
| 448 |
+
if save_file:
|
| 449 |
+
output_container = av.open(save_file, mode="w")
|
| 450 |
+
output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
for packet in container.demux(audio_stream):
|
| 454 |
+
for frame in packet.decode():
|
| 455 |
+
audio_data = frame.to_ndarray().tobytes()
|
| 456 |
+
self.multicast_packet(audio_data)
|
| 457 |
+
|
| 458 |
+
if save_file:
|
| 459 |
+
output_container.mux(frame)
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
|
| 462 |
+
finally:
|
| 463 |
+
# Wait for server to send any leftover transcription.
|
| 464 |
+
time.sleep(5)
|
| 465 |
+
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 466 |
+
if output_container:
|
| 467 |
+
output_container.close()
|
| 468 |
+
container.close()
|
| 469 |
+
|
| 470 |
+
def save_chunk(self, n_audio_file):
|
| 471 |
+
"""
|
| 472 |
+
Saves the current audio frames to a WAV file in a separate thread.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
n_audio_file (int): The index of the audio file which determines the filename.
|
| 476 |
+
This helps in maintaining the order and uniqueness of each chunk.
|
| 477 |
+
"""
|
| 478 |
+
t = threading.Thread(
|
| 479 |
+
target=self.write_audio_frames_to_file,
|
| 480 |
+
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
|
| 481 |
+
)
|
| 482 |
+
t.start()
|
| 483 |
+
|
| 484 |
+
def finalize_recording(self, n_audio_file):
|
| 485 |
+
"""
|
| 486 |
+
Finalizes the recording process by saving any remaining audio frames,
|
| 487 |
+
closing the audio stream, and terminating the process.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
|
| 491 |
+
This index is incremented before use if the last chunk is saved.
|
| 492 |
+
"""
|
| 493 |
+
if self.save_output_recording and len(self.frames):
|
| 494 |
+
self.write_audio_frames_to_file(
|
| 495 |
+
self.frames[:], f"chunks/{n_audio_file}.wav"
|
| 496 |
+
)
|
| 497 |
+
n_audio_file += 1
|
| 498 |
+
self.stream.stop_stream()
|
| 499 |
+
self.stream.close()
|
| 500 |
+
self.p.terminate()
|
| 501 |
+
self.close_all_clients()
|
| 502 |
+
if self.save_output_recording:
|
| 503 |
+
self.write_output_recording(n_audio_file)
|
| 504 |
+
|
| 505 |
+
def record(self):
|
| 506 |
+
"""
|
| 507 |
+
Record audio data from the input stream and save it to a WAV file.
|
| 508 |
+
|
| 509 |
+
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
| 510 |
+
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
| 511 |
+
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
| 512 |
+
|
| 513 |
+
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
| 514 |
+
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
| 515 |
+
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
| 516 |
+
the method combines all the saved audio chunks into the specified `out_file`.
|
| 517 |
+
"""
|
| 518 |
+
n_audio_file = 0
|
| 519 |
+
if self.save_output_recording:
|
| 520 |
+
if os.path.exists("chunks"):
|
| 521 |
+
shutil.rmtree("chunks")
|
| 522 |
+
os.makedirs("chunks")
|
| 523 |
+
try:
|
| 524 |
+
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
| 525 |
+
if not any(client.recording for client in self.clients):
|
| 526 |
+
break
|
| 527 |
+
data = self.stream.read(self.chunk, exception_on_overflow=False)
|
| 528 |
+
self.frames += data
|
| 529 |
+
|
| 530 |
+
audio_array = self.bytes_to_float_array(data)
|
| 531 |
+
|
| 532 |
+
self.multicast_packet(audio_array.tobytes())
|
| 533 |
+
|
| 534 |
+
# save frames if more than a minute
|
| 535 |
+
if len(self.frames) > 60 * self.rate:
|
| 536 |
+
if self.save_output_recording:
|
| 537 |
+
self.save_chunk(n_audio_file)
|
| 538 |
+
n_audio_file += 1
|
| 539 |
+
self.frames = b""
|
| 540 |
+
|
| 541 |
+
except KeyboardInterrupt:
|
| 542 |
+
self.finalize_recording(n_audio_file)
|
| 543 |
+
|
| 544 |
+
def write_audio_frames_to_file(self, frames, file_name):
|
| 545 |
+
"""
|
| 546 |
+
Write audio frames to a WAV file.
|
| 547 |
+
|
| 548 |
+
The WAV file is created or overwritten with the specified name. The audio frames should be
|
| 549 |
+
in the correct format and match the specified channel, sample width, and sample rate.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
frames (bytes): The audio frames to be written to the file.
|
| 553 |
+
file_name (str): The name of the WAV file to which the frames will be written.
|
| 554 |
+
|
| 555 |
+
"""
|
| 556 |
+
with wave.open(file_name, "wb") as wavfile:
|
| 557 |
+
wavfile: wave.Wave_write
|
| 558 |
+
wavfile.setnchannels(self.channels)
|
| 559 |
+
wavfile.setsampwidth(2)
|
| 560 |
+
wavfile.setframerate(self.rate)
|
| 561 |
+
wavfile.writeframes(frames)
|
| 562 |
+
|
| 563 |
+
def write_output_recording(self, n_audio_file):
|
| 564 |
+
"""
|
| 565 |
+
Combine and save recorded audio chunks into a single WAV file.
|
| 566 |
+
|
| 567 |
+
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
| 568 |
+
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
| 569 |
+
and saving, the final recording is stored in the specified `out_file`.
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
n_audio_file (int): The number of audio chunk files to combine.
|
| 574 |
+
out_file (str): The name of the output WAV file to save the final recording.
|
| 575 |
+
|
| 576 |
+
"""
|
| 577 |
+
input_files = [
|
| 578 |
+
f"chunks/{i}.wav"
|
| 579 |
+
for i in range(n_audio_file)
|
| 580 |
+
if os.path.exists(f"chunks/{i}.wav")
|
| 581 |
+
]
|
| 582 |
+
with wave.open(self.output_recording_filename, "wb") as wavfile:
|
| 583 |
+
wavfile: wave.Wave_write
|
| 584 |
+
wavfile.setnchannels(self.channels)
|
| 585 |
+
wavfile.setsampwidth(2)
|
| 586 |
+
wavfile.setframerate(self.rate)
|
| 587 |
+
for in_file in input_files:
|
| 588 |
+
with wave.open(in_file, "rb") as wav_in:
|
| 589 |
+
while True:
|
| 590 |
+
data = wav_in.readframes(self.chunk)
|
| 591 |
+
if data == b"":
|
| 592 |
+
break
|
| 593 |
+
wavfile.writeframes(data)
|
| 594 |
+
# remove this file
|
| 595 |
+
os.remove(in_file)
|
| 596 |
+
wavfile.close()
|
| 597 |
+
# clean up temporary directory to store chunks
|
| 598 |
+
if os.path.exists("chunks"):
|
| 599 |
+
shutil.rmtree("chunks")
|
| 600 |
+
|
| 601 |
+
@staticmethod
|
| 602 |
+
def bytes_to_float_array(audio_bytes):
|
| 603 |
+
"""
|
| 604 |
+
Convert audio data from bytes to a NumPy float array.
|
| 605 |
+
|
| 606 |
+
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
| 607 |
+
have values between -1 and 1.
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
audio_bytes (bytes): Audio data in bytes.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
| 614 |
+
"""
|
| 615 |
+
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
| 616 |
+
return raw_data.astype(np.float32) / 32768.0
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class TranscriptionClient(TranscriptionTeeClient):
|
| 620 |
+
"""
|
| 621 |
+
Client for handling audio transcription tasks via a single WebSocket connection.
|
| 622 |
+
|
| 623 |
+
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 624 |
+
to send audio data for transcription to a server and receive transcribed text segments.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
host (str): The hostname or IP address of the server.
|
| 628 |
+
port (int): The port number to connect to on the server.
|
| 629 |
+
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
|
| 630 |
+
save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
|
| 631 |
+
output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
|
| 632 |
+
output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
|
| 633 |
+
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 634 |
+
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 635 |
+
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 636 |
+
mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
|
| 637 |
+
|
| 638 |
+
Attributes:
|
| 639 |
+
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
| 640 |
+
|
| 641 |
+
Example:
|
| 642 |
+
To create a TranscriptionClient and start transcription on microphone audio:
|
| 643 |
+
```python
|
| 644 |
+
transcription_client = TranscriptionClient(host="localhost", port=9090)
|
| 645 |
+
transcription_client()
|
| 646 |
+
```
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
def __init__(
|
| 650 |
+
self,
|
| 651 |
+
host,
|
| 652 |
+
port,
|
| 653 |
+
lang=None,
|
| 654 |
+
save_output_recording=False,
|
| 655 |
+
output_recording_filename="./output_recording.wav",
|
| 656 |
+
log_transcription=True,
|
| 657 |
+
max_clients=4,
|
| 658 |
+
max_connection_time=600,
|
| 659 |
+
mute_audio_playback=False,
|
| 660 |
+
):
|
| 661 |
+
self.client = Client(
|
| 662 |
+
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
| 663 |
+
max_connection_time=max_connection_time
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
| 667 |
+
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
|
| 668 |
+
|
| 669 |
+
TranscriptionTeeClient.__init__(
|
| 670 |
+
self,
|
| 671 |
+
[self.client],
|
| 672 |
+
save_output_recording=save_output_recording,
|
| 673 |
+
output_recording_filename=output_recording_filename,
|
| 674 |
+
mute_audio_playback=mute_audio_playback
|
| 675 |
+
)
|
transcribe/server.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import pathlib
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
import librosa
|
| 11 |
+
import numpy as np
|
| 12 |
+
import soundfile
|
| 13 |
+
from pywhispercpp.model import Model
|
| 14 |
+
from websockets.exceptions import ConnectionClosed
|
| 15 |
+
from websockets.sync.server import serve
|
| 16 |
+
|
| 17 |
+
from transcribe.vad import VoiceActivityDetector
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ClientManager:
|
| 23 |
+
def __init__(self, max_clients=4, max_connection_time=600):
|
| 24 |
+
"""
|
| 25 |
+
Initializes the ClientManager with specified limits on client connections and connection durations.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
| 29 |
+
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
| 30 |
+
to 600 seconds (10 minutes).
|
| 31 |
+
"""
|
| 32 |
+
self.clients = {}
|
| 33 |
+
self.start_times = {}
|
| 34 |
+
self.max_clients = max_clients
|
| 35 |
+
self.max_connection_time = max_connection_time
|
| 36 |
+
|
| 37 |
+
def add_client(self, websocket, client):
|
| 38 |
+
"""
|
| 39 |
+
Adds a client and their connection start time to the tracking dictionaries.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
websocket: The websocket associated with the client to add.
|
| 43 |
+
client: The client object to be added and tracked.
|
| 44 |
+
"""
|
| 45 |
+
self.clients[websocket] = client
|
| 46 |
+
self.start_times[websocket] = time.time()
|
| 47 |
+
|
| 48 |
+
def get_client(self, websocket):
|
| 49 |
+
"""
|
| 50 |
+
Retrieves a client associated with the given websocket.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
websocket: The websocket associated with the client to retrieve.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
The client object if found, False otherwise.
|
| 57 |
+
"""
|
| 58 |
+
if websocket in self.clients:
|
| 59 |
+
return self.clients[websocket]
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
def remove_client(self, websocket):
|
| 63 |
+
"""
|
| 64 |
+
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
| 65 |
+
client if necessary.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
websocket: The websocket associated with the client to be removed.
|
| 69 |
+
"""
|
| 70 |
+
client = self.clients.pop(websocket, None)
|
| 71 |
+
if client:
|
| 72 |
+
client.cleanup()
|
| 73 |
+
self.start_times.pop(websocket, None)
|
| 74 |
+
|
| 75 |
+
def get_wait_time(self):
|
| 76 |
+
"""
|
| 77 |
+
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
| 81 |
+
"""
|
| 82 |
+
wait_time = None
|
| 83 |
+
for start_time in self.start_times.values():
|
| 84 |
+
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
| 85 |
+
if wait_time is None or current_client_time_remaining < wait_time:
|
| 86 |
+
wait_time = current_client_time_remaining
|
| 87 |
+
return wait_time / 60 if wait_time is not None else 0
|
| 88 |
+
|
| 89 |
+
def is_server_full(self, websocket, options):
|
| 90 |
+
"""
|
| 91 |
+
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
websocket: The websocket of the client attempting to connect.
|
| 95 |
+
options: A dictionary of options that may include the client's unique identifier.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
True if the server is full, False otherwise.
|
| 99 |
+
"""
|
| 100 |
+
if len(self.clients) >= self.max_clients:
|
| 101 |
+
wait_time = self.get_wait_time()
|
| 102 |
+
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
| 103 |
+
websocket.send(json.dumps(response))
|
| 104 |
+
return True
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
def is_client_timeout(self, websocket):
|
| 108 |
+
"""
|
| 109 |
+
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
websocket: The websocket associated with the client to check.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
| 116 |
+
"""
|
| 117 |
+
elapsed_time = time.time() - self.start_times[websocket]
|
| 118 |
+
if elapsed_time >= self.max_connection_time:
|
| 119 |
+
self.clients[websocket].disconnect()
|
| 120 |
+
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
| 121 |
+
return True
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class BackendType(Enum):
|
| 126 |
+
PYWHISPERCPP = "pywhispercpp"
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def valid_types() -> List[str]:
|
| 130 |
+
return [backend_type.value for backend_type in BackendType]
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def is_valid(backend: str) -> bool:
|
| 134 |
+
return backend in BackendType.valid_types()
|
| 135 |
+
|
| 136 |
+
def is_pywhispercpp(self) -> bool:
|
| 137 |
+
return self == BackendType.PYWHISPERCPP
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class TranscriptionServer:
|
| 141 |
+
RATE = 16000
|
| 142 |
+
|
| 143 |
+
def __init__(self):
|
| 144 |
+
self.client_manager = None
|
| 145 |
+
self.no_voice_activity_chunks = 0
|
| 146 |
+
self.single_model = False
|
| 147 |
+
|
| 148 |
+
def initialize_client(
|
| 149 |
+
self, websocket, options
|
| 150 |
+
):
|
| 151 |
+
client: Optional[ServeClientBase] = None
|
| 152 |
+
|
| 153 |
+
if self.backend.is_pywhispercpp():
|
| 154 |
+
client = ServeClientWhisperCPP(
|
| 155 |
+
websocket,
|
| 156 |
+
language=options["language"],
|
| 157 |
+
client_uid=options["uid"],
|
| 158 |
+
single_model=self.single_model,
|
| 159 |
+
)
|
| 160 |
+
logging.info("Running pywhispercpp backend.")
|
| 161 |
+
|
| 162 |
+
if client is None:
|
| 163 |
+
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
| 164 |
+
|
| 165 |
+
self.client_manager.add_client(websocket, client)
|
| 166 |
+
|
| 167 |
+
def get_audio_from_websocket(self, websocket):
|
| 168 |
+
"""
|
| 169 |
+
Receives audio buffer from websocket and creates a numpy array out of it.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
websocket: The websocket to receive audio from.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
A numpy array containing the audio.
|
| 176 |
+
"""
|
| 177 |
+
frame_data = websocket.recv()
|
| 178 |
+
if frame_data == b"END_OF_AUDIO":
|
| 179 |
+
return False
|
| 180 |
+
return np.frombuffer(frame_data, dtype=np.float32)
|
| 181 |
+
|
| 182 |
+
def handle_new_connection(self, websocket):
|
| 183 |
+
try:
|
| 184 |
+
logging.info("New client connected")
|
| 185 |
+
options = websocket.recv()
|
| 186 |
+
options = json.loads(options)
|
| 187 |
+
|
| 188 |
+
if self.client_manager is None:
|
| 189 |
+
max_clients = options.get('max_clients', 4)
|
| 190 |
+
max_connection_time = options.get('max_connection_time', 600)
|
| 191 |
+
self.client_manager = ClientManager(max_clients, max_connection_time)
|
| 192 |
+
|
| 193 |
+
if self.client_manager.is_server_full(websocket, options):
|
| 194 |
+
websocket.close()
|
| 195 |
+
return False # Indicates that the connection should not continue
|
| 196 |
+
|
| 197 |
+
if self.backend.is_pywhispercpp():
|
| 198 |
+
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
| 199 |
+
|
| 200 |
+
self.initialize_client(websocket, options)
|
| 201 |
+
|
| 202 |
+
return True
|
| 203 |
+
except json.JSONDecodeError:
|
| 204 |
+
logging.error("Failed to decode JSON from client")
|
| 205 |
+
return False
|
| 206 |
+
except ConnectionClosed:
|
| 207 |
+
logging.info("Connection closed by client")
|
| 208 |
+
return False
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logging.error(f"Error during new connection initialization: {str(e)}")
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
def process_audio_frames(self, websocket):
|
| 214 |
+
frame_np = self.get_audio_from_websocket(websocket)
|
| 215 |
+
client = self.client_manager.get_client(websocket)
|
| 216 |
+
|
| 217 |
+
# TODO Vad has some problem, it will be blocking process loop
|
| 218 |
+
# if frame_np is False:
|
| 219 |
+
# if self.backend.is_pywhispercpp():
|
| 220 |
+
# client.set_eos(True)
|
| 221 |
+
# return False
|
| 222 |
+
|
| 223 |
+
# if self.backend.is_pywhispercpp():
|
| 224 |
+
# voice_active = self.voice_activity(websocket, frame_np)
|
| 225 |
+
# if voice_active:
|
| 226 |
+
# self.no_voice_activity_chunks = 0
|
| 227 |
+
# client.set_eos(False)
|
| 228 |
+
# if self.use_vad and not voice_active:
|
| 229 |
+
# return True
|
| 230 |
+
|
| 231 |
+
client.add_frames(frame_np)
|
| 232 |
+
return True
|
| 233 |
+
|
| 234 |
+
def recv_audio(self,
|
| 235 |
+
websocket,
|
| 236 |
+
backend: BackendType = BackendType.PYWHISPERCPP):
|
| 237 |
+
|
| 238 |
+
self.backend = backend
|
| 239 |
+
if not self.handle_new_connection(websocket):
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
while not self.client_manager.is_client_timeout(websocket):
|
| 244 |
+
if not self.process_audio_frames(websocket):
|
| 245 |
+
break
|
| 246 |
+
except ConnectionClosed:
|
| 247 |
+
logging.info("Connection closed by client")
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logging.error(f"Unexpected error: {str(e)}")
|
| 250 |
+
finally:
|
| 251 |
+
if self.client_manager.get_client(websocket):
|
| 252 |
+
self.cleanup(websocket)
|
| 253 |
+
websocket.close()
|
| 254 |
+
del websocket
|
| 255 |
+
|
| 256 |
+
def run(self,
|
| 257 |
+
host,
|
| 258 |
+
port=9090,
|
| 259 |
+
backend="pywhispercpp"):
|
| 260 |
+
"""
|
| 261 |
+
Run the transcription server.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
host (str): The host address to bind the server.
|
| 265 |
+
port (int): The port number to bind the server.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
if not BackendType.is_valid(backend):
|
| 269 |
+
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
| 270 |
+
|
| 271 |
+
with serve(
|
| 272 |
+
functools.partial(
|
| 273 |
+
self.recv_audio,
|
| 274 |
+
backend=BackendType(backend),
|
| 275 |
+
),
|
| 276 |
+
host,
|
| 277 |
+
port
|
| 278 |
+
) as server:
|
| 279 |
+
server.serve_forever()
|
| 280 |
+
|
| 281 |
+
def voice_activity(self, websocket, frame_np):
|
| 282 |
+
"""
|
| 283 |
+
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
| 284 |
+
|
| 285 |
+
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
| 286 |
+
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
| 287 |
+
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
| 288 |
+
speech detection to improve subsequent processing steps.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
websocket: The websocket associated with the current client. Used to retrieve the client object
|
| 292 |
+
from the client manager for state management.
|
| 293 |
+
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
| 294 |
+
the audio data for the current frame.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
| 298 |
+
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
| 299 |
+
end-of-speech (EOS) flag for the client.
|
| 300 |
+
"""
|
| 301 |
+
if not self.vad_detector(frame_np):
|
| 302 |
+
self.no_voice_activity_chunks += 1
|
| 303 |
+
if self.no_voice_activity_chunks > 3:
|
| 304 |
+
client = self.client_manager.get_client(websocket)
|
| 305 |
+
if not client.eos:
|
| 306 |
+
client.set_eos(True)
|
| 307 |
+
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
| 308 |
+
return False
|
| 309 |
+
return True
|
| 310 |
+
|
| 311 |
+
def cleanup(self, websocket):
|
| 312 |
+
"""
|
| 313 |
+
Cleans up resources associated with a given client's websocket.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
websocket: The websocket associated with the client to be cleaned up.
|
| 317 |
+
"""
|
| 318 |
+
if self.client_manager.get_client(websocket):
|
| 319 |
+
self.client_manager.remove_client(websocket)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class ServeClientBase(object):
|
| 323 |
+
RATE = 16000
|
| 324 |
+
SERVER_READY = "SERVER_READY"
|
| 325 |
+
DISCONNECT = "DISCONNECT"
|
| 326 |
+
|
| 327 |
+
def __init__(self, client_uid, websocket):
|
| 328 |
+
self.client_uid = client_uid
|
| 329 |
+
self.websocket = websocket
|
| 330 |
+
self.frames = b""
|
| 331 |
+
self.timestamp_offset = 0.0
|
| 332 |
+
self.frames_np = None
|
| 333 |
+
self.frames_offset = 0.0
|
| 334 |
+
self.text = []
|
| 335 |
+
self.current_out = ''
|
| 336 |
+
self.prev_out = ''
|
| 337 |
+
self.t_start = None
|
| 338 |
+
self.exit = False
|
| 339 |
+
self.same_output_count = 0
|
| 340 |
+
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
| 341 |
+
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
| 342 |
+
self.transcript = []
|
| 343 |
+
self.send_last_n_segments = 10
|
| 344 |
+
|
| 345 |
+
# text formatting
|
| 346 |
+
self.pick_previous_segments = 2
|
| 347 |
+
|
| 348 |
+
# threading
|
| 349 |
+
self.lock = threading.Lock()
|
| 350 |
+
|
| 351 |
+
def speech_to_text(self):
|
| 352 |
+
raise NotImplementedError
|
| 353 |
+
|
| 354 |
+
def transcribe_audio(self):
|
| 355 |
+
raise NotImplementedError
|
| 356 |
+
|
| 357 |
+
def handle_transcription_output(self):
|
| 358 |
+
raise NotImplementedError
|
| 359 |
+
|
| 360 |
+
def add_frames(self, frame_np):
|
| 361 |
+
"""
|
| 362 |
+
Add audio frames to the ongoing audio stream buffer.
|
| 363 |
+
|
| 364 |
+
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
| 365 |
+
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
| 366 |
+
to prevent excessive memory usage.
|
| 367 |
+
|
| 368 |
+
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
| 369 |
+
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
| 370 |
+
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
| 374 |
+
|
| 375 |
+
"""
|
| 376 |
+
self.lock.acquire()
|
| 377 |
+
if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
|
| 378 |
+
self.frames_offset += 30.0
|
| 379 |
+
self.frames_np = self.frames_np[int(30 * self.RATE):]
|
| 380 |
+
# check timestamp offset(should be >= self.frame_offset)
|
| 381 |
+
# this basically means that there is no speech as timestamp offset hasnt updated
|
| 382 |
+
# and is less than frame_offset
|
| 383 |
+
if self.timestamp_offset < self.frames_offset:
|
| 384 |
+
self.timestamp_offset = self.frames_offset
|
| 385 |
+
if self.frames_np is None:
|
| 386 |
+
self.frames_np = frame_np.copy()
|
| 387 |
+
else:
|
| 388 |
+
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
| 389 |
+
self.lock.release()
|
| 390 |
+
|
| 391 |
+
def clip_audio_if_no_valid_segment(self):
|
| 392 |
+
"""
|
| 393 |
+
Update the timestamp offset based on audio buffer status.
|
| 394 |
+
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
| 395 |
+
no valid segment for the last 30 seconds from whisper
|
| 396 |
+
"""
|
| 397 |
+
with self.lock:
|
| 398 |
+
if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
|
| 399 |
+
duration = self.frames_np.shape[0] / self.RATE
|
| 400 |
+
self.timestamp_offset = self.frames_offset + duration - 5
|
| 401 |
+
|
| 402 |
+
def get_audio_chunk_for_processing(self):
|
| 403 |
+
"""
|
| 404 |
+
Retrieves the next chunk of audio data for processing based on the current offsets.
|
| 405 |
+
|
| 406 |
+
Calculates which part of the audio data should be processed next, based on
|
| 407 |
+
the difference between the current timestamp offset and the frame's offset, scaled by
|
| 408 |
+
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
| 409 |
+
duration in seconds.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
tuple: A tuple containing:
|
| 413 |
+
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
| 414 |
+
- duration (float): The duration of the audio chunk in seconds.
|
| 415 |
+
"""
|
| 416 |
+
with self.lock:
|
| 417 |
+
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
| 418 |
+
input_bytes = self.frames_np[int(samples_take):].copy()
|
| 419 |
+
duration = input_bytes.shape[0] / self.RATE
|
| 420 |
+
return input_bytes, duration
|
| 421 |
+
|
| 422 |
+
def prepare_segments(self, last_segment=None):
|
| 423 |
+
"""
|
| 424 |
+
Prepares the segments of transcribed text to be sent to the client.
|
| 425 |
+
|
| 426 |
+
This method compiles the recent segments of transcribed text, ensuring that only the
|
| 427 |
+
specified number of the most recent segments are included. It also appends the most
|
| 428 |
+
recent segment of text if provided (which is considered incomplete because of the possibility
|
| 429 |
+
of the last word being truncated in the audio chunk).
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
last_segment (str, optional): The most recent segment of transcribed text to be added
|
| 433 |
+
to the list of segments. Defaults to None.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
list: A list of transcribed text segments to be sent to the client.
|
| 437 |
+
"""
|
| 438 |
+
segments = []
|
| 439 |
+
if len(self.transcript) >= self.send_last_n_segments:
|
| 440 |
+
segments = self.transcript[-self.send_last_n_segments:].copy()
|
| 441 |
+
else:
|
| 442 |
+
segments = self.transcript.copy()
|
| 443 |
+
if last_segment is not None:
|
| 444 |
+
segments = segments + [last_segment]
|
| 445 |
+
return segments
|
| 446 |
+
|
| 447 |
+
def get_audio_chunk_duration(self, input_bytes):
|
| 448 |
+
"""
|
| 449 |
+
Calculates the duration of the provided audio chunk.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
float: The duration of the audio chunk in seconds.
|
| 456 |
+
"""
|
| 457 |
+
return input_bytes.shape[0] / self.RATE
|
| 458 |
+
|
| 459 |
+
def send_transcription_to_client(self, segments):
|
| 460 |
+
"""
|
| 461 |
+
Sends the specified transcription segments to the client over the websocket connection.
|
| 462 |
+
|
| 463 |
+
This method formats the transcription segments into a JSON object and attempts to send
|
| 464 |
+
this object to the client. If an error occurs during the send operation, it logs the error.
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
segments (list): A list of transcription segments to be sent to the client.
|
| 468 |
+
"""
|
| 469 |
+
try:
|
| 470 |
+
self.websocket.send(
|
| 471 |
+
json.dumps({
|
| 472 |
+
"uid": self.client_uid,
|
| 473 |
+
"segments": segments,
|
| 474 |
+
})
|
| 475 |
+
)
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logging.error(f"[ERROR]: Sending data to client: {e}")
|
| 478 |
+
|
| 479 |
+
def disconnect(self):
|
| 480 |
+
"""
|
| 481 |
+
Notify the client of disconnection and send a disconnect message.
|
| 482 |
+
|
| 483 |
+
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
| 484 |
+
that the transcription service is disconnecting gracefully.
|
| 485 |
+
|
| 486 |
+
"""
|
| 487 |
+
self.websocket.send(json.dumps({
|
| 488 |
+
"uid": self.client_uid,
|
| 489 |
+
"message": self.DISCONNECT
|
| 490 |
+
}))
|
| 491 |
+
|
| 492 |
+
def cleanup(self):
|
| 493 |
+
"""
|
| 494 |
+
Perform cleanup tasks before exiting the transcription service.
|
| 495 |
+
|
| 496 |
+
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
| 497 |
+
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
| 498 |
+
associated with the transcription process.
|
| 499 |
+
|
| 500 |
+
"""
|
| 501 |
+
logging.info("Cleaning up.")
|
| 502 |
+
self.exit = True
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class ServeClientWhisperCPP(ServeClientBase):
|
| 506 |
+
SINGLE_MODEL = None
|
| 507 |
+
SINGLE_MODEL_LOCK = threading.Lock()
|
| 508 |
+
|
| 509 |
+
def __init__(self, websocket, language=None, client_uid=None,
|
| 510 |
+
single_model=False):
|
| 511 |
+
"""
|
| 512 |
+
Initialize a ServeClient instance.
|
| 513 |
+
The Whisper model is initialized based on the client's language and device availability.
|
| 514 |
+
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
| 515 |
+
to the client to indicate that the server is ready.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
websocket (WebSocket): The WebSocket connection for the client.
|
| 519 |
+
language (str, optional): The language for transcription. Defaults to None.
|
| 520 |
+
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
| 521 |
+
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
| 522 |
+
|
| 523 |
+
"""
|
| 524 |
+
super().__init__(client_uid, websocket)
|
| 525 |
+
self.language = language
|
| 526 |
+
self.eos = False
|
| 527 |
+
|
| 528 |
+
if single_model:
|
| 529 |
+
if ServeClientWhisperCPP.SINGLE_MODEL is None:
|
| 530 |
+
self.create_model()
|
| 531 |
+
ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
|
| 532 |
+
else:
|
| 533 |
+
self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
|
| 534 |
+
else:
|
| 535 |
+
self.create_model()
|
| 536 |
+
|
| 537 |
+
# threading
|
| 538 |
+
logging.info('Create a thread to process audio.')
|
| 539 |
+
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
| 540 |
+
self.trans_thread.start()
|
| 541 |
+
|
| 542 |
+
self.websocket.send(json.dumps({
|
| 543 |
+
"uid": self.client_uid,
|
| 544 |
+
"message": self.SERVER_READY,
|
| 545 |
+
"backend": "pywhispercpp"
|
| 546 |
+
}))
|
| 547 |
+
|
| 548 |
+
def create_model(self, warmup=True):
|
| 549 |
+
"""
|
| 550 |
+
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
| 551 |
+
"""
|
| 552 |
+
model = 'medium-q5_0'
|
| 553 |
+
here = pathlib.Path(__file__)
|
| 554 |
+
models_dir = f'{here.parent.parent / "moyoyo_asr_models"}'
|
| 555 |
+
self.transcriber = Model(model=model, models_dir=models_dir)
|
| 556 |
+
if warmup:
|
| 557 |
+
self.warmup()
|
| 558 |
+
|
| 559 |
+
def warmup(self, warmup_steps=1):
|
| 560 |
+
"""
|
| 561 |
+
Warmup TensorRT since first few inferences are slow.
|
| 562 |
+
|
| 563 |
+
Args:
|
| 564 |
+
warmup_steps (int): Number of steps to warm up the model for.
|
| 565 |
+
"""
|
| 566 |
+
logging.info("[INFO:] Warming up whisper.cpp engine..")
|
| 567 |
+
mel, _, = soundfile.read("assets/jfk.flac")
|
| 568 |
+
for i in range(warmup_steps):
|
| 569 |
+
self.transcriber.transcribe(mel, print_progress=False)
|
| 570 |
+
|
| 571 |
+
def set_eos(self, eos):
|
| 572 |
+
"""
|
| 573 |
+
Sets the End of Speech (EOS) flag.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
eos (bool): The value to set for the EOS flag.
|
| 577 |
+
"""
|
| 578 |
+
self.lock.acquire()
|
| 579 |
+
self.eos = eos
|
| 580 |
+
self.lock.release()
|
| 581 |
+
|
| 582 |
+
def handle_transcription_output(self, last_segment, duration):
|
| 583 |
+
"""
|
| 584 |
+
Handle the transcription output, updating the transcript and sending data to the client.
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
| 588 |
+
of the possibility of word being truncated.
|
| 589 |
+
duration (float): Duration of the transcribed audio chunk.
|
| 590 |
+
"""
|
| 591 |
+
segments = self.prepare_segments({"text": last_segment})
|
| 592 |
+
self.send_transcription_to_client(segments)
|
| 593 |
+
if self.eos:
|
| 594 |
+
self.update_timestamp_offset(last_segment, duration)
|
| 595 |
+
|
| 596 |
+
def transcribe_audio(self, input_bytes):
|
| 597 |
+
"""
|
| 598 |
+
Transcribe the audio chunk and send the results to the client.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
input_bytes (np.array): The audio chunk to transcribe.
|
| 602 |
+
"""
|
| 603 |
+
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 604 |
+
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
|
| 605 |
+
logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
| 606 |
+
mel = input_bytes
|
| 607 |
+
duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
|
| 608 |
+
|
| 609 |
+
if self.language == "zh":
|
| 610 |
+
prompt = '以下是简体中文普通话的句子。'
|
| 611 |
+
else:
|
| 612 |
+
prompt = 'The following is an English sentence.'
|
| 613 |
+
|
| 614 |
+
segments = self.transcriber.transcribe(mel, language='zh', initial_prompt=prompt, print_progress=False)
|
| 615 |
+
text = []
|
| 616 |
+
for segment in segments:
|
| 617 |
+
content = segment.text
|
| 618 |
+
text.append(content)
|
| 619 |
+
last_segment = ' '.join(text)
|
| 620 |
+
|
| 621 |
+
logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
|
| 622 |
+
|
| 623 |
+
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 624 |
+
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
|
| 625 |
+
if last_segment:
|
| 626 |
+
self.handle_transcription_output(last_segment, duration)
|
| 627 |
+
|
| 628 |
+
def update_timestamp_offset(self, last_segment, duration):
|
| 629 |
+
"""
|
| 630 |
+
Update timestamp offset and transcript.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
last_segment (str): Last transcribed audio from the whisper model.
|
| 634 |
+
duration (float): Duration of the last audio chunk.
|
| 635 |
+
"""
|
| 636 |
+
if not len(self.transcript):
|
| 637 |
+
self.transcript.append({"text": last_segment + " "})
|
| 638 |
+
elif self.transcript[-1]["text"].strip() != last_segment:
|
| 639 |
+
self.transcript.append({"text": last_segment + " "})
|
| 640 |
+
|
| 641 |
+
logging.info(f'Transcript list context: {self.transcript}')
|
| 642 |
+
|
| 643 |
+
with self.lock:
|
| 644 |
+
self.timestamp_offset += duration
|
| 645 |
+
|
| 646 |
+
def speech_to_text(self):
|
| 647 |
+
"""
|
| 648 |
+
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
| 649 |
+
|
| 650 |
+
This method continuously receives audio frames, performs real-time transcription, and sends
|
| 651 |
+
transcribed segments to the client via a WebSocket connection.
|
| 652 |
+
|
| 653 |
+
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
| 654 |
+
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
| 655 |
+
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
| 656 |
+
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
| 657 |
+
there is no speech for a specified duration to indicate a pause.
|
| 658 |
+
|
| 659 |
+
Raises:
|
| 660 |
+
Exception: If there is an issue with audio processing or WebSocket communication.
|
| 661 |
+
|
| 662 |
+
"""
|
| 663 |
+
while True:
|
| 664 |
+
if self.exit:
|
| 665 |
+
logging.info("Exiting speech to text thread")
|
| 666 |
+
break
|
| 667 |
+
|
| 668 |
+
if self.frames_np is None:
|
| 669 |
+
time.sleep(0.02) # wait for any audio to arrive
|
| 670 |
+
continue
|
| 671 |
+
|
| 672 |
+
self.clip_audio_if_no_valid_segment()
|
| 673 |
+
|
| 674 |
+
input_bytes, duration = self.get_audio_chunk_for_processing()
|
| 675 |
+
if duration < 1:
|
| 676 |
+
continue
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
input_sample = input_bytes.copy()
|
| 680 |
+
logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
|
| 681 |
+
self.transcribe_audio(input_sample)
|
| 682 |
+
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logging.error(f"[ERROR]: {e}")
|
transcribe/utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import textwrap
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import av
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def clear_screen():
|
| 9 |
+
"""Clears the console screen."""
|
| 10 |
+
os.system("cls" if os.name == "nt" else "clear")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def print_transcript(text):
|
| 14 |
+
"""Prints formatted transcript text."""
|
| 15 |
+
wrapper = textwrap.TextWrapper(width=60)
|
| 16 |
+
for line in wrapper.wrap(text="".join(text)):
|
| 17 |
+
print(line)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def format_time(s):
|
| 21 |
+
"""Convert seconds (float) to SRT time format."""
|
| 22 |
+
hours = int(s // 3600)
|
| 23 |
+
minutes = int((s % 3600) // 60)
|
| 24 |
+
seconds = int(s % 60)
|
| 25 |
+
milliseconds = int((s - int(s)) * 1000)
|
| 26 |
+
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_srt_file(segments, resampled_file):
|
| 30 |
+
with open(resampled_file, 'w', encoding='utf-8') as srt_file:
|
| 31 |
+
segment_number = 1
|
| 32 |
+
for segment in segments:
|
| 33 |
+
start_time = format_time(float(segment['start']))
|
| 34 |
+
end_time = format_time(float(segment['end']))
|
| 35 |
+
text = segment['text']
|
| 36 |
+
|
| 37 |
+
srt_file.write(f"{segment_number}\n")
|
| 38 |
+
srt_file.write(f"{start_time} --> {end_time}\n")
|
| 39 |
+
srt_file.write(f"{text}\n\n")
|
| 40 |
+
|
| 41 |
+
segment_number += 1
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def resample(file: str, sr: int = 16000):
|
| 45 |
+
"""
|
| 46 |
+
Resample the audio file to 16kHz.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
file (str): The audio file to open
|
| 50 |
+
sr (int): The sample rate to resample the audio if necessary
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
resampled_file (str): The resampled audio file
|
| 54 |
+
"""
|
| 55 |
+
container = av.open(file)
|
| 56 |
+
stream = next(s for s in container.streams if s.type == 'audio')
|
| 57 |
+
|
| 58 |
+
resampler = av.AudioResampler(
|
| 59 |
+
format='s16',
|
| 60 |
+
layout='mono',
|
| 61 |
+
rate=sr,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
resampled_file = Path(file).stem + "_resampled.wav"
|
| 65 |
+
output_container = av.open(resampled_file, mode='w')
|
| 66 |
+
output_stream = output_container.add_stream('pcm_s16le', rate=sr)
|
| 67 |
+
output_stream.layout = 'mono'
|
| 68 |
+
|
| 69 |
+
for frame in container.decode(audio=0):
|
| 70 |
+
frame.pts = None
|
| 71 |
+
resampled_frames = resampler.resample(frame)
|
| 72 |
+
if resampled_frames is not None:
|
| 73 |
+
for resampled_frame in resampled_frames:
|
| 74 |
+
for packet in output_stream.encode(resampled_frame):
|
| 75 |
+
output_container.mux(packet)
|
| 76 |
+
|
| 77 |
+
for packet in output_stream.encode(None):
|
| 78 |
+
output_container.mux(packet)
|
| 79 |
+
|
| 80 |
+
output_container.close()
|
| 81 |
+
return resampled_file
|
transcribe/vad.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import onnxruntime
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VoiceActivityDetection():
|
| 11 |
+
|
| 12 |
+
def __init__(self, force_onnx_cpu=True):
|
| 13 |
+
path = self.download()
|
| 14 |
+
|
| 15 |
+
opts = onnxruntime.SessionOptions()
|
| 16 |
+
opts.log_severity_level = 3
|
| 17 |
+
|
| 18 |
+
opts.inter_op_num_threads = 1
|
| 19 |
+
opts.intra_op_num_threads = 1
|
| 20 |
+
|
| 21 |
+
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
| 22 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
| 23 |
+
else:
|
| 24 |
+
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
|
| 25 |
+
|
| 26 |
+
self.reset_states()
|
| 27 |
+
if '16k' in path:
|
| 28 |
+
warnings.warn('This model support only 16000 sampling rate!')
|
| 29 |
+
self.sample_rates = [16000]
|
| 30 |
+
else:
|
| 31 |
+
self.sample_rates = [8000, 16000]
|
| 32 |
+
|
| 33 |
+
def _validate_input(self, x, sr: int):
|
| 34 |
+
if x.dim() == 1:
|
| 35 |
+
x = x.unsqueeze(0)
|
| 36 |
+
if x.dim() > 2:
|
| 37 |
+
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
| 38 |
+
|
| 39 |
+
if sr != 16000 and (sr % 16000 == 0):
|
| 40 |
+
step = sr // 16000
|
| 41 |
+
x = x[:, ::step]
|
| 42 |
+
sr = 16000
|
| 43 |
+
|
| 44 |
+
if sr not in self.sample_rates:
|
| 45 |
+
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
| 46 |
+
if sr / x.shape[1] > 31.25:
|
| 47 |
+
raise ValueError("Input audio chunk is too short")
|
| 48 |
+
|
| 49 |
+
return x, sr
|
| 50 |
+
|
| 51 |
+
def reset_states(self, batch_size=1):
|
| 52 |
+
self._state = torch.zeros((2, batch_size, 128)).float()
|
| 53 |
+
self._context = torch.zeros(0)
|
| 54 |
+
self._last_sr = 0
|
| 55 |
+
self._last_batch_size = 0
|
| 56 |
+
|
| 57 |
+
def __call__(self, x, sr: int):
|
| 58 |
+
|
| 59 |
+
x, sr = self._validate_input(x, sr)
|
| 60 |
+
num_samples = 512 if sr == 16000 else 256
|
| 61 |
+
|
| 62 |
+
if x.shape[-1] != num_samples:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
| 65 |
+
|
| 66 |
+
batch_size = x.shape[0]
|
| 67 |
+
context_size = 64 if sr == 16000 else 32
|
| 68 |
+
|
| 69 |
+
if not self._last_batch_size:
|
| 70 |
+
self.reset_states(batch_size)
|
| 71 |
+
if (self._last_sr) and (self._last_sr != sr):
|
| 72 |
+
self.reset_states(batch_size)
|
| 73 |
+
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
| 74 |
+
self.reset_states(batch_size)
|
| 75 |
+
|
| 76 |
+
if not len(self._context):
|
| 77 |
+
self._context = torch.zeros(batch_size, context_size)
|
| 78 |
+
|
| 79 |
+
x = torch.cat([self._context, x], dim=1)
|
| 80 |
+
if sr in [8000, 16000]:
|
| 81 |
+
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
| 82 |
+
ort_outs = self.session.run(None, ort_inputs)
|
| 83 |
+
out, state = ort_outs
|
| 84 |
+
self._state = torch.from_numpy(state)
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError()
|
| 87 |
+
|
| 88 |
+
self._context = x[..., -context_size:]
|
| 89 |
+
self._last_sr = sr
|
| 90 |
+
self._last_batch_size = batch_size
|
| 91 |
+
|
| 92 |
+
out = torch.from_numpy(out)
|
| 93 |
+
return out
|
| 94 |
+
|
| 95 |
+
def audio_forward(self, x, sr: int):
|
| 96 |
+
outs = []
|
| 97 |
+
x, sr = self._validate_input(x, sr)
|
| 98 |
+
self.reset_states()
|
| 99 |
+
num_samples = 512 if sr == 16000 else 256
|
| 100 |
+
|
| 101 |
+
if x.shape[1] % num_samples:
|
| 102 |
+
pad_num = num_samples - (x.shape[1] % num_samples)
|
| 103 |
+
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
|
| 104 |
+
|
| 105 |
+
for i in range(0, x.shape[1], num_samples):
|
| 106 |
+
wavs_batch = x[:, i:i + num_samples]
|
| 107 |
+
out_chunk = self.__call__(wavs_batch, sr)
|
| 108 |
+
outs.append(out_chunk)
|
| 109 |
+
|
| 110 |
+
stacked = torch.cat(outs, dim=1)
|
| 111 |
+
return stacked.cpu()
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
|
| 115 |
+
target_dir = os.path.expanduser("~/.cache/silero-vad/")
|
| 116 |
+
|
| 117 |
+
# Ensure the target directory exists
|
| 118 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
# Define the target file path
|
| 121 |
+
model_filename = os.path.join(target_dir, "silero_vad.onnx")
|
| 122 |
+
|
| 123 |
+
# Check if the model file already exists
|
| 124 |
+
if not os.path.exists(model_filename):
|
| 125 |
+
# If it doesn't exist, download the model using wget
|
| 126 |
+
try:
|
| 127 |
+
# subprocess.run(["wget", "-O", model_filename, model_url], check=True)
|
| 128 |
+
subprocess.run(["curl", "-sL", "-o", model_filename, model_url], check=True)
|
| 129 |
+
except subprocess.CalledProcessError:
|
| 130 |
+
print("Failed to download the model using wget.")
|
| 131 |
+
return model_filename
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class VoiceActivityDetector:
|
| 135 |
+
def __init__(self, threshold=0.5, frame_rate=16000):
|
| 136 |
+
"""
|
| 137 |
+
Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
|
| 141 |
+
"""
|
| 142 |
+
self.model = VoiceActivityDetection()
|
| 143 |
+
self.threshold = threshold
|
| 144 |
+
self.frame_rate = frame_rate
|
| 145 |
+
|
| 146 |
+
def __call__(self, audio_frame):
|
| 147 |
+
"""
|
| 148 |
+
Determines if the given audio frame contains speech by comparing the detected speech probability against
|
| 149 |
+
the threshold.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
|
| 153 |
+
NumPy array of audio samples.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
|
| 157 |
+
False otherwise.
|
| 158 |
+
"""
|
| 159 |
+
speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
|
| 160 |
+
return torch.any(speech_probs > self.threshold).item()
|