
import sys
import os
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio

# Define the model and processor from Hugging Face
model_name = "nalini2799/CDAC_hindispeechrecognition"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

# Ensure the model is on CPU, as we're not using CUDA
model.to("cpu")

# Get the audio file path from command line arguments
if len(sys.argv) < 2:
    sys.stderr.write("Usage: python3.11 hindi_stt.py <audio_file_path>\n")
    sys.exit(1)

audio_file_path = sys.argv[1]

if not os.path.exists(audio_file_path):
    sys.stderr.write(f"Error: Audio file not found at {audio_file_path}\n")
    sys.exit(1)

try:
    # Load the audio file
    speech, sample_rate = torchaudio.load(audio_file_path)

    # Resample if necessary (Wav2Vec2 models typically expect 16kHz)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        speech = resampler(speech)

    # Ensure mono audio (remove batch dimension for single file)
    if speech.ndim > 1:
        speech = speech.mean(dim=0)
    
    # Process the audio
    input_values = processor(speech.squeeze(), sampling_rate=16000, return_tensors="pt").input_values
    
    # Perform inference
    with torch.no_grad():
        logits = model(input_values).logits

    # Get the predicted IDs and decode to text
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])

    sys.stdout.write(transcription + "\n")

except Exception as e:
    sys.stderr.write(f"Error during Hindi STT transcription: {e}\n")
    sys.exit(1)
