Skip to content

FunAudioLLM/InspireMusic

Repository files navigation

logo

Demo Code Model Space Space Paper

InspireMusic is a fundamental AIGC toolkit and models designed for music, song, and audio generation.

GitHub Repo stars Please support our community 💖 by starring it 感谢大家加⭐支持🙏

Highlights | News | Introduction | Installation | Quick Start | Tutorial | Models | Contact


Highlights

InspireMusic focuses on music generation, song generation, and audio generation.

  • A unified framework for music, song, and audio generation. Controllable with text prompts, music genres, music structures, etc.
  • Support music generation tasks with high audio quality, with available the sampling rates of 24kHz, 48kHz.
  • Support long-form audio generation.
  • Convenient fine-tuning and inference. Support mixed precision training (e.g., BF16, FP16, FP32). Provide convenient fine-tuning and inference scripts, allowing users to easily fine-tune and experience their music generation models.

What's New 🔥

  • 2025/03: InspireMusic Technical Report is available.
  • 2025/02: Online demo is available on ModelScope Space and HuggingFace Space.
  • 2025/01: Open-source InspireMusic-Base, InspireMusic-Base-24kHz, InspireMusic-1.5B, InspireMusic-1.5B-24kHz, InspireMusic-1.5B-Long models for music generation. Models are available on both ModelScope and HuggingFace.
  • 2024/12: Support to generate 48kHz audio with super resolution flow matching.
  • 2024/11: Welcome to preview 👉🏻 InspireMusic Demos 👈🏻. We're excited to share this with you and are working hard to bring even more features and models soon. Your support and feedback mean a lot to us!
  • 2024/11: We are thrilled to announce the open-sourcing of the InspireMusic code repository and demos. InspireMusic is a unified framework for music, song, and audio generation, featuring capabilities such as text-to-music generation, music continuation and so on. InspireMusic shows comparative performance on music generation with currently top-tier open-sourced models.

Introduction

Note

This repo contains the algorithm infrastructure and some simple examples. Currently only support English text prompts.

Tip

To preview the performance, please refer to InspireMusic Demo Page.

InspireMusic is a unified music, song, and audio generation framework through the audio tokenization integrated with autoregressive transformer and flow-matching based model. The original motive of this toolkit is to empower the common users to innovate soundscapes and enhance euphony in research through music, song, and audio crafting. The toolkit provides both training and inference codes for AI-based generative models that create high-quality music. Featuring a unified framework, InspireMusic incorporates audio tokenizers with autoregressive transformer and super-resolution flow-matching modeling, allowing for the controllable generation of music, song, and audio with both text and audio prompts. The toolkit currently supports music generation, will support song generation, audio generation in the future.

InspireMusic

Light
Figure 1: An overview of the InspireMusic framework. We introduce InspireMusic, a unified framework for music, song, audio generation capable of producing high-quality long-form audio. InspireMusic consists of the following three key components. Audio Tokenizers convert the raw audio waveform into discrete audio tokens that can be efficiently processed and trained by the autoregressive transformer model. Audio waveform of lower sampling rate has converted to discrete tokens via a high bitrate compression audio tokenizer[1]. Autoregressive Transformer model is based on Qwen2.5[2] as the backbone model and is trained using a next-token prediction approach on both text and audio tokens, enabling it to generate coherent and contextually relevant token sequences. The audio and text tokens are the inputs of an autoregressive model with the next token prediction to generate tokens. Super-Resolution Flow-Matching Model based on flow modeling method, maps the generated tokens to latent features with high-resolution fine-grained acoustic details[3] obtained from a higher sampling rate of audio to ensure the acoustic information flow connected with high fidelity through models. A vocoder then generates the final audio waveform from these enhanced latent features. InspireMusic supports a range of tasks including text-to-music, music continuation, music reconstruction, and music super-resolution.

Installation

Clone

  • Clone the repo
git clone --recursive https://github.com/FunAudioLLM/InspireMusic.git
# If you failed to clone submodule due to network failures, please run the following command until success
cd InspireMusic
git submodule update --recursive
# or you can download the third_party repo Matcha-TTS manually
cd third_party && git clone https://github.com/shivammehta25/Matcha-TTS.git

Install from Source

InspireMusic requires Python>=3.8, PyTorch>=2.0.1, flash attention==2.6.2/2.6.3, CUDA>=11.2. You can install the dependencies with the following commands:

conda create -n inspiremusic python=3.8
conda activate inspiremusic
cd InspireMusic
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platforms.
conda install -y -c conda-forge pynini==2.1.5
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# install flash attention to speedup training
pip install flash-attn --no-build-isolation
  • Install within the package:
cd InspireMusic
# You can run to install the packages
python setup.py install
pip install flash-attn --no-build-isolation

We also recommend having sox or ffmpeg installed, either through your system or Anaconda:

# # Install sox
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel

# Install ffmpeg
# ubuntu
sudo apt-get install ffmpeg
# centos
sudo yum install ffmpeg

Use Docker

Run the following command to build a docker image from Dockerfile provided.

docker build -t inspiremusic .

Run the following command to start the docker container in interactive mode.

docker run -ti --gpus all -v .:/workspace/InspireMusic inspiremusic

Quick Start

Here is a quick example inference script for music generation.

cd InspireMusic
mkdir -p pretrained_models

# Download models
# ModelScope
git clone https://www.modelscope.cn/iic/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic-1.5B-Long
# HuggingFace
git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic-1.5B-Long

cd examples/music_generation
# run a quick inference example
sh infer_1.5b_long.sh

Here is a quick start running script to run music generation task including data preparation pipeline, model training, inference.

cd InspireMusic/examples/music_generation/
sh run.sh

One-line Inference

Text-to-music Task

One-line Shell script for text-to-music task.

cd examples/music_generation
# with flow matching, use one-line command to get a quick try
python -m inspiremusic.cli.inference

# custom the config like the following one-line command
python -m inspiremusic.cli.inference --task text-to-music -m "InspireMusic-1.5B-Long" -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." -c intro -s 0.0 -e 30.0 -r "exp/inspiremusic" -o output -f wav 

# without flow matching, use one-line command to get a quick try
python -m inspiremusic.cli.inference --task text-to-music -g 0 -t "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance." --fast True

Alternatively, you can run the inference with just a few lines of Python code.

from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
  set_env_variables()
  model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
  model.inference("text-to-music", "Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")

Music Continuation Task

One-line Shell script for music continuation task.

cd examples/music_generation
# with flow matching
python -m inspiremusic.cli.inference --task continuation -g 0 -a audio_prompt.wav
# without flow matching
python -m inspiremusic.cli.inference --task continuation -g 0 -a audio_prompt.wav --fast True

Alternatively, you can run the inference with just a few lines of Python code.

from inspiremusic.cli.inference import InspireMusicUnified
from inspiremusic.cli.inference import set_env_variables
if __name__ == "__main__":
  set_env_variables()
  model = InspireMusicUnified(model_name = "InspireMusic-1.5B-Long")
  # just use audio prompt
  model.inference("continuation", None, "audio_prompt.wav")
  # use both text prompt and audio prompt
  model.inference("continuation", "Continue to generate jazz music.", "audio_prompt.wav")

Models

Download Models

We strongly recommend that you download our pretrained InspireMusic models, especially InspireMusic-1.5B-Long for music generation.

# use git to download models,please make sure git lfs is installed.
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/InspireMusic-1.5B-Long.git pretrained_models/InspireMusic

Available Models

Currently, we open source the music generation models support 24KHz mono and 48KHz stereo audio. The table below presents the links to the ModelScope and Huggingface model hub. More models will be available soon.

Model name Model Links Remarks
InspireMusic-Base-24kHz model model Pre-trained Music Generation Model, 24kHz mono, 30s
InspireMusic-Base model model Pre-trained Music Generation Model, 48kHz, 30s
InspireMusic-1.5B-24kHz model model Pre-trained Music Generation 1.5B Model, 24kHz mono, 30s
InspireMusic-1.5B model model Pre-trained Music Generation 1.5B Model, 48kHz, 30s
InspireMusic-1.5B-Long model model Pre-trained Music Generation 1.5B Model, 48kHz, support long-form music generation more than 5mins
InspireSong-1.5B model model Pre-trained Song Generation 1.5B Model, 48kHz stereo
InspireAudio-1.5B model model Pre-trained Audio Generation 1.5B Model, 48kHz stereo
Wavtokenizer[1] (75Hz) model model An extreme low bitrate audio tokenizer for music with one codebook at 24kHz audio.
Music_tokenizer (75Hz) model model A music tokenizer based on HifiCodec[3] at 24kHz audio.
Music_tokenizer (150Hz) model model A music tokenizer based on HifiCodec[3] at 48kHz audio.

