Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples: GUI and "Advanced" CLI #10

Closed
AcTePuKc opened this issue Feb 28, 2025 · 20 comments
Closed

Examples: GUI and "Advanced" CLI #10

AcTePuKc opened this issue Feb 28, 2025 · 20 comments

Comments

@AcTePuKc
Copy link

AcTePuKc commented Feb 28, 2025

Want CLI?

Running this as Python in the main folder while had the conda env active - produces more accurate speech patterns for some reason for me

import os
import torch
import numpy as np
import soundfile as sf
import logging
from datetime import datetime
from cli.SparkTTS import SparkTTS

def generate_tts_audio(
    text,
    model_dir="pretrained_models/Spark-TTS-0.5B",
    device="cuda:0",
    prompt_speech_path=None,
    prompt_text=None,
    gender=None,
    pitch=None,
    speed=None,
    save_dir="example/results",
    segmentation_threshold=150 #Do not go above this if you want to crash or you have better GPU
):
    """
    Generates TTS audio from input text, splitting into segments if necessary.

    Args:
        text (str): Input text for speech synthesis.
        model_dir (str): Path to the model directory.
        device (str): Device identifier (e.g., "cuda:0" or "cpu").
        prompt_speech_path (str, optional): Path to prompt audio for cloning.
        prompt_text (str, optional): Transcript of prompt audio.
        gender (str, optional): Gender parameter ("male"/"female").
        pitch (str, optional): Pitch parameter (e.g., "moderate").
        speed (str, optional): Speed parameter (e.g., "moderate").
        save_dir (str): Directory where generated audio will be saved.
        segmentation_threshold (int): Maximum number of words per segment.

    Returns:
        str: The unique file path where the generated audio is saved.
    """
    logging.info("Initializing TTS model...")
    device = torch.device(device)
    model = SparkTTS(model_dir, device)
    
    # Ensure the save directory exists.
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    save_path = os.path.join(save_dir, f"{timestamp}.wav")

    # Check if the text is too long.
    words = text.split()
    if len(words) > segmentation_threshold:
        logging.info("Input text exceeds threshold; splitting into segments...")
        segments = [' '.join(words[i:i+segmentation_threshold]) for i in range(0, len(words), segmentation_threshold)]
        wavs = []
        for seg in segments:
            with torch.no_grad():
                wav = model.inference(
                    seg,
                    prompt_speech_path,
                    prompt_text=prompt_text,
                    gender=gender,
                    pitch=pitch,
                    speed=speed
                )
            wavs.append(wav)
        final_wav = np.concatenate(wavs, axis=0)
    else:
        with torch.no_grad():
            final_wav = model.inference(
                text,
                prompt_speech_path,
                prompt_text=prompt_text,
                gender=gender,
                pitch=pitch,
                speed=speed
            )
    
    # Save the generated audio.
    sf.write(save_path, final_wav, samplerate=16000)
    logging.info(f"Audio saved at: {save_path}")
    return save_path

# Example usage:
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    
    # Sample input (feel free to adjust)
    sample_text = (
        "The mind that opens to a new idea never returns to its original size. "
        "Hellstrom’s Hive: Chapter 1 – The Awakening. Mara Vance stirred from a deep, dreamless sleep, "
        "her consciousness surfacing like a diver breaking through the ocean's surface. "
        "A dim, amber light filtered through her closed eyelids, warm and pulsing softly. "
        "She hesitated to open her eyes, savoring the fleeting peace before reality set in. "
        "A cool, earthy scent filled her nostrils—damp soil mingled with something sweet and metallic. "
        "The air was thick, almost humid, carrying with it a faint vibration that resonated in her bones. "
        "It wasn't just a sound; it was a presence. "
        "Her eyelids fluttered open. Above her stretched a ceiling unlike any she'd seen—organic and alive, "
        "composed of interwoven tendrils that glowed with the same amber light. They pulsated gently, "
        "like the breathing of some colossal creature. Shadows danced between the strands, creating shifting patterns."
    )
    
    # Call the function (adjust parameters as needed)
    output_file = generate_tts_audio(
        sample_text,
        gender="male",
        pitch="moderate",
        speed="moderate"
    )
    print("Generated audio file:", output_file)

Better GUI

And GUI if someone wants one - it's Light weight - same as if you run trough CLI - at least on 3060 it runs normal - combines text - but it will crash if you place a ton of text unfortunately
it requires just to

 pip install pyside6 

🔹 Buttons?!?

  • Text Input: The big text box where you enter text to be converted into speech.
  • Load Voice Sample: Loads a voice sample (MP3/WAV) for RVC-like functionality, allowing voice transformation.
  • Reset Voice Sample: Clears the loaded voice sample, letting you switch back to gender-based synthesis without restarting the app.
  • Gender Selection Dropdown:
    • If using Spark-TTS, select "Male" or "Female" for a generated voice.
    • If left on "Auto," Spark-TTS will fail.
    • Takes a few seconds to generate before synthesis starts.
  • Generate Speech: Starts generating speech based on the entered text and selected parameters.
  • Play: Plays the last generated audio file.
  • Stop: Stops playback.
  • Save Audio: Saves the last generated audio to a file.
  • Word Count: That thing that count words.

😎

Image

Image

Image

Image

import sys
import os
import time
import torch
import shutil
import numpy as np
import soundfile as sf
from PySide6.QtWidgets import (
    QApplication, QWidget, QVBoxLayout, QPushButton, QLabel,
    QTextEdit, QSlider, QFileDialog, QComboBox, QHBoxLayout
)
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
from PySide6.QtGui import QPainter, QColor, QPen, QIcon
from cli.SparkTTS import SparkTTS

# --- Worker Thread for TTS Generation (with segmentation support) ---
class TTSWorker(QThread):
    result_ready = Signal(object, float)  # Emits (final result, generation_time)
    progress_update = Signal(int, int)      # Emits (current_segment, total_segments)

    def __init__(self, model, text, voice_sample, gender, pitch, speed):
        """
        text: Either a string or a list of strings (segments).
        """
        super().__init__()
        self.model = model
        self.text = text
        self.voice_sample = voice_sample
        self.gender = gender
        self.pitch = pitch
        self.speed = speed
        

    def run(self):
        start = time.time()
        try:
            results = []
            if isinstance(self.text, list):
                total = len(self.text)
                for i, segment in enumerate(self.text):
                    with torch.no_grad():
                        wav = self.model.inference(
                            segment,
                            prompt_speech_path=self.voice_sample,
                            gender=self.gender,
                            pitch=self.pitch,
                            speed=self.speed
                        )
                    results.append(wav)
                    self.progress_update.emit(i + 1, total)
                final_wav = np.concatenate(results, axis=0)
            else:
                with torch.no_grad():
                    final_wav = self.model.inference(
                        self.text,
                        prompt_speech_path=self.voice_sample,
                        gender=self.gender,
                        pitch=self.pitch,
                        speed=self.speed
                    )
                self.progress_update.emit(1, 1)
            elapsed = time.time() - start
            self.result_ready.emit(final_wav, elapsed)
        except Exception as e:
            self.result_ready.emit(e, 0)

