Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTMs) are types of neural networks designed to handle sequential data, like music. In music generation, their ability to remember and utilize previous inputs is crucial. RNNs and LSTMs are unique in terms of maintaining information over time. They can learn patterns in music sequences, enabling the generation of coherent and stylistically consistent musical pieces.
Quick Read (<5min): Understanding the Basics of RNNs and LSTMs
RNNs work by processing sequences step by step, maintaining information from previous steps to inform the current one. RNN achieves this with a simple neural network layer that combines the current input with the previous step’s output in each step, so that outputs in the history are taken into context. However, RNNs struggle with long sequences due to the “vanishing gradient problem,” where they lose track of earlier data. Imagine a composer writing a symphony. Using an RNN is like the composer relying on their short-term memory, recalling only the last few notes to decide what to write next. This works well for short sequences but becomes challenging for longer compositions, as the memory of earlier notes fades.
LSTMs are upgraded RNNs, addressing this issue with a more complex structure. They include ‘gates’ — mechanisms that regulate the flow of information. These gates decide what to remember (or forget) from past data, making LSTMs adept at handling long sequences without losing relevant information from earlier steps. An LSTM is like a composer who keeps a detailed notebook. They can refer back to earlier themes or motifs, regardless of how much time has passed, ensuring that the entire composition is harmonious and interconnected. This notebook represents the ‘gates’ in LSTMs that manage what to remember and what to forget, allowing for more complex and coherent musical pieces.
Deeper Dive: A Breakdown of RNN and LSTM’s structure(10min)
The conceptual design of an RNN looks like the left part of the graph, where the network processes the input x through a loop, and spits out the output o. The ‘unfolded’ section shows each step in the sequence separately, which clarifies how the RNN handles each piece of data over time. Here, you can see the network’s recurrent transfer from one hidden state, Sₜ₋₁, to the next state, Sₜ, embodying the network’s memory that affects subsequent states. The weight matrix W is consistently applied through different time steps, retaining sequential context from the past.
At each step, the RNN performs a series of calculations to update its hidden state and produce an output, depicted in the above equations. First, the network takes in an input xₜ. In the context of music generation, this could be a chord or a music piece. Then. the new hidden state Sₜ is computed by the weighted sum of the current input xₜ and the previous hidden state Sₜ₋₁, with applying an activation function, like tanh or sigmoid. Again, the hidden states (blue nodes in the above graph) capture information from previous inputs due to the inclusion of weight matrix W. The output oₜ is generated by applying another set of weights V to the current hidden state Sₜ, which now embodies the processed information of the input at that step. Finally, the actual predicted value yₜ (e.g. the next node or chord in the sequence) is generated from the output oₜ through an activation function, usually designed to match the format of the target data.
The Vanishing Gradient Problem
To understand the vanishing gradient problem, we need to understand how RNNs are trained. Imagine you’re trying to solve a puzzle by working backward from the last piece to the first. You look at each piece and figure out how it fits with the one before it, step by step. The training of RNN is similar conceptually, which is called backpropagation. They start from the end of a sequence of data and work back towards the beginning, trying to understand how each piece of data was influenced by the previous one; then they adjust the network in the direction towards the most favorable result.
However, there’s a hitch in this process. As they move backward from the last data point to the first, the information they carry gets smaller and smaller. This is because each step involves a bit of multiplication, and if you multiply by numbers less than one repeatedly, the result gets tinier each time. The graph above represents how well the network retains information from previous time steps. As time progresses, the colors fade, symbolizing how the network’s memory of the initial information fades away. By the time the model reaches a far future point (like t=100), the network almost forgot about the previous information.
It’s like the network’s ability to learn from the early data ‘vanishes’ over time. Since the learning is based on how well the network can adjust from its mistakes (which relies on the gradient, a concept from calculus that measures change), if this gradient becomes very small or vanishes, the network can’t learn effectively from the earlier data points.
LSTMs and How It Solves the Problem
LSTMs combat the vanishing gradient problem primarily through their cell state, which acts like a highway for gradients, allowing them to flow unchanged and thus maintaining a steady gradient over many time steps. The gates of an LSTM — forget, input, and output — regulate the flow of information and gradients: the forget gate selectively removes information that’s no longer relevant, reducing the risk of saturating the gradients; the input gate controls the addition of new information into the cell state, ensuring that only significant updates are made; and the output gate influences the output activation, allowing the network to make fine-grained decisions about the data to pass on. These mechanisms together keep the gradients in a healthy range, preventing them from becoming too small (vanishing) or too large (exploding) as they are propagated back through the layers and time steps during training.
Let’s explain how an LSTM cell works with our composer analogy:
- Input (
x(t)
): This is the new musical note or chord that's being added to the sequence. It's like giving our composer a new note to consider in their improvisation. - Cell State (
C(t-1)
): This represents the composer's memory. Just as a musician remembers the flow of the music piece, the cell state holds information about the notes that have come before. - Hidden State (
h(t-1)
): This is akin to the feeling or the current vibe of the music that the musician is playing. It's the immediate past note or emotion that influences the next note.
Gates: These are like decisions that our composer makes:
- Forget Gate: It decides what parts of the previous memory are no longer relevant for the current step, like deciding that a certain musical theme is finished and should not influence the next notes.
- Input Gate: It decides which of the new information should be added to the memory, like introducing a new theme or variation in our music piece.
- Output Gate: The output gate determines what the next output should be, such as the next note to be played based on the current state of the music.
Operations:
- The
tanh
andsigmoid
functions are activation functions that help to regulate the values in the network, keeping them between -1 and 1. It’s like the composer decides how much importance to give to each piece of new information. - Outputs (
C(t)
andh(t)
): The final output is the next part of the music. The cell state (C(t)
) is the updated memory of our musician, which could represent the next part of the music piece that the LSTM has generated.
For a more comprehensive breakdown of LSTM structure, here is a great post by Yan that goes into each component in detail, and here is a great technical breakdown of how LSTMs work mathematically.
Real-world Application: How LSTM is used in Music Generation
Let’s see LSTM in action with TensorFlow's tutorial on music generation, which teaches how to create musical notes using a LSTM. It involves training a model with piano MIDI files from the MAESTRO dataset, where the model learns to predict the next note in a sequence. This tutorial includes complete code for parsing and creating MIDI files, and you could copy the codebook by yourself!
Below is the tutorial’s implementation of LSTM.
inputs = tf.keras.Input(input_shape)
x = tf.keras.layers.LSTM(128)(inputs)
outputs = {
'pitch': tf.keras.layers.Dense(128, name='pitch')(x),
'step': tf.keras.layers.Dense(1, name='step')(x),
'duration': tf.keras.layers.Dense(1, name='duration')(x),
}
model = tf.keras.Model(inputs, outputs)
The structure of the model is very simple: there is an Input layer, an LSTM layer, and then three output layers each for pitch, step, and duration.
The Input layer is essentially the starting point of a neural network. It is used to specify the shape and type of data that the network will receive. Think of it as the “doorway” through which your data enters the network. It doesn’t perform any computation or transformation on the data; it merely defines the type and size of the input you’ll be feeding into the model.
The following LSTM layer processes the data feed from the Input layer, with 128 ‘units’. Each unit is essentially a memory cell like the one we broke down above. In terms of visualization, you might see LSTM being presented as below:
Note that in the above graph, it’s the same memory cell A in different time steps, not three memory cells. To imagine the “128 units” in model specification, try to visualize 128 of these LSTM cells working in parallel for a single time step, each capable of learning from the input data.
The choice of unit number is a balance between computational efficiency and the complexity of patterns the model can learn. More units could potentially capture more complex patterns or more nuanced features in the data, but it comes with a higher computational cost and memory usage; there’s also a higher risk of overfitting, where the model learns the training data too well, including the noise, and performs poorly on unseen data. Conversely, fewer units come at the risk of not capturing complex patterns effectively.
The above graph is the output, or generated notes, of our LSTM. You can see what the output layer produces here.
- The pitch output layer has 128 units, which would output a one-hot encoded vector or a probability distribution across 128 possible pitch classes. The highest probability would be favored for (but not always become) the chosen pitch for a given note. The high number of units suggests that the model is trying to choose from a large number of possible pitches.
- The step output layer, with a single unit, indicates the time interval before the next note is played, which essentially determines when a note should start.
- The duration output layer, with a single unit, would predict how long each note should be sustained. In the data, some durations are zero, which could indicate pauses or rests between notes.
For intuition, the generated notes look like the above. You can listen to the output by going to the tutorial site.
Conclusion
In conclusion, RNNs and LSTMs serve as the backbone of AI music generation by handling sequential data with an aptitude for memory and learning from context. RNNs lay the groundwork by processing sequences and carrying forward information, but they falter with longer sequences due to the vanishing gradient problem. LSTMs excel by introducing a sophisticated architecture that meticulously manages memory through gates, allowing for the retention of important long-term information and the discarding of the irrelevant, thereby solving the vanishing gradient issue. This makes LSTMs the “composer’s detailed notebook” in the symphony of AI music generation, ensuring that every note and chord contributes meaningfully to the final composition, resonating with coherence and stylistic consistency that would otherwise be unattainable with simpler models.
References:
Zhu, Juncheng & Yang, Zhile & Mourshed, Monjur & Guo, Yuanjun & Zhou, Yimin & Chang, Yan & Wei, Yanjie & Feng, Shengzhong. (2019). Electric Vehicle Charging Load Forecasting: A Comparative Study of Deep Learning Approaches. Energies. 12. 2692. 10.3390/en12142692.
Engati. (n.d.). Vanishing gradient problem. Engati. Retrieved December 15, 2023, from https://www.engati.com/glossary/vanishing-gradient-problem
Data Basecamp. (2022). Long Short-Term Memory Networks (LSTM)- simply explained! Data Basecamp. https://databasecamp.de/en/ml/lstms
TensorFlow. (n.d.). Generate music with an RNN. Retrieved December 15, 2023, from https://www.tensorflow.org/tutorials/audio/music_generation