Our models have been trained within a limited budget, and there is still significant potential for performance improvement. We are actively working on enhancing the model's performance.

Basic Usage

At the moment, InspireMusic contains the training and inference codes for music generation. More tasks such as song generation, audio generation will be supported in the future.

Training

Here is an example to train LLM model, support BF16/FP16 training.

torchrun --nnodes=1 --nproc_per_node=8 \
    --rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
    inspiremusic/bin/train.py \
    --train_engine "torch_ddp" \
    --config conf/inspiremusic.yaml \
    --train_data data/train.data.list \
    --cv_data data/dev.data.list \
    --model llm \
    --model_dir `pwd`/exp/music_generation/llm/ \
    --tensorboard_dir `pwd`/tensorboard/music_generation/llm/ \
    --ddp.dist_backend "nccl" \
    --num_workers 8 \
    --prefetch 100 \
    --pin_memory \
    --deepspeed_config ./conf/ds_stage2.json \
    --deepspeed.save_states model+optimizer \
    --fp16

Here is an example code to train flow matching model, does not support FP16 training.

torchrun --nnodes=1 --nproc_per_node=8 \
    --rdzv_id=1024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
    inspiremusic/bin/train.py \
    --train_engine "torch_ddp" \
    --config conf/inspiremusic.yaml \
    --train_data data/train.data.list \
    --cv_data data/dev.data.list \
    --model flow \
    --model_dir `pwd`/exp/music_generation/flow/ \
    --tensorboard_dir `pwd`/tensorboard/music_generation/flow/ \
    --ddp.dist_backend "nccl" \
    --num_workers 8 \
    --prefetch 100 \
    --pin_memory \
    --deepspeed_config ./conf/ds_stage2.json \
    --deepspeed.save_states model+optimizer

Inference

Here is an example script to quickly do model inference.

cd InspireMusic/examples/music_generation/
sh infer.sh

Here is an example code to run inference with normal mode, i.e., with flow matching model for text-to-music and music continuation tasks.

pretrained_model_dir = "pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
  python inspiremusic/bin/inference.py --task $task \
      --gpu 0 \
      --config conf/inspiremusic.yaml \
      --prompt_data data/test/parquet/data.list \
      --flow_model $pretrained_model_dir/flow.pt \
      --llm_model $pretrained_model_dir/llm.pt \
      --music_tokenizer $pretrained_model_dir/music_tokenizer \
      --wavtokenizer $pretrained_model_dir/wavtokenizer \
      --result_dir `pwd`/exp/inspiremusic/${task}_test \
      --chorus verse \
      --min_generate_audio_seconds 8 \
      --max_generate_audio_seconds 30 
done

Here is an example code to run inference with fast mode, i.e., without flow matching model for text-to-music and music continuation tasks.

pretrained_model_dir = "pretrained_models/InspireMusic/"
for task in 'text-to-music' 'continuation'; do
  python inspiremusic/bin/inference.py --task $task \
      --gpu 0 \
      --config conf/inspiremusic.yaml \
      --prompt_data data/test/parquet/data.list \
      --flow_model $pretrained_model_dir/flow.pt \
      --llm_model $pretrained_model_dir/llm.pt \
      --music_tokenizer $pretrained_model_dir/music_tokenizer \
      --wavtokenizer $pretrained_model_dir/wavtokenizer \
      --result_dir `pwd`/exp/inspiremusic/${task}_test \
      --chorus verse \
      --fast \
      --min_generate_audio_seconds 8 \
      --max_generate_audio_seconds 30 
done

Hardware & Execution Time

Previous test on H800 GPU, InspireMusic-1.5B-Long could generate 30 seconds audio with real-time factor (RTF) around 1.6~1.8. For normal mode, we recommend using hardware with at least 24GB of GPU memory for better experience. For fast mode, 12GB GPU memory is enough. If you want to generate longer audio, you may increase the --max_generate_audio_seconds inference parameter with larger GPU memory.