# --- Waveform Visualization Widget ---
class WaveformWidget(QWidget):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.progress = 0.0  # Range: 0.0 to 1.0

    def set_progress(self, progress):
        self.progress = progress
        self.update()

    def paintEvent(self, event):
        painter = QPainter(self)
        painter.fillRect(self.rect(), QColor("black"))
        pen = QPen(QColor("green"))
        pen.setWidth(5)
        painter.setPen(pen)
        painter.drawLine(0, self.height() // 2, int(self.width() * self.progress), self.height() // 2)

# --- Main Application Class ---
class SparkTTSApp(QWidget):
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.voice_sample = None
        self.current_audio_file = None
        self.total_duration = 0
        self.init_ui()
        self.status_label.setText(f"Model loaded on {device}")

        # Set app icon if available.
        icon_path = "src/logo.webp"
        if os.path.exists(icon_path):
            self.setWindowIcon(QIcon(icon_path))  # Set app icon if found.

        # Initialize audio player and output.    
        self.audio_player = QMediaPlayer()
        self.audio_output = QAudioOutput()
        self.audio_player.setAudioOutput(self.audio_output)
        self.audio_player.positionChanged.connect(self.on_position_changed)
        self.audio_player.durationChanged.connect(self.on_duration_changed)

    def init_ui(self):
        self.setWindowTitle("Spark-TTS GUI")
        self.setMinimumSize(600, 400)
        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(15, 15, 15, 15)
        
        # Text input.
        self.text_input = QTextEdit()
        self.text_input.setPlaceholderText("Enter text for speech synthesis...")
        main_layout.addWidget(self.text_input)

        self.word_count_label = QLabel("Word Count: 0")
        main_layout.addWidget(self.word_count_label)

        self.text_input.textChanged.connect(self.update_word_count)

        
        
        btn_layout = QHBoxLayout()
        self.voice_btn = QPushButton("Load Voice Sample")
        self.voice_btn.clicked.connect(self.select_voice_sample)
        self.reset_voice_btn = QPushButton("Reset Voice Sample")
        self.reset_voice_btn.clicked.connect(self.reset_voice_sample)
        self.generate_btn = QPushButton("Generate Speech")
        self.generate_btn.clicked.connect(self.run_synthesis)
        btn_layout.addWidget(self.voice_btn)
        btn_layout.addWidget(self.reset_voice_btn)
        btn_layout.addWidget(self.generate_btn)
        main_layout.addLayout(btn_layout)
        
        # Controls Layout (only Gender, Pitch, and Speed).
        controls_layout = QHBoxLayout()
        self.gender_selector = QComboBox()
        self.gender_selector.addItems(["Auto", "Male", "Female"])
        controls_layout.addWidget(QLabel("Gender:"))
        controls_layout.addWidget(self.gender_selector)
        
        self.pitch_slider, pitch_layout = self.create_slider_with_value("Pitch")
        controls_layout.addLayout(pitch_layout)
        
        self.speed_slider, speed_layout = self.create_slider_with_value("Speed")
        controls_layout.addLayout(speed_layout)
        
        main_layout.addLayout(controls_layout)
        
        # Audio controls layout.
        audio_controls = QHBoxLayout()
        self.play_btn = QPushButton("Play")
        self.play_btn.clicked.connect(self.play_audio)
        self.stop_btn = QPushButton("Stop")
        self.stop_btn.clicked.connect(self.stop_audio)
        self.save_btn = QPushButton("Save Audio")
        self.save_btn.clicked.connect(self.save_audio)
        audio_controls.addWidget(self.play_btn)
        audio_controls.addWidget(self.stop_btn)
        audio_controls.addWidget(self.save_btn)
        main_layout.addLayout(audio_controls)
        
        # Status bar.
        self.status_label = QLabel("Ready")
        self.status_label.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(self.status_label)
        
        # Waveform visualization widget.
        self.waveform = WaveformWidget()
        main_layout.addWidget(self.waveform)
        
        self.setLayout(main_layout)

    def create_slider_with_value(self, label_text):
        from PySide6.QtWidgets import QVBoxLayout
        layout = QVBoxLayout()
        label = QLabel(label_text)
        slider = QSlider(Qt.Horizontal)
        slider.setRange(0, 4)
        slider.setValue(2)
        value_label = QLabel("2")
        slider.valueChanged.connect(lambda val: value_label.setText(str(val)))
        layout.addWidget(label)
        layout.addWidget(slider)
        layout.addWidget(value_label)
        # Descriptive text under the slider.
        desc_label = QLabel(f"Adjust {label_text.lower()} level")
        layout.addWidget(desc_label)
        return slider, layout
    
    def update_word_count(self):
        """Updates the word count dynamically as the user types."""
        text = self.text_input.toPlainText().strip()
        word_count = len(text.split()) if text else 0
        self.word_count_label.setText(f"Word Count: {word_count}")

    def reset_voice_sample(self):
        """Clears the loaded voice sample and restores gender selection."""
        self.voice_sample = None
        self.gender_selector.setEnabled(True)
        self.status_label.setText("Voice sample cleared. You can now use gender selection.")

    def select_voice_sample(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Select Voice Sample", "", "Audio Files (*.wav *.mp3)"
        )
        if file_path:
            self.voice_sample = file_path
            self.status_label.setText(f"Loaded voice sample: {os.path.basename(file_path)}")

    def save_audio(self):
        if not (self.current_audio_file and os.path.exists(self.current_audio_file)):
            self.status_label.setText("No audio to save!")
            return
        save_path, _ = QFileDialog.getSaveFileName(
            self, "Save Audio", "", "WAV Files (*.wav)"
        )
        if save_path:
            shutil.copy(self.current_audio_file, save_path)
            self.status_label.setText(f"Audio saved to: {os.path.basename(save_path)}")

    def play_audio(self):
        if self.current_audio_file and os.path.exists(self.current_audio_file):
            self.audio_player.setSource(self.current_audio_file)
            self.audio_player.play()

    def stop_audio(self):
        self.audio_player.stop()

    def on_duration_changed(self, duration):
        self.total_duration = duration

    def on_position_changed(self, position):
        if self.total_duration > 0:
            progress = position / self.total_duration
            self.waveform.set_progress(progress)

    def run_synthesis(self):
        text = self.text_input.toPlainText().strip()
        if not text:
            self.status_label.setText("Please enter some text!")
            return

        # Segmentation: Limit each segment to 150 words.
        segmentation_threshold = 150
        words = text.split()
        if len(words) > segmentation_threshold:
            text_to_process = [
                ' '.join(words[i:i + segmentation_threshold])
                for i in range(0, len(words), segmentation_threshold)
            ]
            self.status_label.setText("Text too long: processing segments...")
        else:
            text_to_process = text

        # Determine parameters based on whether a voice sample is loaded.
        if self.voice_sample is not None:
            prompt = self.voice_sample
            gender = None
            pitch = None
            speed = None
        else:
            prompt = None
            gender = self.gender_selector.currentText().lower()
            gender = None if gender == "auto" else gender
            pitch_map = ["very_low", "low", "moderate", "high", "very_high"]
            speed_map = ["very_low", "low", "moderate", "high", "very_high"]
            pitch = pitch_map[self.pitch_slider.value()]
            speed = speed_map[self.speed_slider.value()]

        self.generate_btn.setEnabled(False)
        self.status_label.setText("Generating speech...")

        self.worker = TTSWorker(self.model, text_to_process, prompt, gender, pitch, speed)
        self.worker.progress_update.connect(self.on_generation_progress)
        self.worker.result_ready.connect(self.on_generation_complete)
        self.worker.start()

    def on_generation_progress(self, current, total):
        self.status_label.setText(f"Generating segment {current} / {total}...")

    def on_generation_complete(self, result, elapsed):
        if isinstance(result, Exception):
            self.status_label.setText(f"Error: {result}")
        else:
            filename = f"output_{int(time.time())}.wav"
            sf.write(filename, result, samplerate=16000)
            self.current_audio_file = filename
            self.status_label.setText(f"Generated in {elapsed:.1f}s | Saved to {filename}")
        self.generate_btn.setEnabled(True)

if __name__ == "__main__":
    app = QApplication(sys.argv)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SparkTTS("pretrained_models/Spark-TTS-0.5B", device=device)
    window = SparkTTSApp(model, device.type.upper())
    window.show()
    sys.exit(app.exec())

I've got side-tracked - and I'm updating it to look better - but this version works - it's slower than the CLI, but hey you type it speaks.

Cheers.

@xinshengwang
Copy link
Member

@AcTePuKc I’ve updated the README to include an Optional Methods section that links to your CLI/Web UI implementation. Thanks for sharing! 🚀

@AcTePuKc
Copy link
Author

AcTePuKc commented Feb 28, 2025

This one is even advanced colors and stuff.

# Save in the main directory 
'''
SparkTTS Studio - GUI Description (Concise)

SparkTTS Studio is a user-friendly interface for text-to-speech synthesis, featuring voice cloning capabilities.

Key Features:

Text Input: Large text area to enter text for speech synthesis.

Voice Sample Loading: Load WAV/MP3 files to enable voice cloning and mimic voice styles.

Gender Selection: Choose "Male," "Female," or "Pick Voice (Gender optional)" for voice synthesis.

Pitch & Speed Control: Sliders to adjust voice pitch and speaking speed.

Interactive Waveform: Visual display of generated audio with playback progress and clickable seeking.

Playback Controls: "Play/Pause," "Stop," and Volume slider for audio playback.

Save Audio: Save generated speech to WAV files.

Multi-Language GUI: Switch UI language between English, Bulgarian, Spanish, French, and Japanese.

Status Bar: Displays messages about model loading, generation progress, and errors.

System Tray: Option to minimize to system tray for background operation.

Workflow Highlights:

Enter Text: Input the text you want to convert to speech.

Optional Voice Cloning: Load a voice sample to clone a voice.

Set Voice Parameters: Adjust gender, pitch, and speed.

Generate Speech: Click "Generate Speech" to synthesize audio.

Playback & Seek: Use playback controls and click on the waveform to navigate audio.

Save Audio: Save the generated audio to a WAV file.
'''
import traceback
import sys
import os
import time
import torch
import shutil
import numpy as np
import soundfile as sf
from PySide6.QtWidgets import (
    QApplication, QWidget, QVBoxLayout, QPushButton, QLabel,
    QTextEdit, QSlider, QFileDialog, QComboBox, QHBoxLayout,
    QGroupBox, QProgressBar, QSystemTrayIcon, QMenu, QSizePolicy
)
from PySide6.QtCore import Qt, QThread, Signal, QPoint, QTimer, QCoreApplication
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
from PySide6.QtGui import QPainter, QColor, QPen, QLinearGradient, QIcon, QAction
from cli.SparkTTS import SparkTTS

# ------------------- Modern Style Sheet -------------------
STYLE_SHEET = """
QWidget {
    background-color: #2D2D2D;
    color: #FFFFFF;
    font-family: 'Segoe UI';
    font-size: 12px;
}

QTextEdit {
    background-color: #404040;
    border: 2px solid #505050;
    border-radius: 5px;
    padding: 8px;
    selection-background-color: #3DAEE9;
}

QPushButton {
    background-color: #3DAEE9;
    border: none;
    border-radius: 4px;
    color: white;
    padding: 8px 16px;
    min-width: 80px;
}

QPushButton:hover {
    background-color: #2D9CDB;
}

QPushButton:disabled {
    background-color: #505050;
    color: #808080;
}

QSlider::groove:horizontal {
    height: 6px;
    background: #404040;
    border-radius: 3px;
}

QSlider::handle:horizontal {
    background: #3DAEE9;
    border: 2px solid #2D2D2D;
    width: 16px;
    margin: -6px 0;
    border-radius: 8px;
}

QComboBox {
    background-color: #404040;
    border: 2px solid #505050;
    border-radius: 4px;
    padding: 4px;
    min-width: 100px;
}

QGroupBox {
    border: 2px solid #505050;
    border-radius: 6px;
    margin-top: 10px;
    padding-top: 15px;
    color: #FFFFFF; /* Added to ensure title text is white */
}

QGroupBox::title {
    subcontrol-origin: margin;
    left: 10px;
    padding: 0 5px;
    color: #FFFFFF; /* Added to ensure title text is white */
}
"""

# ------------------- Enhanced Waveform Widget -------------------
class WaveformWidget(QWidget):
    seek_position_signal = Signal(int) # New signal to emit seek position in milliseconds

    def __init__(self, parent=None):
        super().__init__(parent)
        self.waveform_data = None
        self.playback_progress = 0.0
        self.playhead_progress = 0.0
        self.audio_player_duration_ms = 0  # ADDED: Store audio duration here!
        self.setMinimumHeight(100)
        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        self.setMouseTracking(True) # Enable mouse tracking for click events

    def set_waveform(self, data):
        self.waveform_data = data
        self.update()

    def set_playback_progress_overlay(self, progress):
        self.playback_progress = progress
        self.update()

    def set_playhead_progress(self, progress):
        self.playhead_progress = progress
        self.update()

    def mousePressEvent(self, event):
        if self.waveform_data is not None and self.audio_player_duration_ms > 0: # Now it has the attribute!
            click_x = event.position().x()
            progress_ratio = click_x / self.width()
            seek_position_ms = int(self.audio_player_duration_ms * progress_ratio)
            self.seek_position_signal.emit(seek_position_ms)

    def set_audio_duration(self, duration_ms): # NEW: Method to set duration
        self.audio_player_duration_ms = duration_ms


    def paintEvent(self, event):
        painter = QPainter(self)
        rect = self.rect()

        # Draw background, waveform, progress overlay, playhead (No changes in paintEvent itself)
        # ... (rest of paintEvent code is the same as before) ...
        # Draw background gradient
        gradient = QLinearGradient(0, 0, 0, rect.height())
        gradient.setColorAt(0, QColor("#363636"))
        gradient.setColorAt(1, QColor("#2D2D2D"))
        painter.fillRect(rect, gradient)

        if self.waveform_data is not None and len(self.waveform_data) > 0:
            # Normalize waveform data
            normalized_waveform = self.waveform_data / np.max(np.abs(self.waveform_data)) if np.max(np.abs(self.waveform_data)) > 0 else self.waveform_data

            # Draw waveform
            pen = QPen(QColor("#3DAEE9"))
            pen.setWidth(2)
            painter.setPen(pen)

            num_samples = len(normalized_waveform)
            step = max(1, num_samples // rect.width())
            center_y = rect.height() / 2

            for x in range(rect.width()):
                idx = min(int(x * step), num_samples - 1)
                sample_value = normalized_waveform[idx]
                value_pixel_height = int(abs(sample_value) * center_y * 0.95)
                y1 = int(center_y - value_pixel_height)
                y2 = int(center_y + value_pixel_height)
                y1 = max(0, min(int(y1), rect.height()))
                y2 = max(0, min(int(y2), rect.height()))
                painter.drawLine(x, y1, x, y2)

        # Draw playback progress overlay
        painter.setCompositionMode(QPainter.CompositionMode_SourceOver)
        progress_width = int(rect.width() * self.playback_progress)
        progress_rect = rect.adjusted(0, 0, progress_width - rect.width(), 0)
        painter.fillRect(progress_rect, QColor(61, 174, 233, 80))

        # Draw playhead
        if self.playhead_progress > 0:
            playhead_x = int(rect.width() * self.playhead_progress)
            playhead_pen = QPen(QColor("white"))
            playhead_pen.setWidth(2)
            painter.setPen(playhead_pen)
            painter.drawLine(playhead_x, 0, playhead_x, rect.height())

    def mousePressEvent(self, event): # New mouse click event handler
        if self.waveform_data is not None and self.audio_player_duration_ms > 0: # Check if waveform data and duration are available
            click_x = event.position().x()
            progress_ratio = click_x / self.width()
            seek_position_ms = int(self.audio_player_duration_ms * progress_ratio)
            self.seek_position_signal.emit(seek_position_ms) # Emit signal with seek position

# ------------------- Main Application Class -------------------
class SparkTTSApp(QWidget):
    def __init__(self, model, device):
        super().__init__() # Corrected super().__init__() call - no arguments
        self.model = model
        self.device = device # Store device as instance attribute
        self.voice_sample = None
        self.current_audio_file = None
        self.audio_player_duration_ms = 0 # Store audio player duration in milliseconds
        self.init_ui()
        self.init_tray_icon()
        self.status_label.setText(f"Model loaded on {device}")
        self.audio_player = QMediaPlayer()
        self.audio_output = QAudioOutput()
        self.audio_player.setAudioOutput(self.audio_output)
        self.audio_player.positionChanged.connect(self.update_waveform_playhead)
        self.audio_player.positionChanged.connect(self.on_position_changed)
        self.audio_player.durationChanged.connect(self.on_duration_changed)
        self.timer = QTimer(self)
        self.timer.timeout.connect(self.update_time_display)
        self.setWindowIcon(QIcon("src/logo.webp"))  # Add actual icon file
        self.waveform.seek_position_signal.connect(self.seek_audio) # Connect seek_position_signal to seek_audio

    def init_ui(self):
        # Window title - needs translation
        self.setWindowTitle("SparkTTS Studio")
        self.setMinimumSize(800, 600)
        self.setStyleSheet(STYLE_SHEET)

        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(20, 20, 20, 20)
        main_layout.setSpacing(15)

        # GUI Language Selector (Placed at the Top Right)
        top_bar_layout = QHBoxLayout()
        top_bar_layout.addStretch()  # Push to the right
        # Language Label - needs translation
        top_bar_layout.addWidget(QLabel("Language:"))
        self.language_label = top_bar_layout.itemAt(
            1).widget()  # Store language label
        self.gui_language_selector = QComboBox()
        self.gui_language_selector.addItems(
            ["English", "Bulgarian", "Spanish", "French", "Japanese"])
        self.gui_language_selector.currentIndexChanged.connect(
            self.update_gui_language)
        top_bar_layout.addWidget(self.gui_language_selector)
        main_layout.addLayout(top_bar_layout)

        # Text Input Group
        # Group Box Title - needs translation
        input_group = QGroupBox("Text Input")
        self.input_group = input_group  # Store group box for translation
        input_layout = QVBoxLayout()
        self.text_input = QTextEdit()
        # Placeholder - needs translation
        self.text_input.setPlaceholderText(
            "Enter text for speech synthesis...")
        input_layout.addWidget(self.text_input)
        input_group.setLayout(input_layout)
        main_layout.addWidget(input_group)

        # Controls Group
        # Group Box Title - needs translation
        controls_group = QGroupBox("Synthesis Controls")
        self.controls_group = controls_group  # Store group box for translation
        controls_layout = QVBoxLayout()

        # Voice Sample Section
        voice_layout = QHBoxLayout()
        # Label - needs translation
        self.voice_label = QLabel("No voice sample loaded")
        self.voice_label_status = self.voice_label  # Store label for translation
        # Button text - needs translation
        self.voice_btn = QPushButton("Load Voice Sample")
        self.voice_btn_load_voice = self.voice_btn  # Store button for translation
        self.voice_btn.clicked.connect(self.select_voice_sample)
        self.voice_btn.setIcon(QIcon())  # Add actual icon file
        voice_layout.addWidget(self.voice_btn)
        voice_layout.addWidget(self.voice_label)
        controls_layout.addLayout(voice_layout)

        # Parameters
        params_layout = QHBoxLayout()

        # Gender Selector Group
        # Group Box Title - needs translation
        gender_box = QGroupBox("Voice Parameters")
        self.gender_box = gender_box  # Store group box for translation
        gender_layout = QVBoxLayout()
        self.gender_selector_label = QLabel(
            "Gender:")  # Label - needs translation
        # Label - needs translation
        gender_layout.addWidget(self.gender_selector_label)
        # Store label for translation
        self.gender_selector_label_widget = self.gender_selector_label
        self.gender_selector = QComboBox()
        self.gender_selector.addItems(
            # Items - need translation
            ["Pick Voice (Gender optional)", "Male", "Female"])
        # Store items for translation
        self.gender_selector_items = [
            "Pick Voice (Gender optional)", "Male", "Female"]
        self.gender_selector.currentIndexChanged.connect(
            self.on_gender_changed)
        gender_layout.addWidget(self.gender_selector)
        gender_box.setLayout(gender_layout)
        params_layout.addWidget(gender_box)

        # Pitch Control Group
        pitch_box = QGroupBox("Pitch")  # Group Box Title - needs translation
        self.pitch_box = pitch_box  # Store group box for translation
        pitch_layout = QVBoxLayout()
        self.pitch_slider = QSlider(Qt.Horizontal)
        self.pitch_slider.setRange(0, 4)
        self.pitch_slider.setValue(2)
        pitch_layout.addWidget(self.pitch_slider)
        pitch_box.setLayout(pitch_layout)
        params_layout.addWidget(pitch_box)

        # Speed Control Group
        speed_box = QGroupBox("Speed")  # Group Box Title - needs translation
        self.speed_box = speed_box  # Store group box for translation
        speed_layout = QVBoxLayout()
        self.speed_slider = QSlider(Qt.Horizontal)
        self.speed_slider.setRange(0, 4)
        self.speed_slider.setValue(2)
        speed_layout.addWidget(self.speed_slider)
        speed_box.setLayout(speed_layout)
        params_layout.addWidget(speed_box)

        controls_layout.addLayout(params_layout)
        controls_group.setLayout(controls_layout)
        main_layout.addWidget(controls_group)

        # Visualization and Playback Group
        # Group Box Title - needs translation
        vis_group = QGroupBox("Audio Visualization")
        self.vis_group = vis_group  # Store group box for translation
        vis_layout = QVBoxLayout()
        self.waveform = WaveformWidget()
        vis_layout.addWidget(self.waveform)

        # Time Display
        time_layout = QHBoxLayout()
        # Label - needs translation (though format is universal)
        self.current_time = QLabel("00:00")
        # Label - needs translation (though format is universal)
        self.total_time = QLabel("00:00")
        time_layout.addWidget(self.current_time)
        time_layout.addStretch()
        time_layout.addWidget(self.total_time)
        vis_layout.addLayout(time_layout)

        # Playback Controls
        playback_layout = QHBoxLayout()
        self.play_btn = QPushButton("Play")  # Button text - needs translation
        self.play_btn_play = self.play_btn  # Store button for translation
        self.play_btn.clicked.connect(self.play_audio)
        self.play_btn.setIcon(QIcon())

        self.pause_btn = QPushButton("Pause")
        self.pause_btn_pause = self.pause_btn # Store for translation
        self.pause_btn.clicked.connect(self.pause_audio) # Connect to pause_audio method
        self.pause_btn.setIcon(QIcon()) # Add pause icon if you have one

        self.stop_btn = QPushButton("Stop")  # Button text - needs translation
        self.play_btn_stop = self.stop_btn  # Store button for translation
        self.stop_btn.clicked.connect(self.stop_audio)
        self.stop_btn.setIcon(QIcon())
        self.volume_label = QLabel("Volume:")  # Label - needs translation
        self.volume_label_widget = self.volume_label  # Store label for translation
        self.volume_slider = QSlider(Qt.Horizontal)
        self.volume_slider.setRange(0, 100)
        self.volume_slider.setValue(100)
        self.volume_slider.valueChanged.connect(self.set_volume)
        playback_layout.addWidget(self.play_btn)
        playback_layout.addWidget(self.pause_btn)
        playback_layout.addWidget(self.stop_btn)
        playback_layout.addWidget(self.volume_label)
        playback_layout.addWidget(self.volume_slider)
        vis_layout.addLayout(playback_layout)

        vis_group.setLayout(vis_layout)
        main_layout.addWidget(vis_group)

        # Bottom Panel
        bottom_layout = QHBoxLayout()
        # Button text - needs translation
        self.generate_btn = QPushButton("Generate Speech")
        self.generate_btn_generate = self.generate_btn  # Store button for translation
        self.generate_btn.clicked.connect(self.run_synthesis)
        self.generate_btn.setIcon(QIcon())
        self.generate_btn.setEnabled(False)  # Initially disabled

        # Button text - needs translation
        self.save_btn = QPushButton("Save Audio")
        self.save_btn_save = self.save_btn  # Store button for translation
        self.save_btn.clicked.connect(self.save_audio)
        self.save_btn.setIcon(QIcon())  # Add actual icon file

        # New Exit Button - needs translation
        self.exit_btn = QPushButton("Exit")
        self.exit_btn_main_window = self.exit_btn # Store for translation
        self.exit_btn.clicked.connect(self.quit_app) # Connect to quit_app
        self.exit_btn.setIcon(QIcon()) # Add exit icon if you have one

        # Progress Bar
        self.progress_bar = QProgressBar()
        self.progress_bar.setTextVisible(False)
        self.progress_bar.setFixedHeight(6)

        bottom_layout.addWidget(self.generate_btn)
        bottom_layout.addWidget(self.save_btn)
        bottom_layout.addWidget(self.exit_btn) # Add Exit button to layout
        bottom_layout.addWidget(self.progress_bar)
        main_layout.addLayout(bottom_layout)

        # Status Bar
        self.status_label = QLabel("Ready")  # Label - needs translation
        self.status_label_bottom = self.status_label  # Store status label
        self.status_label.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(self.status_label)

        self.setLayout(main_layout)
        self.update_gui_language()  # Initial language setup

    def init_tray_icon(self):
        self.tray_icon = QSystemTrayIcon(self)
        self.tray_icon.setIcon(QIcon("src/logo.webp"))  # Use logo for tray icon as well

        tray_menu = QMenu()
        # Action text - needs translation (if tray menu is always visible)
        show_action = QAction("Show", self)
        self.show_action_tray = show_action  # Store action for translation
        show_action.triggered.connect(self.show)
        # Action text - needs translation (if tray menu is always visible)
        exit_action = QAction("Exit", self)
        self.exit_action_tray = exit_action  # Store action for translation
        exit_action.triggered.connect(self.quit_app) # Connect to quit_app instead of close
        tray_menu.addAction(show_action)
        tray_menu.addAction(exit_action)
        self.tray_icon.setContextMenu(tray_menu)
        self.tray_icon.show()

    def quit_app(self): # New function to quit the app properly
        QCoreApplication.quit() # Use QCoreApplication.quit() to properly exit

    def closeEvent(self, event):
        self.hide() # Minimize to tray on window close
        event.ignore() # Still ignore the close event to prevent window destruction, but now we have quit_app for proper exit.
        self.tray_icon.showMessage(
            # Title - needs translation (if tray message is always visible)
            "SparkTTS Studio",
            # Message - needs translation (if tray message is always visible)
            "The application is running in the system tray",
            QSystemTrayIcon.Information,
            2000
        )


    def set_volume(self, value):
        self.audio_output.setVolume(value / 100)

    def update_time_display(self):
        if self.audio_player.isPlaying():
            position_ms = self.audio_player.position()
            duration_ms = self.audio_player.duration()

            if not isinstance(position_ms, (int, float)): # Check if position is a number
                print(f"Error: audio_player.position() returned unexpected type: {type(position_ms)}")
                return # Exit if unexpected type

            if not isinstance(duration_ms, (int, float)): # Check if duration is a number
                print(f"Error: audio_player.duration() returned unexpected type: {type(duration_ms)}")
                return # Exit if unexpected type


            current_seconds = position_ms // 1000
            total_seconds = duration_ms // 1000

            self.current_time.setText(f"{current_seconds//60:02}:{current_seconds % 60:02}")
            if total_seconds > 0: # Use total_seconds here
                self.total_time.setText(f"{total_seconds//60:02}:{total_seconds % 60:02}")
            else:
                self.total_time.setText("00:00") # Or some default if total duration is invalid

    def on_generation_complete(self, result, elapsed):
        if isinstance(result, Exception):
            error_message = f"Error during speech generation: {result}. "
            # Error related to voice sample
            if "prompt_speech_path" in str(result):
                # Needs translation
                error_message += "Please load a voice sample or select a gender (if no voice sample is loaded)."
            # Error related to gender parameter
            elif "Gender must be 'male' or 'female' or None" in str(result):
                # Needs translation
                error_message += "Please select a valid gender (Male or Female) if not using 'Pick Voice (Gender optional)'."
            else:  # Generic error message
                # Needs translation
                error_message += "Please check your input and try again. See console for details."
            self.status_label.setText(error_message)
            # Print full error to console for debugging
            print(f"Full Error Details:")
            traceback.print_exc() # Print full traceback to console!
            self.progress_bar.setValue(0)
        else:
            filename = f"output_{int(time.time())}.wav"
            sf.write(filename, result, samplerate=16000)
            self.current_audio_file = filename
            # No translation needed for technical status
            self.status_label.setText(
                f"Generated in {elapsed:.1f}s | Saved to {filename}")
            self.waveform.set_waveform(result)
            self.progress_bar.setValue(0)
        self.generate_btn.setEnabled(True)
        self.update_generate_button_state()  # Re-validate and set button state again

    def on_position_changed(self, position):
        if self.audio_player.duration() > 0: # Keep the duration check
            progress = position / self.audio_player.duration()
            self.update_waveform_playhead(progress) # Update waveform playhead
            self.update_time_display() # Call update_time_display from here

    def update_waveform_playhead(self, progress): # New function to update playhead
            self.waveform.set_playhead_progress(progress)

    def on_duration_changed(self, duration):
        if duration > 0:
            self.timer.start(200)
            total_seconds = duration // 1000
            self.total_time.setText(f"{total_seconds//60:02}:{total_seconds % 60:02}")
            self.audio_player_duration_ms = duration
            self.waveform.set_audio_duration(duration) # NEW: Pass duration to WaveformWidget!
        else:
            self.total_time.setText("00:00")
            self.audio_player_duration_ms = 0
            self.waveform.set_audio_duration(0) # Also reset in WaveformWidget

    def seek_audio(self, position_ms): # New method to seek audio
        self.audio_player.setPosition(position_ms)
        if not self.audio_player.isPlaying(): # If not playing, start playing from seek position
            self.play_audio()

    def update_word_count(self):
        """Updates the word count dynamically as the user types."""
        text = self.text_input.toPlainText().strip()
        word_count = len(text.split()) if text else 0
        # Keep word count in English - usually numbers are universal
        self.word_count_label.setText(f"Word Count: {word_count}")

    def validate_inputs(self):
        """
        Validates if required inputs (voice sample OR gender) are provided.
        Returns True if inputs are valid, False otherwise.
        """
        if self.voice_sample is not None:
            return True  # Voice sample loaded, valid
        elif self.gender_selector.currentIndex() != 0:  # Not "Pick Voice (Gender optional)"
            # Gender selected, valid (assuming not "Pick Voice...")
            return True
        else:
            return False  # Neither voice sample nor gender selected, invalid

    def update_generate_button_state(self):
        """Updates the 'Generate Speech' button's enabled state based on input validity."""
        is_valid = self.validate_inputs()
        self.generate_btn.setEnabled(is_valid)
        if not is_valid:
            self.status_label.setText(
                "Load a voice sample or select gender to enable 'Generate Speech'.")  # Needs translation

    def on_gender_changed(self, index):
        """Handler for gender selector changes. Updates generate button state."""
        self.update_generate_button_state()

    def reset_voice_sample(self):
        """Clears the loaded voice sample and restores gender selection."""
        self.voice_sample = None
        # Needs translation - update status label
        self.voice_label.setText("No voice sample loaded")
        self.update_generate_button_state()  # Update button state after reset

    def select_voice_sample(self):
        file_path, _ = QFileDialog.getOpenFileName(
            # "Select Voice Sample" - Dialog title - OS dependent usually
            self, "Select Voice Sample", "", "Audio Files (*.wav *.mp3)"
        )
        if file_path:
            self.voice_sample = file_path
            # Needs translation - update status label, but keep filename in English
            self.voice_label.setText(
                f"Loaded voice sample: {os.path.basename(file_path)}")
            # Update button state after loading voice sample
            self.update_generate_button_state()
        else:
            # Re-validate in case selection was cancelled
            self.update_generate_button_state()

    def save_audio(self):
        if not (self.current_audio_file and os.path.exists(self.current_audio_file)):
            self.status_label.setText("No audio to save!")  # Needs translation
            return
        default_filename = f"SparkTTS_output_{int(time.time())}.wav" # Generate default filename
        save_path, _ = QFileDialog.getSaveFileName(
            # "Save Audio" - Dialog title - OS dependent usually
            self, "Save Audio", default_filename, "WAV Files (*.wav)" # Added default filename here
        )
        if save_path:
            shutil.copy(self.current_audio_file, save_path)
            # No translation needed for technical status
            self.status_label.setText(
                f"Audio saved to: {os.path.basename(save_path)}")

    def play_audio(self): 
        if self.current_audio_file and os.path.exists(self.current_audio_file):
            if not self.audio_player.isPlaying(): 
                self.audio_player.setSource(self.current_audio_file) 
                self.audio_player.play()
                self.play_btn.setText("Pause")
                self.play_btn_play.setText("Pause") 
            else: 
                self.audio_player.pause()
                self.play_btn.setText("Play") 
                self.play_btn_play.setText("Play") 
        elif self.audio_player.isPlaying(): 
            self.audio_player.pause()
            self.play_btn.setText("Play") 
            self.play_btn_play.setText("Play") 

    def pause_audio(self): 
        if self.audio_player.isPlaying():
            self.audio_player.pause()
            self.play_btn.setText("Play") 
            self.play_btn_play.setText("Play")
        else:
            self.play_audio()


    def stop_audio(self): 
        self.audio_player.stop()
        self.play_btn.setText("Play") 
        self.play_btn_play.setText("Play") 
        self.waveform.set_playback_progress_overlay(0.0)
        self.waveform.set_playhead_progress(0.0)
        self.current_time.setText("00:00")

    def run_synthesis(self):
        text = self.text_input.toPlainText().strip()
        if not text:
            self.status_label.setText(
                "Please enter some text!")  # Needs translation
            return

        if not self.validate_inputs():  # Double check validation before synthesis (optional, but good practice)
            self.status_label.setText(
                "Load a voice sample or select gender to generate speech.")  # Needs translation
            return

        # Segmentation: Limit each segment to 150 words.
        segmentation_threshold = 150
        words = text.split()
        if len(words) > segmentation_threshold:
            text_to_process = [
                ' '.join(words[i:i + segmentation_threshold])
                for i in range(0, len(words), segmentation_threshold)
                ]
            self.status_label.setText(
                "Text too long: processing segments...")  # Needs translation
            # Setup progress bar for segments
            self.progress_bar.setMaximum(len(text_to_process))
            self.progress_bar.setValue(0)
        else:
            text_to_process = text
            self.progress_bar.setMaximum(1)  # Single segment
            self.progress_bar.setValue(0)

        # Determine parameters based on whether a voice sample is loaded.
        if self.voice_sample is not None:
            prompt = self.voice_sample
            gender = None
            pitch = None
            speed = None
        else:
            prompt = None
            gender = self.gender_selector.currentText().lower()
            # Corrected gender logic
            gender = None if gender == "pick voice (gender optional)" else gender
            speed = self.speed_slider.value()
            pitch = self.pitch_slider.value()

        # Disable again right before generation, just in case
        self.generate_btn.setEnabled(False)
        self.status_label.setText("Generating speech...")  # Needs translation

        self.worker = TTSWorker(
            self.model, text_to_process, prompt, gender, pitch, speed)
        self.worker.progress_update.connect(self.on_generation_progress)
        self.worker.result_ready.connect(self.on_generation_complete)
        self.worker.start()

    def on_generation_progress(self, current, total):
        # Needs translation - segment info
        self.status_label.setText(f"Generating segment {current} / {total}...")
        self.progress_bar.setValue(current)  # Update progress bar

    # AI GENERATED LANGUAGE TRANSLATIONS
    translations = {  # --- PASTE THE TRANSLATIONS DICTIONARY HERE ---
        "English": {
            "SparkTTS Studio": "SparkTTS Studio",
            "enter_text": "Enter text for speech synthesis...",
            "language": "Language:",
            "word_count": "Word Count:",
            "load_voice": "Load Voice Sample",
            "reset_voice": "Reset Voice Sample",
            "generate_speech": "Generate Speech",
            "gender": "Gender:",
            "auto": "Pick Voice (Gender optional)",  # Renamed "Auto"
            "male": "Male",
            "female": "Female",
            "pitch_label": "Pitch",
            "speed_label": "Speed",
            "play_button": "Play",
            "stop": "Stop",
            "save_audio": "Save Audio",
            "model_cuda": "Model loaded on CUDA",
            "pitch": "Pitch",
            "speed": "Speed",
            "text_input_group": "Text Input",
            "synthesis_controls_group": "Synthesis Controls",
            "audio_visualization_group": "Audio Visualization",
            "voice_parameters_group": "Voice Parameters",
            "no_voice_sample_loaded": "No voice sample loaded",
            "volume_label": "Volume:",
            "ready_status": "Ready",
            "tray_show": "Show",
            "tray_exit": "Exit",
            "tray_message_title": "SparkTTS Studio",
            "tray_message_text": "The application is running in the system tray",
            "error_voice_sample_missing": "Please load a voice sample or select a gender (if no voice sample is loaded).",
            "error_gender_invalid": "Please select a valid gender (Male or Female) if not using 'Pick Voice (Gender optional)'.",
            "error_generic": "Please check your input and try again. See console for details.",
            "status_generating_segment": "Generating segment {current} / {total}...",
            "status_generating_speech": "Generating speech...",
            "status_load_voice_sample_enable_generate": "Load a voice sample or select gender to enable 'Generate Speech'.",
            "status_no_audio_to_save": "No audio to save!",
            "status_please_enter_text": "Please enter some text!",
            "status_text_too_long_segments": "Text too long: processing segments...",
            "status_voice_sample_cleared": "Voice sample cleared.",
            "status_loaded_voice_sample": "Loaded voice sample: {filename}",
            "status_audio_saved_to": "Audio saved to: {filename}",
            "play_button_play": "Play",
            "play_button_pause": "Pause",
            "pause_button": "Pause",
            "stop_button": "Stop",
            "generate_button": "Generate Speech",
            "save_audio_button": "Save Audio",
            "reset_voice_sample_status": "Voice sample cleared.",
            "exit_button": "Exit", 
            
        },
        "Bulgarian": {
            "SparkTTS Studio": "SparkTTS Студио",
            "enter_text": "Въведете текст за синтез на реч...",
            "language": "Език:",
            "word_count": "Брой думи:",
            "load_voice": "Зареди гласов файл",
            "reset_voice": "Изчисти гласов файл",
            "generate_speech": "Генерирай реч",
            "gender": "Пол:",
            "auto": "Избери глас",
            "male": "Мъжски",
            "female": "Женоски",
            "pitch_label": "Височина на тона",
            "speed_label": "Скорост",
            "play": "Пусни",
            "stop": "Спри",
            "save_audio": "Запази аудио",
            "model_cuda": "Моделът е зареден на CUDA",
            "pitch": "Височина на тона",
            "speed": "Скорост",
            "text_input_group": "Въвеждане на текст",
            "synthesis_controls_group": "Контрол на синтеза",
            "audio_visualization_group": "Визуализация на аудио",
            "voice_parameters_group": "Гласови параметри",
            "no_voice_sample_loaded": "Не е зареден гласов файл",
            "volume_label": "Сила на звука:",
            "ready_status": "Готов",
            "tray_show": "Покажи",
            "tray_exit": "Изход",
            "tray_message_title": "SparkTTS Studio",
            "tray_message_text": "Приложението работи в системния трей",
            "error_voice_sample_missing": "Моля, заредете гласов файл или изберете пол (ако не е зареден гласов файл).",
            "error_gender_invalid": "Моля, изберете валиден пол (Мъж или Жена), ако не използвате 'Избери глас (Пол по избор)'.",
            "error_generic": "Моля, проверете въведените данни и опитайте отново. Вижте конзолата за подробности.",
            "status_generating_segment": "Генериране на сегмент {current} / {total}...",
            "status_generating_speech": "Генериране на реч...",
            "status_load_voice_sample_enable_generate": "Заредете гласов файл или изберете пол, за да активирате 'Генерирай реч'.",
            "status_no_audio_to_save": "Няма аудио за запазване!",
            "status_please_enter_text": "Моля, въведете текст!",
            "status_text_too_long_segments": "Текстът е твърде дълъг: обработка на сегменти...",
            "status_voice_sample_cleared": "Гласовият файл е изчистен.",
            "status_loaded_voice_sample": "Зареден гласов файл: {filename}",
            "status_audio_saved_to": "Аудиото е запазено в: {filename}",
            "play_button_play": "Пусни",
            "play_button_pause": "Пауза",
            "stop_button": "Спри",
            "generate_button": "Генерирай реч",
            "save_audio_button": "Запази аудио",
            "reset_voice_sample_status": "Гласовият файл е изчистен.",
            "exit_button": "Затвори", # Added translation for Exit button
            "pause_button": "Пауза", # Added translation for Pause button
        },
        "Spanish": {
            "SparkTTS Studio": "SparkTTS Studio",
            "enter_text": "Introduzca texto para la síntesis de voz...",
            "language": "Idioma:",
            "word_count": "Recuento de palabras:",
            "load_voice": "Cargar muestra de voz",
            "reset_voice": "Restablecer muestra de voz",
            "generate_speech": "Generar voz",
            "gender": "Género:",
            # Renamed "Auto" - Example translation, verify!
            "auto": "Elegir voz (Género opcional)",
            "male": "Masculino",
            "female": "Femenino",
            "pitch_label": "Tono",
            "speed_label": "Velocidad",
            "play": "Reproducir",
            "stop": "Detener",
            "save_audio": "Guardar audio",
            "model_cuda": "Modelo cargado en CUDA",
            "pitch": "Tono",
            "speed": "Velocidad",
            "text_input_group": "Entrada de texto",
            "synthesis_controls_group": "Controles de síntesis",
            "audio_visualization_group": "Visualización de audio",
            "voice_parameters_group": "Parámetros de voz",
            "no_voice_sample_loaded": "No se ha cargado ninguna muestra de voz",
            "volume_label": "Volumen:",
            "ready_status": "Listo",
            "tray_show": "Mostrar",
            "tray_exit": "Salir",
            "tray_message_title": "SparkTTS Studio",
            "tray_message_text": "La aplicación se está ejecutando en la bandeja del sistema",
            "error_voice_sample_missing": "Por favor, cargue una muestra de voz o seleccione un género (si no se carga ninguna muestra de voz).",
            "error_gender_invalid": "Por favor, seleccione un género válido (Masculino o Femenino) si no utiliza 'Elegir voz (Género opcional)'.",
            "error_generic": "Por favor, revise su entrada e inténtelo de nuevo. Consulte la consola para obtener más detalles.",
            "status_generating_segment": "Generando segmento {current} / {total}...",
            "status_generating_speech": "Generando voz...",
            "status_load_voice_sample_enable_generate": "Cargue una muestra de voz o seleccione un género para activar 'Generar voz'.",
            "status_no_audio_to_save": "¡No hay audio para guardar!",
            "status_please_enter_text": "¡Por favor, introduzca algún texto!",
            "status_text_too_long_segments": "Texto demasiado largo: procesando segmentos...",
            "status_voice_sample_cleared": "Muestra de voz borrada.",
            "status_loaded_voice_sample": "Muestra de voz cargada: {filename}",
            "status_audio_saved_to": "Audio guardado en: {filename}",
            "play_button_play": "Reproducir",
            "play_button_pause": "Pausa",
            "stop_button": "Detener",
            "generate_button": "Generar voz",
            "save_audio_button": "Guardar audio",
            "reset_voice_sample_status": "Muestra de voz borrada.",
            "exit_button": "Salir", # Added translation for Exit button
            "pause_button": "Pausa", # Added translation for Pause button
        },
        "French": {
            "SparkTTS Studio": "SparkTTS Studio",
            "enter_text": "Entrez du texte pour la synthèse vocale...",
            "language": "Langue:",
            "word_count": "Nombre de mots:",
            "load_voice": "Charger un échantillon de voix",
            "reset_voice": "Réinitialiser l'échantillon vocal",
            "generate_speech": "Générer la parole",
            "gender": "Genre:",
            # Renamed "Auto" - Example translation, verify!
            "auto": "Choisir une voix (Genre optionnel)",
            "male": "Masculin",
            "female": "Féminin",
            "pitch_label": "Hauteur",
            "speed_label": "Vitesse",
            "play": "Lecture",
            "stop": "Arrêter",
            "save_audio": "Enregistrer l'audio",
            "model_cuda": "Modèle chargé sur CUDA",
            "pitch": "Hauteur",
            "speed": "Vitesse",
            "text_input_group": "Saisie de texte",
            "synthesis_controls_group": "Contrôles de synthèse",
            "audio_visualization_group": "Visualisation audio",
            "voice_parameters_group": "Paramètres vocaux",
            "no_voice_sample_loaded": "Aucun échantillon vocal chargé",
            "volume_label": "Volume:",
            "ready_status": "Prêt",
            "tray_show": "Afficher",
            "tray_exit": "Quitter",
            "tray_message_title": "SparkTTS Studio",
            "tray_message_text": "L'application fonctionne dans la barre des tâches",
            "error_voice_sample_missing": "Veuillez charger un échantillon vocal ou sélectionner un genre (si aucun échantillon vocal n'est chargé).",
            "error_gender_invalid": "Veuillez sélectionner un genre valide (Masculin ou Féminin) si vous n'utilisez pas 'Choisir une voix (Genre optionnel)'.",
            "error_generic": "Veuillez vérifier votre saisie et réessayer. Consultez la console pour plus de détails.",
            "status_generating_segment": "Génération du segment {current} / {total}...",
            "status_generating_speech": "Génération de la parole...",
            "status_load_voice_sample_enable_generate": "Chargez un échantillon vocal ou sélectionnez un genre pour activer 'Générer la parole'.",
            "status_no_audio_to_save": "Aucun audio à enregistrer !",
            "status_please_enter_text": "Veuillez saisir du texte !",
            "status_text_too_long_segments": "Texte trop long : traitement des segments...",
            "status_voice_sample_cleared": "Échantillon vocal effacé.",
            "status_loaded_voice_sample": "Échantillon vocal chargé : {filename}",
            "status_audio_saved_to": "Audio enregistré dans : {filename}",
            "play_button_play": "Lecture",
            "play_button_pause": "Pause",
            "stop_button": "Arrêter",
            "generate_button": "Générer la parole",
            "save_audio_button": "Enregistrer l'audio",
            "reset_voice_sample_status": "Échantillon vocal effacé.",
            "exit_button": "Quitter", # Added translation for Exit button
            "pause_button": "Pause", # Added translation for Pause button
        },
        "Japanese": {
            "SparkTTS Studio": "SparkTTS Studio",
            "enter_text": "音声合成のためのテキストを入力してください…",
            "language": "言語:",
            "word_count": "単語数:",
            "load_voice": "音声サンプルを読み込む",
            "reset_voice": "音声サンプルをリセット",
            "generate_speech": "音声を生成",
            "gender": "性別:",
            # Renamed "Auto" - Example translation, verify!
            "auto": "音声を選択 (性別はオプション)",
            "male": "男性",
            "female": "女性",
            "pitch_label": "ピッチ",
            "speed_label": "速度",
            "play": "再生",
            "stop": "停止",
            "save_audio": "音声を保存",
            "model_cuda": "CUDAでモデルが読み込まれました",
            "pitch": "ピッチ",
            "speed": "速度",
            "text_input_group": "テキスト入力",
            "synthesis_controls_group": "合成コントロール",
            "audio_visualization_group": "オーディオ可視化",
            "voice_parameters_group": "音声パラメータ",
            "no_voice_sample_loaded": "音声サンプルはロードされていません",
            "volume_label": "音量:",
            "ready_status": "準備完了",
            "tray_show": "表示",
            "tray_exit": "終了",
            "tray_message_title": "SparkTTS Studio",
            "tray_message_text": "アプリケーションはシステムトレイで実行されています",
            "error_voice_sample_missing": "音声サンプルをロードするか、性別を選択してください(音声サンプルがロードされていない場合)。",
            "error_gender_invalid": "'音声を選択(性別はオプション)'を使用しない場合は、有効な性別(男性または女性)を選択してください。",
            "error_generic": "入力内容を確認して、もう一度お試しください。詳細については、コンソールを参照してください。",
            "status_generating_segment": "セグメント {current} / {total} を生成中...",
            "status_generating_speech": "音声を生成中...",
            "status_load_voice_sample_enable_generate": "音声サンプルをロードするか、性別を選択して「音声を生成」を有効にしてください。",
            "status_no_audio_to_save": "保存するオーディオはありません!",
            "status_please_enter_text": "テキストを入力してください!",
            "status_text_too_long_segments": "テキストが長すぎます: セグメントを処理中...",
            "status_voice_sample_cleared": "音声サンプルをクリアしました。",
            "status_loaded_voice_sample": "音声サンプルをロードしました: {filename}",
            "status_audio_saved_to": "オーディオを保存しました: {filename}",
            "play_button_play": "再生",
            "play_button_pause": "一時停止",
            "stop_button": "停止",
            "generate_button": "音声を生成",
            "save_audio_button": "音声を保存",
            "reset_voice_sample_status": "音声サンプルをクリアしました。",
            "exit_button": "終了", # Added translation for Exit button
            "pause_button": "一時停止", # Added translation for Pause button
        }
    }

    def update_gui_language(self):
        """Updates the GUI labels based on the selected language."""

        # Get selected language, default to English
        selected_lang = self.gui_language_selector.currentText()
        t = self.translations.get(selected_lang, self.translations["English"])

        # Apply translations to UI elements
        self.setWindowTitle(t["SparkTTS Studio"])  # Window title
        self.language_label.setText(t["language"])  # "Language:" label
        # "Text Input" group box
        self.input_group.setTitle(t["text_input_group"])
        self.text_input.setPlaceholderText(
            t["enter_text"])  # Text input placeholder
        # "Synthesis Controls" group box
        self.controls_group.setTitle(t["synthesis_controls_group"])
        self.voice_btn_load_voice.setText(
            t["load_voice"])  # "Load Voice Sample" button
        # "No voice sample loaded" label
        self.voice_label_status.setText(t["no_voice_sample_loaded"])
        # "Voice Parameters" group box
        self.gender_box.setTitle(t["voice_parameters_group"])
        self.gender_selector_label_widget.setText(
            t["gender"])  # "Gender:" label
        # Gender ComboBox items
        for i, item_text in enumerate([t["auto"], t["male"], t["female"]]):
            self.gender_selector.setItemText(i, item_text)
        self.pitch_box.setTitle(t["pitch_label"])  # "Pitch" group box
        self.speed_box.setTitle(t["speed_label"])  # "Speed" group box
        # "Audio Visualization" group box
        self.vis_group.setTitle(t["audio_visualization_group"])
        self.volume_label_widget.setText(t["volume_label"])  # "Volume:" label
        # "Play" button (initial text)
        self.play_btn_play.setText(t["play_button_play"])
        self.play_btn_play.setText("Play") # Ensure initial text is "Play" after language change too.
        self.play_btn_stop.setText(t["stop_button"])  # "Stop" button
        self.generate_btn_generate.setText(
            t["generate_button"])  # "Generate Speech" button
        self.save_btn_save.setText(
            t["save_audio_button"])  # "Save Audio" button
        # New Exit button text
        self.exit_btn_main_window.setText(t["exit_button"]) # Set Exit button text
        # Bottom status bar "Ready" text
        self.status_label_bottom.setText(t["ready_status"])

        # Tray message - conditionally translate if it's always shown
        # self.tray_icon.showMessage(
        #     t["tray_message_title"],
        #     t["tray_message_text"],
        #     QSystemTrayIcon.Information,
        #     2000
        # )
        self.update_generate_button_state()  # Update button state after language change


# --- TTSWorker class (no changes needed) ---
class TTSWorker(QThread):
    # Emits (final result, generation_time)
    result_ready = Signal(object, float)
    # Emits (current_segment, total_segments)
    progress_update = Signal(int, int)

    def __init__(self, model, text, voice_sample, gender, pitch, speed):
        """
        text: Either a string or a list of strings (segments).
        """
        super().__init__()
        self.model = model
        self.text = text
        self.voice_sample = voice_sample
        self.gender = gender
        self.pitch = pitch
        self.speed = speed

    def run(self):
        start = time.time()
        try:
            results = []
            if isinstance(self.text, list):
                total = len(self.text)
                for i, segment in enumerate(self.text):
                    with torch.no_grad():
                        wav = self.model.inference(
                            segment,
                            prompt_speech_path=self.voice_sample,
                            gender=self.gender,
                            pitch=self.pitch,
                            speed=self.speed
                        )
                    results.append(wav)
                    self.progress_update.emit(i + 1, total)
                final_wav = np.concatenate(results, axis=0)
            else:
                with torch.no_grad():
                    final_wav = self.model.inference(
                        self.text,
                        prompt_speech_path=self.voice_sample,
                        gender=self.gender,
                        pitch=self.pitch,
                        speed=self.speed
                    )
                self.progress_update.emit(1, 1)
            elapsed = time.time() - start
            self.result_ready.emit(final_wav, elapsed)
        except Exception as e:
            self.result_ready.emit(e, 0)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    app.setStyleSheet(STYLE_SHEET)  # Apply stylesheet globally for the app
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SparkTTS("pretrained_models/Spark-TTS-0.5B", device=device)
    window = SparkTTSApp(model, device.type.upper())
    window.show()
    sys.exit(app.exec())

@Rakile
Copy link

Rakile commented Mar 9, 2025

@AcTePuKc I think you forgot to convert the pitch and speed value in run() before sending it to inference.
So maybe just include:
pitch_val = LEVELS_MAP_UI[int(self.pitch)]
speed_val = LEVELS_MAP_UI[int(self.speed)]
so the inference function arguments would then be:
final_wav = self.model.inference(
self.text,
prompt_speech_path=self.voice_sample,
gender=self.gender,
pitch=pitch_val,
speed=speed_val
)

Also just set speed and pitch here ( so that it doesn't crash on LEVELS_MAP_UI):
if self.voice_sample is not None:
prompt = self.voice_sample
gender = None
speed = self.speed_slider.value()
pitch = self.pitch_slider.value()

ofc you also have to import LEVELS_MAP_UI:
from sparktts.utils.token_parser import LEVELS_MAP_UI

@prasannakulkarni333
Copy link

Hello, thank you for the cli code. How to keep the voice consistent among all the segments? For each segment new voice is created, Are there any predefined voices I can use in the generate_tts_audio function? thanks @AcTePuKc

@AcTePuKc
Copy link
Author

AcTePuKc commented Mar 10, 2025


🔧 Patch Drop – Spark-TTS Updated Files (Manual Copy-Paste)

Hey all 👋 – since Git being Git decided to mess up my merge and force a pull that doesn’t even work properly (🙄), I’m posting the updated full files here instead.

Just copy & replace the files below into your Spark-TTS repo.
They include a bunch of small improvements, cleanups, and test tooling I’ve been working on.


Notes:

  • SEEDS ADDED - You can generate same voice over and over again using SEEDS - change it in the CLI or open it and read how to use
  • Some options like EMOTIONS are not implemented by the devs and are NOT working, but I left them in for anyone who wants to tinker or improve it/remove it.
  • This is shared as full file replacements — THE FILE NEEDS to diff or merge, as it's being outdated (safety check)
  • The folder layout assumes you're working inside:
    Spark-TTS-main/ (the same folder where webui.py is).

📄 File Replacements:

🔁 Spark-TTS-main/sparktts/models/BiCodec.py

# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Any
from omegaconf import DictConfig
from safetensors.torch import load_file

from sparktts.utils.file import load_config
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize


class BiCodec(nn.Module):
    """
    BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
    quantizer, and wave generator.
    """

    def __init__(
        self,
        mel_params: Dict[str, Any],
        encoder: nn.Module,
        decoder: nn.Module,
        quantizer: nn.Module,
        speaker_encoder: nn.Module,
        prenet: nn.Module,
        postnet: nn.Module,
        **kwargs
    ) -> None:
        """
        Initializes the BiCodec model with the required components.

        Args:
            mel_params (dict): Parameters for the mel-spectrogram transformer.
            encoder (nn.Module): Encoder module.
            decoder (nn.Module): Decoder module.
            quantizer (nn.Module): Quantizer module.
            speaker_encoder (nn.Module): Speaker encoder module.
            prenet (nn.Module): Prenet network.
            postnet (nn.Module): Postnet network.
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.quantizer = quantizer
        self.speaker_encoder = speaker_encoder
        self.prenet = prenet
        self.postnet = postnet
        self.init_mel_transformer(mel_params)

    @classmethod
    def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
        """
        Loads the model from a checkpoint.

        Args:
            model_dir (Path): Path to the model directory containing checkpoint and config.
        
        Returns:
            BiCodec: The initialized BiCodec model.
        """
        ckpt_path = f'{model_dir}/model.safetensors'
        config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
        mel_params = config["mel_params"]
        encoder = Encoder(**config["encoder"])
        quantizer = FactorizedVectorQuantize(**config["quantizer"])
        prenet = Decoder(**config["prenet"])
        postnet = Decoder(**config["postnet"])
        decoder = WaveGenerator(**config["decoder"])
        speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])

        model = cls(
            mel_params=mel_params,
            encoder=encoder,
            decoder=decoder,
            quantizer=quantizer,
            speaker_encoder=speaker_encoder,
            prenet=prenet,
            postnet=postnet,
        )

        state_dict = load_file(ckpt_path)
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

        # Show filtered + cleaner warning instead of spamming per tensor
        known_safe_missing = [
            "mel_transformer.spectrogram.window",
            "mel_transformer.mel_scale.fb"
        ]
        important_missing = [k for k in missing_keys if k not in known_safe_missing]

        if important_missing:
            print(f"Warning: Important model tensors missing: {important_missing}")
        elif missing_keys:
            print("Note: Some known model tensors (like mel_transformer) are missing and will be auto-regenerated.")


        model.eval()
        model.remove_weight_norm()

        return model

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        """
        Performs a forward pass through the model.

        Args:
            batch (dict): A dictionary containing features, reference waveform, and target waveform.
        
        Returns:
            dict: A dictionary containing the reconstruction, features, and other metrics.
        """
        feat = batch["feat"]
        mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)

        z = self.encoder(feat.transpose(1, 2))
        vq_outputs = self.quantizer(z)

        x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))

        conditions = d_vector
        with_speaker_loss = False

        x = self.prenet(vq_outputs["z_q"], conditions)
        pred_feat = self.postnet(x)
        x = x + conditions.unsqueeze(-1)
        wav_recon = self.decoder(x)

        return {
            "vq_loss": vq_outputs["vq_loss"],
            "perplexity": vq_outputs["perplexity"],
            "cluster_size": vq_outputs["active_num"],
            "recons": wav_recon,
            "pred_feat": pred_feat,
            "x_vector": x_vector,
            "d_vector": d_vector,
            "audios": batch["wav"].unsqueeze(1),
            "with_speaker_loss": with_speaker_loss,
        }

    @torch.no_grad()
    def tokenize(self, batch: Dict[str, Any]):
        """
        Tokenizes the input audio into semantic and global tokens.

        Args:
            batch (dict): The input audio features and reference waveform.

        Returns:
            tuple: Semantic tokens and global tokens.
        """
        feat = batch["feat"]
        mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)

        z = self.encoder(feat.transpose(1, 2))
        semantic_tokens = self.quantizer.tokenize(z)
        global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))

        return semantic_tokens, global_tokens

    @torch.no_grad()
    def detokenize(self, semantic_tokens, global_tokens):
        """
        Detokenizes the semantic and global tokens into a waveform.

        Args:
            semantic_tokens (tensor): Semantic tokens.
            global_tokens (tensor): Global tokens.

        Returns:
            tensor: Reconstructed waveform.
        """
        z_q = self.quantizer.detokenize(semantic_tokens)
        d_vector = self.speaker_encoder.detokenize(global_tokens)
        x = self.prenet(z_q, d_vector)
        x = x + d_vector.unsqueeze(-1)
        wav_recon = self.decoder(x)

        return wav_recon

    def init_mel_transformer(self, config: Dict[str, Any]):
        """
        Initializes the MelSpectrogram transformer based on the provided configuration.

        Args:
            config (dict): Configuration parameters for MelSpectrogram.
        """
        import torchaudio.transforms as TT

        self.mel_transformer = TT.MelSpectrogram(
            config["sample_rate"],
            config["n_fft"],
            config["win_length"],
            config["hop_length"],
            config["mel_fmin"],
            config["mel_fmax"],
            n_mels=config["num_mels"],
            power=1,
            norm="slaney",
            mel_scale="slaney",
        )

    def remove_weight_norm(self):
        """Removes weight normalization from all layers using updated PyTorch API."""
        from torch.nn.utils.parametrizations import weight_norm
        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:
                try:
                    weight_norm.remove(m)
                except Exception:
                    pass  # Already removed or not applicable
        self.apply(_remove_weight_norm)



# Test the model
if __name__ == "__main__":

    config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
    model = BiCodec.load_from_checkpoint(
        model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
    )

    # Generate random inputs for testing
    duration = 0.96
    x = torch.randn(20, 1, int(duration * 16000))
    feat = torch.randn(20, int(duration * 50), 1024)
    inputs = {"feat": feat, "wav": x, "ref_wav": x}

    # Forward pass
    outputs = model(inputs)
    semantic_tokens, global_tokens = model.tokenize(inputs)
    wav_recon = model.detokenize(semantic_tokens, global_tokens)

    # Verify if the reconstruction matches
    if torch.allclose(outputs["recons"].detach(), wav_recon):
        print("Test successful")
    else:
        print("Test failed")

📄 NEW File : (ADDED --seed, --text_file, --text, --prompt_audio, --gender, --pitch, --speed) --emotion - is not working

🔁 Spark-TTS-main/cli/tts_cli.py

import sys
import os
import torch
import numpy as np
import soundfile as sf
import logging
from datetime import datetime
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from cli.SparkTTS import SparkTTS
from sparktts.utils.token_parser import EMO_MAP

# Global cache for reuse
_cached_model_instance = None


def generate_tts_audio(
    text,
    model_dir=None,
    device="cuda:0",
    prompt_speech_path=None,
    prompt_text=None,
    gender=None,
    pitch=None,
    speed=None,
    emotion=None,
    save_dir="example/results",
    segmentation_threshold=150,
    seed=None,
    model=None,
    skip_model_init=False
):
    
    """
    Generates TTS audio from input text, splitting into segments if necessary.

    Args:
        text (str): Input text for speech synthesis.
        model_dir (str): Path to the model directory.
        device (str): Device identifier (e.g., "cuda:0" or "cpu").
        prompt_speech_path (str, optional): Path to prompt audio for cloning.
        prompt_text (str, optional): Transcript of prompt audio.
        gender (str, optional): Gender parameter ("male"/"female").
        pitch (str, optional): Pitch parameter (e.g., "moderate").
        speed (str, optional): Speed parameter (e.g., "moderate").
        emotion (str, optional): Emotion tag (e.g., "HAPPY", "SAD", "ANGRY").
        save_dir (str): Directory where generated audio will be saved.
        segmentation_threshold (int): Maximum number of words per segment.
        seed (int, optional): Seed value for deterministic voice generation.

    Returns:
        str: The unique file path where the generated audio is saved.
    """
    # ============================== OPTIONS REFERENCE ==============================
    # ✔ Gender options: "male", "female"
    # ✔ Pitch options: "very_low", "low", "moderate", "high", "very_high"
    # ✔ Speed options: same as pitch
    # ✔ Emotion options: list from token_parser.py EMO_MAP keys
    # ✔ Seed: any integer (e.g., 1337, 42, 123456) = same voice (mostly)
    # ==============================================================================

    if model_dir is None:
        model_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "pretrained_models", "Spark-TTS-0.5B"))

    global _cached_model_instance

    if not skip_model_init or model is None:
        if _cached_model_instance is None:
            logging.info("Initializing TTS model...")
            if not prompt_speech_path:
                logging.info(f"Using Gender: {gender or 'default'}, Pitch: {pitch or 'default'}, Speed: {speed or 'default'}, Emotion: {emotion or 'none'}, Seed: {seed or 'random'}")
            model = SparkTTS(model_dir, torch.device(device))
            _cached_model_instance = model
        else:
            model = _cached_model_instance


    # Set seed for reproducibility
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        logging.info(f"Seed set to: {seed}")

    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
    save_path = os.path.join(save_dir, f"{timestamp}.wav")

    words = text.split()
    if len(words) > segmentation_threshold:
        logging.info("Text exceeds threshold; splitting into segments...")
        segments = [' '.join(words[i:i + segmentation_threshold]) for i in range(0, len(words), segmentation_threshold)]
        wavs = []
        for seg in segments:
            with torch.no_grad():
                wav = model.inference(
                    seg,
                    prompt_speech_path,
                    prompt_text=prompt_text,
                    gender=gender,
                    pitch=pitch,
                    speed=speed,
                    emotion=emotion
                )
            wavs.append(wav)
        final_wav = np.concatenate(wavs, axis=0)
    else:
        with torch.no_grad():
            final_wav = model.inference(
                text,
                prompt_speech_path,
                prompt_text=prompt_text,
                gender=gender,
                pitch=pitch,
                speed=speed,
                emotion=emotion
            )

    sf.write(save_path, final_wav, samplerate=16000)
    logging.info(f"Audio saved at: {save_path}")
    return save_path


# Example CLI usage
if __name__ == "__main__":
    import argparse


    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt_audio", type=str, help="Path to audio file for voice cloning")
    parser.add_argument("--prompt_text", type=str, help="Transcript text for the prompt audio (optional)")
    parser.add_argument("--text", type=str, help="Text to generate", required=False)
    parser.add_argument("--text_file", type=str, help="Path to .txt file with input text")
    parser.add_argument("--gender", type=str, choices=["male", "female"], default=None)
    parser.add_argument("--pitch", type=str, choices=["very_low", "low", "moderate", "high", "very_high"], default="moderate")
    parser.add_argument("--speed", type=str, choices=["very_low", "low", "moderate", "high", "very_high"], default="moderate")
    parser.add_argument("--emotion", type=str, choices=list(EMO_MAP.keys()), default=None)
    parser.add_argument("--seed", type=int, default=None)
    args = parser.parse_args()


    # ---------------- Argument Validation Block ---------------- NEW! SPECIAL!!!EXTRA SPICY!!!
    if not args.prompt_audio and not args.gender:
        print("❌ Error: You must provide either --gender (male/female) or --prompt_audio for voice cloning.")
        print("   Example 1: python tts_cli.py --text \"Hello there.\" --gender female")
        print("   Example 2: python tts_cli.py --text \"Hello there.\" --prompt_audio sample.wav")
        sys.exit(1)

    # --------------- Emotions ------------
    if args.emotion:
        logging.warning("⚠ Emotion input is experimental — model may not reflect emotion changes reliably or at all.")



    # Allow loading text from a file if provided
    if args.text_file:
        if os.path.exists(args.text_file):
            with open(args.text_file, "r", encoding="utf-8") as f:
                args.text = f.read().strip()
        else:
            raise FileNotFoundError(f"Text file not found: {args.text_file}")

    # If Not Provided Text or Text File
    if not args.text:
        raise ValueError("You must provide either --text or --text_file.")

    # Voice Cloning Mode Overrides
    if args.prompt_audio:
        # Normalize path + validate
        args.prompt_audio = os.path.abspath(args.prompt_audio)
        if not os.path.exists(args.prompt_audio):
            logging.error(f"❌ Prompt audio file not found: {args.prompt_audio}")
            sys.exit(1)

        # Log cloning info
        logging.info("🔊 Voice cloning mode enabled")
        logging.info(f"🎧 Cloning from: {args.prompt_audio}")

        # Bonus: Log audio info
        try:
            info = sf.info(args.prompt_audio)
            logging.info(f"📏 Prompt duration: {info.duration:.2f} seconds | Sample Rate: {info.samplerate}")
        except Exception as e:
            logging.warning(f"⚠️ Could not read prompt audio info: {e}")

        # Override pitch/speed/gender
        if args.gender or args.pitch or args.speed:
            print("[!] Warning: Voice cloning mode detected — ignoring gender/pitch/speed settings.")
        args.gender = None
        args.pitch = None
        args.speed = None

    # Start timing
    start_time = time.time()

    output_file = generate_tts_audio(
        text=args.text,
        gender=args.gender,
        pitch=args.pitch,
        speed=args.speed,
        emotion=args.emotion,
        seed=args.seed,
        prompt_speech_path=args.prompt_audio,
        prompt_text=args.prompt_text,
    )

    # End timing
    end_time = time.time()
    elapsed = end_time - start_time

    print(f"Generated audio file: {output_file}")
    print(f"⏱ Generation time: {elapsed:.2f} seconds")

📄 File Replacement :

🔁 Spark-TTS-main/cli/SparkTTS.py

# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import re
import torch
import numpy as np
from typing import Tuple
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from sparktts.utils.file import load_config
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
from sparktts.utils.token_parser import TokenParser 


class SparkTTS:
    """
    Spark-TTS for text-to-speech generation.
    """

    def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
        """
        Initializes the SparkTTS model with the provided configurations and device.

        Args:
            model_dir (Path): Directory containing the model and config files.
            device (torch.device): The device (CPU/GPU) to run the model on.
        """
        self.device = device
        self.model_dir = model_dir
        self.configs = load_config(f"{model_dir}/config.yaml")
        self.sample_rate = self.configs["sample_rate"]
        self._initialize_inference()

        # Device Info Logging
        if self.device.type == "cuda":
            print(f"CUDA Device in use: {torch.cuda.get_device_name(self.device.index if self.device.index is not None else 0)}")
        else:
            print("CPU Mode activated (fallback) – slower generation expected.")

    def _initialize_inference(self):
        """Initializes the tokenizer, model, and audio tokenizer for inference."""
        self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
        self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
        self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
        self.model.to(self.device)

    def process_prompt(
        self,
        text: str,
        prompt_speech_path: Path,
        prompt_text: str = None,
    ) -> Tuple[str, torch.Tensor]:
        """
        Process input for voice cloning.

        Args:
            text (str): The text input to be converted to speech.
            prompt_speech_path (Path): Path to the audio file used as a prompt.
            prompt_text (str, optional): Transcript of the prompt audio.

        Return:
            Tuple[str, torch.Tensor]: Input prompt; global tokens
        """

        global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
            prompt_speech_path
        )
        global_tokens = "".join(
            [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
        )

        # Prepare the input tokens for the model
        if prompt_text is not None:
            semantic_tokens = "".join(
                [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
            )
            inputs = [
                TASK_TOKEN_MAP["tts"],
                "<|start_content|>",
                prompt_text,
                text,
                "<|end_content|>",
                "<|start_global_token|>",
                global_tokens,
                "<|end_global_token|>",
                "<|start_semantic_token|>",
                semantic_tokens,
               "<|end_semantic_token|>",
            ]
        else:
            inputs = [
                TASK_TOKEN_MAP["tts"],
                "<|start_content|>",
                text,
                "<|end_content|>",
                "<|start_global_token|>",
                global_tokens,
                "<|end_global_token|>",
            ]

        inputs = "".join(inputs)

        return inputs, global_token_ids

    def process_prompt_control(
        self,
        gender: str,
        pitch: str,
        speed: str,
        text: str,
        emotion: str = None  # ← New
    ):
        """
        Process input for voice creation.

        Args:
            gender (str): female | male.
            pitch (str): very_low | low | moderate | high | very_high
            speed (str): very_low | low | moderate | high | very_high
            text (str): The text input to be converted to speech.
            emotion (str, optional): Emotion label (e.g., HAPPY, SAD, ANGRY, etc.)

        Return:
            str: Input prompt
        """
        if speed is None:
            speed = "high"  # default fallback
        if pitch is None:
            pitch = "moderate"  # optional: set default pitch too

        assert gender in GENDER_MAP
        assert pitch in LEVELS_MAP
        assert speed in LEVELS_MAP


        gender_id = GENDER_MAP[gender]
        pitch_level_id = LEVELS_MAP[pitch]
        speed_level_id = LEVELS_MAP[speed]

        pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
        speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
        gender_tokens = f"<|gender_{gender_id}|>"

        # Include emotion token if provided
        attribte_tokens = [gender_tokens, pitch_label_tokens, speed_label_tokens]
        if emotion:
            from sparktts.utils.token_parser import TokenParser
            attribte_tokens.append(TokenParser.emotion(emotion))

        attribte_tokens = "".join(attribte_tokens)

        control_tts_inputs = [
            TASK_TOKEN_MAP["controllable_tts"],
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_style_label|>",
            attribte_tokens,
            "<|end_style_label|>",
        ]

        return "".join(control_tts_inputs)


    @torch.no_grad()
    def inference(
        self,
        text: str,
        prompt_speech_path: Path = None,
        prompt_text: str = None,
        gender: str = None,
        pitch: str = None,
        speed: str = None,
        seed: int = None,  # ← ADDED: Deterministic voice control
        emotion: str = None,  # ← ADDED: Emotion conditioning
        temperature: float = 0.8,
        top_k: float = 50,
        top_p: float = 0.95,
    ) -> torch.Tensor:
        """
        Performs inference to generate speech from text, incorporating prompt audio and/or control attributes.

        Args:
            text (str): The text input to be converted to speech.
            prompt_speech_path (Path): Path to the audio file used as a prompt.
            prompt_text (str, optional): Transcript of the prompt audio.
            gender (str): female | male.
            pitch (str): very_low | low | moderate | high | very_high
            speed (str): very_low | low | moderate | high | very_high
            emotion (str): Emotion label (e.g., HAPPY, SAD, ANGRY, etc.)
            temperature (float): Sampling temperature for randomness control.
            top_k (float): Top-k sampling.
            top_p (float): Top-p (nucleus) sampling.

        Returns:
            torch.Tensor: Generated waveform as a tensor.
        """

        # Build prompt using control tokens if gender is set
        if gender is not None:
            prompt = self.process_prompt_control(gender, pitch, speed, text, emotion=emotion)  # ← ADDED emotion
        else:
            prompt, global_token_ids = self.process_prompt(text, prompt_speech_path, prompt_text)

        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)

        # Seed setting block (New Code) ← ADDED
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        # Optional fix for pad_token_id warning ← ADDED
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id or 0

        # Generate speech using the model
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=3000,
            do_sample=True,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            pad_token_id=self.tokenizer.pad_token_id  # ← ADDED Fix 3
        )

        # Trim generated output (remove prompt tokens)
        generated_ids = [
            output_ids[len(input_ids):]
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        # Decode tokens to text
        predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Extract semantic token IDs
        pred_semantic_ids = (
            torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
            .long()
            .unsqueeze(0)
        )

        # Extract global token IDs if control prompt was used
        if gender is not None:
            global_token_ids = (
                torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
                .long()
                .unsqueeze(0)
                .unsqueeze(0)
            )

        # Detokenize to waveform
        wav = self.audio_tokenizer.detokenize(
            global_token_ids.to(self.device).squeeze(0),
            pred_semantic_ids.to(self.device),
        )

        return wav

Example 1: Basic text-to-speech

python tts_cli.py --text "Hello there! Welcome to SparkTTS."

🗣️ This generates a default voice without specifying gender, pitch, speed, emotion, or seed. Great for quick testing.


Example 2: Custom voice with gender, pitch, emotion and seed

python tts_cli.py --text "Let's test some voice features." --gender female --pitch high --emotion HAPPY --seed 42

🎙️ This sets a female voice, high pitch, adds an emotion tag (HAPPY), and locks in reproducibility with a seed.


Example 3: Load text from a .txt file instead of typing in CLI

python tts_cli.py --text_file input_text.txt --gender male --pitch low --speed moderate --emotion SAD

📄 Reads input text directly from input_text.txt. Useful when testing longer scripts or multi-line content.


Example 4: Load Voice for Cloning, using seed for constant reproduction

--text_file text.txt --prompt_audio \Path-To\Spark-TTS-main\src\demos\trump\trump_en.wav --seed 123456

Example 5: Load Voice for Cloning, but now instead of trying to sing copy the text from internet and place it in .txt fileusing seed for constant reproduction

--text "Ra Ra - Oh ... Ma Ma - Shake it Baby" --prompt_audio F:\LocalTTS-Code\Spark-TTS-main\src\demos\trump\trump_en.wav --seed 123456

🧪 test_emotions_batch.py — What's it for?

python test_emotions_batch.py

Batch testing all pitch + emotion combinations automatically.

  • It will loop through all available emotions and pitch levels, and generate individual audio files.
  • Files are saved in a timestamped folder like: example/emotion_tests/YYYYMMDD_HHMM/
  • Helps developers or voice designers evaluate how different settings impact the voice output, especially useful for:
    • Creating demo packs
    • Benchmarking TTS behavior
    • Finding bugs/inconsistencies
    • Pre-selecting voice styles before deployment

📎 That’s it!

Do not forget to copy at least SparkTTS.py and BiCodec.py as there Cuda is enabled and

If you hit issues, feel free to tweak, or improve. Ask ChatGPT with the error/warning you have.😉

@prasannakulkarni333
Copy link

Thank you for the new code. would be good to know of a way to know the seed, pitch, speed settings of a generated voice so it can be generated again. @AcTePuKc

@AcTePuKc
Copy link
Author

What do you mean exactly?
If you're using the same seed, you’ll get the same voice output — as long as you also use the same pitch and speed settings.

That’s kind of the whole point of deterministic seeding 😄
But if you're doing voice cloning, then use the (new™ – I forgot to add it earlier 😅) --prompt_audio option.
Just note: if you're cloning from an audio sample, you can't use pitch or speed settings — they’ll be ignored automatically.

Examples:

--text_file text.txt --prompt_audio \Path-To\Spark-TTS-main\src\demos\trump\trump_en.wav --seed 123456

or

--text "Ra Ra - Oh ... Ma Ma - Shake it Baby" --prompt_audio \Path-To\Spark-TTS-main\src\demos\trump\trump_en.wav --seed 123456

TL;DR – If you want reproducibility, stick with --seed + pitch + speed.
And if you want cloning, use --prompt_audio, but pitch/speed don’t apply in that mode. 😄

@prasannakulkarni333
Copy link

prasannakulkarni333 commented Mar 11, 2025

Hi, when the model inferences on its own and produces and voice, I sometimes like that is a nice voice and want to use it again. That's why I wanted to know the seed value of a any particular voice, say Rick from Rick and Morty so I don't need to have a sample voice. But now I see your point and should just use the clone feature (which I need to give a try).
Thank you for answering!

@AcTePuKc
Copy link
Author

Hi, when the model inferences on its own and produces and voice, I sometimes like that is a nice voice and want to use it again. That's why I wanted to know the seed value of a any particular voice, say Rick from Rick and Morty so I don't need to have a sample voice. But now I see your point and should just use the clone feature (which I need to give a try). Thank you for answering!

Yep, you’re totally right — voice cloning is deterministic in terms of the speaker’s identity (tone, timbre, overall "who it sounds like"), even without setting a seed.
But if you want the output to be exactly the same — including pauses, pitch curves, rhythm — then using --seed ensures full reproducibility.

Things like temperature, top_k, and top_p can also influence slight variations in how the same cloned voice speaks — especially with expressive or long sentences.
But yeah, in most normal use cases, the voice stays the same, even if minor differences in delivery happen.

@851682852
Copy link

You must specify a gender, which can only be male or female, otherwise an error will be reported.

@AcTePuKc
Copy link
Author

Thanks for the comment! The model currently expects 'male' or 'female' — that’s a backend limitation, not a personal choice. If you'd like to add support for non-binary or flexible gender handling, feel free to fork and enhance it. Open-source means you're free to build it your way — all it takes is the sacred art of Copy → Paste. 🙂

@851682852
Copy link

Thanks for the comment! The model currently expects 'male' or 'female' — that’s a backend limitation, not a personal choice. If you'd like to add support for non-binary or flexible gender handling, feel free to fork and enhance it. Open-source means you're free to build it your way — all it takes is the sacred art of Copy → Paste. 🙂

I mean, in your tts_cli.py, the gender is set to None by default.

parser.add_argument("--gender", type=str, choices=["male", "female"], default=None)

This will cause the command python tts_cli.py --text "Hello there! Welcome to SparkTTS." to throw an error.

Additionally, I have a question: the generated audio is limited to a maximum of 60 seconds, which makes it impossible to generate long texts in one go. How can this issue be resolved?

@prasannakulkarni333
Copy link

prasannakulkarni333 commented Mar 17, 2025

Thanks for the comment! The model currently expects 'male' or 'female' — that’s a backend limitation, not a personal choice. If you'd like to add support for non-binary or flexible gender handling, feel free to fork and enhance it. Open-source means you're free to build it your way — all it takes is the sacred art of Copy → Paste. 🙂

I mean, in your tts_cli.py, the gender is set to None by default.

parser.add_argument("--gender", type=str, choices=["male", "female"], default=None)

This will cause the command python tts_cli.py --text "Hello there! Welcome to SparkTTS." to throw an error.

Additionally, I have a question: the generated audio is limited to a maximum of 60 seconds, which makes it impossible to generate long texts in one go. How can this issue be resolved?

What I am doing is to split text into parts, make audio, and stitch it together. That is why the need is to reliably replicate a voice over the iterations. @851682852

@prasannakulkarni333
Copy link

prasannakulkarni333 commented Mar 17, 2025

@AcTePuKc If I add the repo to another python program with your above changes, and do

from tts import generate_tts_audio
generate_tts_audio(
                text,
                gender="female",
                seed=42,
                pitch="moderate",
                speed="moderate",
                save_dir=save_path,
            )


then it looks like the seed value gets ignored. If I run the file as Python in the main folder then seems to work fine. Any ideas why?

@AcTePuKc
Copy link
Author

because I've added this to be possible, the seeds - in this way we could have constant voice generation
I have pointed already what I have added - this is in SparkTTS.py - and this is why the thing you try doesn't work

@torch.no_grad()
    def inference(
        self,
        text: str,
        prompt_speech_path: Path = None,
        prompt_text: str = None,
        gender: str = None,
        pitch: str = None,
        speed: str = None,
        seed: int = None,  # ← ADDED: Deterministic voice control
        emotion: str = None,  # ← ADDED: Emotion conditioning
        temperature: float = 0.8,
        top_k: float = 50,
        top_p: float = 0.95,
    ) -> torch.Tensor:
        """
        Performs inference to generate speech from text, incorporating prompt audio and/or control attributes.

        Args:
            text (str): The text input to be converted to speech.
            prompt_speech_path (Path): Path to the audio file used as a prompt.
            prompt_text (str, optional): Transcript of the prompt audio.
            gender (str): female | male.
            pitch (str): very_low | low | moderate | high | very_high
            speed (str): very_low | low | moderate | high | very_high
            emotion (str): Emotion label (e.g., HAPPY, SAD, ANGRY, etc.)
            temperature (float): Sampling temperature for randomness control.
            top_k (float): Top-k sampling.
            top_p (float): Top-p (nucleus) sampling.

        Returns:
            torch.Tensor: Generated waveform as a tensor.
        """

        # Build prompt using control tokens if gender is set
        if gender is not None:
            prompt = self.process_prompt_control(gender, pitch, speed, text, emotion=emotion)  # ← ADDED emotion
        else:
            prompt, global_token_ids = self.process_prompt(text, prompt_speech_path, prompt_text)

        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)

        # Seed setting block (New Code) ← ADDED
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        # Optional fix for pad_token_id warning ← ADDED
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id or 0

        # Generate speech using the model
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=3000,
            do_sample=True,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            pad_token_id=self.tokenizer.pad_token_id  # ← ADDED Fix 3
        )

        # Trim generated output (remove prompt tokens)
        generated_ids = [
            output_ids[len(input_ids):]
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        # Decode tokens to text
        predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Extract semantic token IDs
        pred_semantic_ids = (
            torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
            .long()
            .unsqueeze(0)
        )

        # Extract global token IDs if control prompt was used
        if gender is not None:
            global_token_ids = (
                torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
                .long()
                .unsqueeze(0)
                .unsqueeze(0)
            )

        # Detokenize to waveform
        wav = self.audio_tokenizer.detokenize(
            global_token_ids.to(self.device).squeeze(0),
            pred_semantic_ids.to(self.device),
        )

        return wav

you can try use this

import torch
import numpy as np
import re

@torch.no_grad()
def inference_standalone(
    text: str,
    prompt: str,
    tokenizer,
    model,
    audio_tokenizer,
    device=torch.device("cuda:0"),
    seed: int = None,
    temperature: float = 0.8,
    top_k: float = 50,
    top_p: float = 0.95,
    use_control_prompt: bool = False,
    global_token_ids: torch.Tensor = None
) -> torch.Tensor:
    """
    Standalone SparkTTS inference function without class dependencies.
    """
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id or 0

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=3000,
        do_sample=True,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        pad_token_id=tokenizer.pad_token_id
    )

    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    predicts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    pred_semantic_ids = (
        torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
        .long()
        .unsqueeze(0)
    )

    if use_control_prompt:
        global_token_ids = (
            torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
            .long()
            .unsqueeze(0)
            .unsqueeze(0)
        )
    elif global_token_ids is None:
        raise ValueError("global_token_ids must be provided if not using control prompt.")

    wav = audio_tokenizer.detokenize(
        global_token_ids.to(device).squeeze(0),
        pred_semantic_ids.to(device),
    )

    return wav

The above can be added in - be careful we use two files from sparktts/models

tweek the code for your needs - in the first part look for <-- This is new or just Ctrl+F and search for New or ADDED - i've added bunch of things that were not available and are not available still ... because for some reason I cannot push update - someone had to do the hard task to use dif on my file and the actual new version.

@AcTePuKc
Copy link
Author

What I’m doing in tts_cli.py is splitting long text into segments and stitching the audio together. That way, you don’t hit the model’s length limits, and you can generate longer output smoothly.

This logic wasn’t in the original repo (or at least, not when I checked) — I added it manually. Here’s the snippet:

words = text.split()
if len(words) > segmentation_threshold:
    logging.info("Text exceeds threshold; splitting into segments...")
    segments = [' '.join(words[i:i + segmentation_threshold]) for i in range(0, len(words), segmentation_threshold)]
    wavs = []
    for seg in segments:
        with torch.no_grad():
            wav = model.inference(
                seg,
                prompt_speech_path,
                prompt_text=prompt_text,
                gender=gender,
                pitch=pitch,
                speed=speed,
                emotion=emotion
            )
        wavs.append(wav)
    final_wav = np.concatenate(wavs, axis=0)
else:
    with torch.no_grad():
        final_wav = model.inference(
            text,
            prompt_speech_path,
            prompt_text=prompt_text,
            gender=gender,
            pitch=pitch,
            speed=speed,
            emotion=emotion
        )

# ← This line is what does the stitching
final_wav = np.concatenate(wavs, axis=0)

So yeah — that’s how you handle long-form audio generation. Simple workaround, but it works.


As for the gender error — think about it: you ran the command without --gender and without --prompt_audio. So what should the model do, magically guess your intent? 😄
You either need to pass --gender male or --gender female, or just use --prompt_audio your_voice.wav if you're cloning a voice. That’s how the model backend works — not a design opinion, just how it was built.

Also, yeah — a bunch of things like --seed, --text_file, --prompt_audio, --gender, --pitch, --speed weren’t in the original CLI at all — I just added them in manually. If something’s missing or broken, feel free to fork, patch, and ping — maybe the devs will roll it into the official code later 🙂


I've edited the #[tts_cli.py and mini error in SparkTTS.py](#10 (comment)) — please keep it in the cli/ folder along with SparkTTS.py, inference.py, and optionally (if you're batch testing) test_emotions_batch.py. - check for diff!!!

That test script can be really useful for stress testing voices, pitch ranges, or verifying emotion conditioning at scale — it auto-generates a whole batch of outputs in a timestamped folder.
Tip: You can even tweak it to run on multiple seeds or test different text prompts if you want to build a small benchmark suite.

@prasannakulkarni333 @851682852

@nq4t
Copy link

nq4t commented Mar 19, 2025

This is neat. I'd been using the webui, but it was less than stable. I couldn't get the "pretty" UI version to work...there's no error display other than an error occurred when trying to render. But I'm not here about that...I want to talk seeds.

Normally, AFAIK, the random seed used for generation can come from the CPU, GPU, and I haven't looked to see if there's a way of getting that back without modifying the core stuff; and I'm not great at Python. I can write a pretty mean bash script; so I wrote a "wrapper" that generates a 32-bit int and feeds that as the seed every run. The batch file even appends the seed to the written filename. I'm lazy. It gives me random seeds I can run till I find one I like.

But, as I said, I'm lazy; and I spend so much time in command line that I like moving away from it when I'm playing around with things. You'll need box on the UI for seed input so you can pass that to the code that sets a manual seed...but that's obvious. I'm not trying to tell anyone how to do anything...I've just been up all night and I'm in a meeting that could have been an email; so I'm taking the long and winding way around.

What if an option is added to let the UI/CLI control the seed at all times; the difference is whether you enter one or it generates one. CUDA machines can pull an number from the cards RNG...with CPU fallback, probably similar to how it does now.

But why do this? Automate some of the "playing around" I'm sure people are doing. I can just let it randomize the seed each render, similar to how I do now; but the ability to actually see what seed was used. Maybe the CLI does this already and I just didn't dig deep enough. I might also try to hack this in to the first UI myself this weekend, for fun, because I think I need to learn python now.

Of course, I'll push the length of the text to speak and at a point the voice completely changes and make the seed moot anyway...but they are a fun place to start.

@AcTePuKc
Copy link
Author

This is neat. I'd been using the webui, but it was less than stable. I couldn't get the "pretty" UI version to work...there's no error display other than an error occurred when trying to render. But I'm not here about that...I want to talk seeds.

Normally, AFAIK, the random seed used for generation can come from the CPU, GPU, and I haven't looked to see if there's a way of getting that back without modifying the core stuff; and I'm not great at Python. I can write a pretty mean bash script; so I wrote a "wrapper" that generates a 32-bit int and feeds that as the seed every run. The batch file even appends the seed to the written filename. I'm lazy. It gives me random seeds I can run till I find one I like.

But, as I said, I'm lazy; and I spend so much time in command line that I like moving away from it when I'm playing around with things. You'll need box on the UI for seed input so you can pass that to the code that sets a manual seed...but that's obvious. I'm not trying to tell anyone how to do anything...I've just been up all night and I'm in a meeting that could have been an email; so I'm taking the long and winding way around.

What if an option is added to let the UI/CLI control the seed at all times; the difference is whether you enter one or it generates one. CUDA machines can pull an number from the cards RNG...with CPU fallback, probably similar to how it does now.

But why do this? Automate some of the "playing around" I'm sure people are doing. I can just let it randomize the seed each render, similar to how I do now; but the ability to actually see what seed was used. Maybe the CLI does this already and I just didn't dig deep enough. I might also try to hack this in to the first UI myself this weekend, for fun, because I think I need to learn python now.

Of course, I'll push the length of the text to speak and at a point the voice completely changes and make the seed moot anyway...but they are a fun place to start.

Hey — just to clarify, this is already possible. I’ve made an expanded CLI version that fully supports seed input, and I’ve shared exactly what files need to be patched (check tts_cli.py and SparkTTS.py). It’s already there — just plug them in and you’re good to go.

I do have a GUI version too, but it was built way earlier — before I added seed support — and honestly, I don’t feel like going back and retrofitting it just for a seed box. If you really want that in the GUI, feel free to add it yourself — it’s open source, after all.

Scroll up and Ctrl+F for "seed" — everything’s already explained. I don’t even know Python that well myself, but I just went with the whole “vibe coding” approach and got it working.

So yeah, nothing’s stopping you from expanding it further. Have fun with it 🙂

@nq4t
Copy link

nq4t commented Mar 27, 2025

Okay...I spent some time playing with it.

Seeds are "pointless". I can create the same line of speech with the same voice with the same seed. Change one word, even misspell, and the voice changes. This is VERY noticeable when your interface splits up tokens...there are very jarring changes in voice, tone, and emotion. I think...what we're supposed to be doing...is getting a random voice...then one-shot cloning that voice again later.

That is literally the only current way to maintain a consistent voice through an entire process. I mean I've gotten the same voice with different lines on the same seeds; but the tokenized speech plays a huge role in generation.

The problem with this is the context window is too short. A good voice sample with transcription will just eat the entire context. Now that's not your fault..that's the underlying engine. Even Kokoro is limited to 150 tokens at a time.

OF course, that's not to mention how many time it just fails to do anything. I have a lot of 768byte wav files of nothing and 2 minutes of silence.

But hearing that your interface was vibe coded...I think I'm gonna have to load the unmodded repo and see if the problems exist. I mean I'm all for AI assisted coding; but vibe coding is just...bad. IF I wanted no effort code I'd load qwen-coder locally.

@851682852
Copy link

Okay...I spent some time playing with it.

Seeds are "pointless". I can create the same line of speech with the same voice with the same seed. Change one word, even misspell, and the voice changes. This is VERY noticeable when your interface splits up tokens...there are very jarring changes in voice, tone, and emotion. I think...what we're supposed to be doing...is getting a random voice...then one-shot cloning that voice again later.

That is literally the only current way to maintain a consistent voice through an entire process. I mean I've gotten the same voice with different lines on the same seeds; but the tokenized speech plays a huge role in generation.

The problem with this is the context window is too short. A good voice sample with transcription will just eat the entire context. Now that's not your fault..that's the underlying engine. Even Kokoro is limited to 150 tokens at a time.

OF course, that's not to mention how many time it just fails to do anything. I have a lot of 768byte wav files of nothing and 2 minutes of silence.

But hearing that your interface was vibe coded...I think I'm gonna have to load the unmodded repo and see if the problems exist. I mean I'm all for AI assisted coding; but vibe coding is just...bad. IF I wanted no effort code I'd load qwen-coder locally.

A very interesting idea! I noticed that using voice cloning for long-text generation ensures a consistent tone, making the seed meaningless. I think the approach should be to use a random voice, generate a speech sample, clone it, and then use this cloned voice for long-text generation. Your idea has inspired me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants