Swift Algorithm Club: Strassen’s Algorithm
In this tutorial, you’ll learn how to implement Strassen’s Matrix Multiplication in Swift. This was the first matrix multiplication algorithm to beat the naive O(n³) implementation, and is a fantastic example of the Divide and Conquer coding paradigm — a favorite topic in coding interviews. By Richard Ash.
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Contents
Swift Algorithm Club: Strassen’s Algorithm
30 mins
- Getting Started
- Understanding Matrix Multiplication
- How Does It Work?
- Challenge
- Multiplying Matrices Correctly
- Challenge
- Naive Matrix Multiplication
- Implementing matrixMultiply
- Analyzing Time Complexity
- Trying It Out
- Challenge 1
- Challenge 2
- Strassen’s Matrix Multiplication
- Implementing in Swift
- nextPowerOfTwo
- Challenge
- strassenRecursive
- Example
- Challenge
- Time Complexity
- Trying It Out
- Challenge
- Where to Go From Here?
Naive Matrix Multiplication
Start by downloading the materials using the Download Materials button found at the top and bottom of this tutorial. Open StrassenAlgorithm-starter.playground.
Instead of dealing with implementing the specifics of a matrix in Swift, the starter project includes helper methods and a Matrix
class to help you focus on learning matrix multiplication and Strassen’s algorithm.
Subscript Methods
-
subscript(row:column:)
: Returns the element at a specified row and column. -
subscript(row:)
: Returns the row at a specified index. -
subscript(column:)
: Returns the column at a specified index.
Term-by-Term Matrix Math
-
* (lhs:rhs:)
: Multiplies two matrices term-by-term. -
+ (lhs:rhs:)
: Adds two matrices term-by-term. -
- (lhs:rhs:)
: Subtracts two matrices term-by-term.
Array
-
dot(_:)
: Computes the dot product with a specified array and returns the result.
Functions
-
printMatrix(:name:)
: Prettily prints a specified matrix to the console.
For more details, look at Matrix.swift
and Array+Extension.swift
, both located under Sources. The file contains the implementation and documentation of all the methods and functions above!
Implementing matrixMultiply
To begin implementing matrixMultiply
, add a new extension to your playground with the following method:
extension Matrix {
// 1
public func matrixMultiply(by other: Matrix) -> Matrix {
// 2
precondition(columnCount == other.rowCount,
"""
Two matrices can only be matrix multiplied if the first \
column's count is equal to the second's row count.
""")
}
}
Reviewing your work:
- You created an extension of
Matrix
with a public method namedmatrixMultiply(by:)
. - Check to see if the current matrix’s column count matches the other matrix’s row count. Recall that this is a requirement, as highlighted when you learned about matrix multiplication.
Next, add the following to matrixMultiply(by:)
below precondition
:
// 1
var result = Matrix(rows: rowCount, columns: other.columnCount)
// 2
for index in result.indices {
// 3
let ithRow = self[row: index.row]
let jthColumn = other[column: index.column]
// 4
result[index] = ithRow.dot(jthColumn)
}
return result
Going over what you just added, you:
- Initialized the result matrix to have dimensions be the first matrix’s row count by the second matrix’s column count.
- Looped through the indices of the matrix.
Matrix
is aMutableCollection
, so this is exactly the same as looping through the indices of a regularArray
. - Initialized constants for both the ith row from the first matrix and the jth column from the second matrix.
- Set the element at
result[index]
to be the dot product between the ith row of the first matrix and jth column of the second matrix.
Analyzing Time Complexity
Next, you’ll analyze the time complexity of this implementation. The algorithm above runs in O(n³) time.
But how… ? You only used one for
loop!
There are actually three loops in this implementation. By using the Collection
iteration and Array
method dot
, you’ve hidden two. Take a closer look at the following line:
for index in result.indices {
This line deceptively contain TWO for
loops! The matrix has n rows and n columns, iterating over each row and column is an O(n²) operation.
Next, look at the following:
result[index] = ithRow.dot(jthColumn)
The dot product takes O(n) time because you need to loop over an n length row. Since it’s embedded in a for
loop, you need to multiply the two time complexities together, resulting in O(n³).
Trying It Out
Add the following outside of the extension at the bottom of your playground:
var A = Matrix<Int>(rows: 2, columns: 4)
A[row: 0] = [2, 1, -1, 0]
A[row: 1] = [0, 10, 0, 0]
printMatrix(A, name: "A")
Here, you initialize a 2×4 matrix named A, and update its rows using subscript(row:)
. You then print the result matrix to the console with printMatrix(:name:)
. You should see the following output on the console:
Matrix A:
2 1 -1 0
0 10 0 0
Challenge 1
Initialize a 4×2 matrix named B that prints the following to the console:
Matrix B:
3 4
2 1
-1 2
2 7
This time, use subscript(column:)
to update the matrix’s elements.
[spoiler title=”Solution”]
var B = Matrix<Int>(rows: 4, columns: 2)
B[column: 0] = [3, 2, -1, 2]
B[column: 1] = [4, 1, 2, 7]
printMatrix(B, name: "B")
[/spoiler]
Next, add the following below printMatrix(B, name: "B")
:
let C = A.matrixMultiply(by: B)
printMatrix(C, name: "C")
You should see the following output:
Matrix C:
9 7
20 10
Cool! Now how do you know that this is the correct matrix? You could write it out yourself (if you want to practice), but that approach becomes quite challenging as the number of rows and columns increases. Fortunately, you can check your answer on WolframAlpha and indeed Matrix C is correct.
Challenge 2
1. Initialize the following 3×3 matrices, D and E.
2. Compute their matrix multiplication, F.
3. Print F to the console.
Matrix D:
1 2 3
3 2 1
1 2 3
Matrix E:
4 5 6
6 5 4
4 5 6
[spoiler title=”Solution”]
var D = Matrix<Int>(rows: 3, columns: 3)
D[row: 0] = [1, 2, 3]
D[row: 1] = [3, 2, 1]
D[row: 2] = [1, 2, 3]
var E = Matrix<Int>(rows: 3, columns: 3)
E[row: 0] = [4, 5, 6]
E[row: 1] = [6, 5, 4]
E[row: 2] = [4, 5, 6]
let F = D.matrixMultiply(by: E)
printMatrix(F, name: "F")
You should see the following output:
Matrix F:
28 30 32
28 30 32
28 30 32
[/spoiler]
On to Strassen’s algorithm!
Strassen’s Matrix Multiplication
Good job! You made it this far. Now for the fun bit: You’ll dive into Strassen’s algorithm. The basic idea behind Strassen’s algorithm is to split the two matrices, A and B, into eight submatrices and then recursively compute the submatrices of C. This strategy is called divide and conquer.
Consider the following:
There are eight recursive calls:
- a * e
- b * g
- a * f
- b * h
- c * e
- d * g
- c * f
- d * h
These combine to form the four quadrants of C.
This step alone, however, doesn’t improve the complexity. Using the Master Theorem with T(n) = 8T(n/2) + O(n²) you still get a time of O(n³).
Strassen’s insight was that you don’t actually need eight recursive calls to complete this process. You can finish the operation with seven recursive calls with a little bit of addition and subtraction.
Strassen’s seven calls are as follows:
- a * (f – h)
- (a + b) * h
- (c + d) * e
- d * (g – e)
- (a + d) * (e + h)
- (b – d) * (g + h)
- (a – c) * (e + f)
Now, you can compute matrix C’s new quadrants:
A great reaction right now would be !!??!?!?!!?! How does this even work?
Next, you’ll prove it!
1. First submatrix:
p5+p4-p2+p6 = (a+d)*(e+h) + d*(g-e) - (a+b)*h + (b-d)*(g+h)
= (ae+de+ah+dh) + (dg-de) - (ah+bh) + (bg-dg+bh-dh)
= ae+bg ✅
Exactly what you got the first time!
Now, on to proving the others.
2. Second submatrix:
p1+p2 = a*(f-h) + (a+b)*h
= (af-ah) + (ah+bh)
= af+bh ✅
3. Third submatrix:
p3+p4 = (c+d)*e + d*(g-e)
= (ce+de) + (dg-de)
= ce+dg ✅
4. Fourth submatrix:
p1+p5-p3-p7 = a*(f-h) + (a+d)*(e+h) - (c+d)*e - (a-c)*(e+f)
= (af-ah) + (ae+de+ah+dh) -(ce+de) - (ae-ce+af-cf)
= cf+dh ✅
Great! The math checks out!