Roadmap

  • 2024/12

    • 75Hz InspireMusic-Base model for music generation
  • 2025/01

    • Support to generate 48kHz
    • 75Hz InspireMusic-1.5B model for music generation
    • 75Hz InspireMusic-1.5B-Long model for long-form music generation
  • 2025/02

    • Technical report v1
    • Provide Dockerfile
  • 2025/03

    • Support audio generation task
    • InspireAudio model for audio generation
  • 2025/05

    • Support song generation task
    • InspireSong model for song generation
  • TBD

    • Diverse sampling strategies
    • 25Hz InspireMusic model
    • Runtime SDK
    • Support streaming inference mode
    • Support more diverse instruction mode, multi-lingual instructions
    • InspireSong trained with more multi-lingual data
    • More...

Citation

@misc{InspireMusic2025,
      title={InspireMusic: Integrating Super Resolution and Large Language Model for High-Fidelity Long-Form Music Generation}, 
      author={Chong Zhang and Yukun Ma and Qian Chen and Wen Wang and Shengkui Zhao and Zexu Pan and Hao Wang and Chongjia Ni and Trung Hieu Nguyen and Kun Zhou and Yidi Jiang and Chaohong Tan and Zhifu Gao and Zhihao Du and Bin Ma},
      year={2025},
      eprint={2503.00084},
      archivePrefix={arXiv},
      primaryClass={cs.SD},
      url={https://arxiv.org/abs/2503.00084}, 
}

Friend Links

Checkout some awesome open-source projects from Tongyi Lab, Alibaba Group.

Demo Demo Demo

Community & Discussion

  • Welcome to join our DingTalk and WeChat groups to share and discuss algorithms, technology, and user experience feedback. You may scan the following QR codes to join our official chat groups accordingly.

FunAudioLLM in DingTalk InspireMusic in WeChat
Light Light

  • Github Discussion. For sharing feedback and asking questions.
  • GitHub Issues. For sharing issues and suggestions using InspireMusic and feature proposals.

Acknowledgement

  1. We borrowed a lot of code from CosyVoice[4].
  2. We borrowed a lot of code from WavTokenizer[1].
  3. We borrowed a lot of code from AcademiCodec[3].
  4. We borrowed a lot of code from FunASR.
  5. We borrowed a lot of code from FunCodec.
  6. We borrowed a lot of code from Matcha-TTS.
  7. We borrowed a lot of code from WeNet.

References

[1] Shengpeng Ji, Ziyue Jiang, Wen Wang, Yifu Chen, Minghui Fang, Jialong Zuo, Qian Yang, Xize Cheng, Zehan Wang, Ruiqi Li, Ziang Zhang, Xiaoda Yang, Rongjie Huang, Yidi Jiang, Qian Chen, Siqi Zheng, Wen Wang, Zhou Zhao, WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling, The Thirteenth International Conference on Learning Representations, 2025.

[2] Qwen: An Yang, Baosong Yang, Beichen Zhang, Binyuan Hui, Bo Zheng, Bowen Yu, Chengyuan Li, Dayiheng Liu, Fei Huang, Haoran Wei, Huan Lin, Jian Yang, Jianhong Tu, Jianwei Zhang, Jianxin Yang, Jiaxi Yang, Jingren Zhou, Junyang Lin, Kai Dang, Keming Lu, Keqin Bao, Kexin Yang, Le Yu, Mei Li, Mingfeng Xue, Pei Zhang, Qin Zhu, Rui Men, Runji Lin, Tianhao Li, Tianyi Tang, Tingyu Xia, Xingzhang Ren, Xuancheng Ren, Yang Fan, Yang Su, Yichang Zhang, Yu Wan, Yuqiong Liu, Zeyu Cui, Zhenru Zhang, Zihan Qiu, Qwen2.5 Technical Report, arXiv preprint arXiv:2412.15115, 2025.

[3] Yang, Dongchao, Songxiang Liu, Rongjie Huang, Jinchuan Tian, Chao Weng, and Yuexian Zou, Hifi-codec: Group-residual vector quantization for high fidelity audio codec, arXiv preprint arXiv:2305.02765, 2023.

[4] Du, Zhihao, Qian Chen, Shiliang Zhang, Kai Hu, Heng Lu, Yexin Yang, Hangrui Hu et al. Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens. arXiv preprint arXiv:2407.05407, 2024.

Disclaimer

The content provided above is for research purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.