Swift Algorithm Club: Strassen’s Algorithm
In this tutorial, you’ll learn how to implement Strassen’s Matrix Multiplication in Swift. This was the first matrix multiplication algorithm to beat the naive O(n³) implementation, and is a fantastic example of the Divide and Conquer coding paradigm — a favorite topic in coding interviews. By Richard Ash.
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Sign up/Sign in
With a free Kodeco account you can download source code, track your progress, bookmark, personalise your learner profile and more!
Create accountAlready a member of Kodeco? Sign in
Contents
Swift Algorithm Club: Strassen’s Algorithm
30 mins
- Getting Started
- Understanding Matrix Multiplication
- How Does It Work?
- Challenge
- Multiplying Matrices Correctly
- Challenge
- Naive Matrix Multiplication
- Implementing matrixMultiply
- Analyzing Time Complexity
- Trying It Out
- Challenge 1
- Challenge 2
- Strassen’s Matrix Multiplication
- Implementing in Swift
- nextPowerOfTwo
- Challenge
- strassenRecursive
- Example
- Challenge
- Time Complexity
- Trying It Out
- Challenge
- Where to Go From Here?
Implementing in Swift
Now, to the implementation! Start by adding the following extension to the bottom of your playground:
extension Matrix {
public func strassenMatrixMultiply(by other: Matrix) -> Matrix {
// More code to come!
}
}
Now, just like in the naive implementation, you need to check that the first matrix’s column count is equal to the second matrix’s row count.
Replace the comment with the following:
precondition(columnCount == other.rowCount, """
Two matrices can only be matrix multiplied if the first column's count is \
equal to the second's row count.
""")
Time for some prep work! Add the following right below precondition
:
// 1
let n = Swift.max(rowCount, columnCount, other.rowCount, other.columnCount)
// 2
let m = nextPowerOfTwo(after: n)
// 3
var firstPrep = Matrix(rows: m, columns: m)
var secondPrep = Matrix(rows: m, columns: m)
// 4
for index in indices {
firstPrep[index.row, index.column] = self[index]
}
for index in other.indices {
secondPrep[index.row, index.column] = other[index]
}
Reviewing what’s going on here, you:
- Calculate the max count of the first or second matrix’s rows or columns.
- Find the next power of two after that number.
- Create two new matrices whose rows and columns are equal to the next power of two.
- Copy the elements from the first and second matrices into their respective prep matrices.
This seems like extra work: Why is this necessary? Great question! Next, you’ll investigate with an example.
Say you have a 3×2 matrix, A. How should you split this up? Should the middle row go with the top split or the bottom? Because there’s no even way to split this matrix, this edge case would need to be explicitly handled. While this seems difficult, the above prep work will remove this possibility completely.
By increasing the size of the matrix until it is a square matrix whose rows/columns are an even power of two, you ensure the edge case will never occur. Additionally, because the prep work only adds rows and columns with zeros, the result won’t change at all. 🎉
Now, to finish the method, add the following to strassenMatrixMultiply
:
// 1
let resultPrep = firstPrep.strassenRecursive(by: secondPrep)
// 2
var result = Matrix(rows: rowCount, columns: other.columnCount)
// 3
for index in result.indices {
result[index] = resultPrep[index.row, index.column]
}
// 4
return result
Here, you:
- Recursively compute the result matrix.
- Initialize a new matrix with the correct dimensions.
- Iterate through the result matrix and copy over the identical index from the prep matrix.
- Finally, you return the result!
Good job! Almost done. You have two unimplemented methods left, nextPowerOfTwo
and strassenRecursive
. You’ll tackle those now.
nextPowerOfTwo
Add the following method below strassenMatrixMultiply
:
private func nextPowerOfTwo(after n: Int) -> Int {
// 1
let logBaseTwoN = log2(Double(n))
// 2
let ceilLogBaseTwoN = ceil(logBaseTwoN)
// 3
let nextPowerOfTwo = pow(2, ceilLogBaseTwoN)
return Int(nextPowerOfTwo)
}
This method takes a number and returns the next power of two after that number if that number is not already an even power of two.
Reviewing, you:
- Calculate the log base 2 of the inputed number.
- Take the ceiling of
logBaseTwoN
. This rounds thelogBaseTwoN
up to the nearest whole number. - Calculate 2 to the
ceilLogBaseTwoN
power and convert it to anInt
.
Challenge
To get a better idea of how this method works, try applying it to the following numbers. Don’t use code! Write out each step and use WolframAlpha to do the calculations.
- 3
- 4
[spoiler title=”Solution”]
For 3:
log2(3) = 1.584 ceil(1.584) = 2 pow(2, 2) = 4 nextPowerOfTwo = 4
For 4:
log2(4) = 2 ceil(2) = 2 pow(2, 2) = 4 nextPowerOfTwo = 4 [/spoiler]
strassenRecursive
Next up, you need to implement strassenRecursive(by other:)
. Start by adding the following below nextPowerOfTwo
:
private func strassenRecursive(by other: Matrix) -> Matrix {
assert(isSquare && other.isSquare, "This method requires square matrices!")
guard rowCount > 1 && other.rowCount > 1 else { return self * other }
}
Here, you set the base case for the recursion: If either matrix has a row length of 1, then you just return the term-by-term multiplication of the two matrices.
Then, you need to split the input matrices into 8 submatrices. Add this initialization to the method:
// 1
let n = rowCount
let nBy2 = n / 2
// Assume submatrices are allocated as follows
// matrix self = |a b|, matrix other = |e f|
// |c d| |g h|
// 2
var a = Matrix(rows: nBy2, columns: nBy2)
var b = Matrix(rows: nBy2, columns: nBy2)
var c = Matrix(rows: nBy2, columns: nBy2)
var d = Matrix(rows: nBy2, columns: nBy2)
var e = Matrix(rows: nBy2, columns: nBy2)
var f = Matrix(rows: nBy2, columns: nBy2)
var g = Matrix(rows: nBy2, columns: nBy2)
var h = Matrix(rows: nBy2, columns: nBy2)
// 3
for i in 0..<nBy2 {
for j in 0..<nBy2 {
a[i, j] = self[i, j]
b[i, j] = self[i, j+nBy2]
c[i, j] = self[i+nBy2, j]
d[i, j] = self[i+nBy2, j+nBy2]
e[i, j] = other[i, j]
f[i, j] = other[i, j+nBy2]
g[i, j] = other[i+nBy2, j]
h[i, j] = other[i+nBy2, j+nBy2]
}
}
OK! You:
- Initialize two variables that are the size of the current matrix and the size of the submatrices.
- Initialize all eight submatrices.
- Update each of the eight submatrices with the appropriate elements from the original matrices. A really cool optimization here is that you only need to loop from
0..<nBy2
instead of0..<n
. Because the eight submatrices have the same indices, you're able to update all eight each step through thefor
loop!
Next, add the following to the bottom of the method:
let p1 = a.strassenRecursive(by: f-h) // a * (f - h)
let p2 = (a+b).strassenRecursive(by: h) // (a + b) * h
let p3 = (c+d).strassenRecursive(by: e) // (c + d) * e
let p4 = d.strassenRecursive(by: g-e) // d * (g - e)
let p5 = (a+d).strassenRecursive(by: e+h) // (a + d) * (e + h)
let p6 = (b-d).strassenRecursive(by: g+h) // (b - d) * (g + h)
let p7 = (a-c).strassenRecursive(by: e+f) // (a - c) * (e + f)
Here, you recursively compute the seven matrix multiplications required by Strassen's algorithm. They are the exact same seven you saw in the section above!
Next, add the following:
let result11 = p5 + p4 - p2 + p6 // p5 + p4 - p2 + p6
let result12 = p1 + p2 // p1 + p2
let result21 = p3 + p4 // p3 + p4
let result22 = p1 + p5 - p3 - p7 // p1 + p5 - p3 - p7
Above, you compute the submatrices of the result matrix. Now for the final step! Add the following:
var result = Matrix(rows: n, columns: n)
for i in 0..<nBy2 {
for j in 0..<nBy2 {
result[i, j] = result11[i, j]
result[i, j+nBy2] = result12[i, j]
result[i+nBy2, j] = result21[i, j]
result[i+nBy2, j+nBy2] = result22[i, j]
}
}
return result
Phew! Good work. In the final step, you combine the four submatrices into your result matrix. Notice that you only need to loop from
0..<nBy2
because, in each iteration of the loop, you can fill four elements of final result matrix. Yay for efficiency!