欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Leetcode 311. Sparse Matrix Multiplication (python),nuro面经

程序员文章站 2022-03-23 18:15:08
...

Leetcode 311. Sparse Matrix Multiplication

题目

Leetcode 311. Sparse Matrix Multiplication (python),nuro面经

解法:暴力

这边主要讲一下怎么做矩阵乘法。假设我们有矩阵A,B,结果是C,那么

C[i][k] += A[i][j]*B[j][k]

从矩阵乘法的维度变化来考虑就可以理解了。
遍历的顺序不重要,下面几种都是ok的

class Solution:
    def multiply(self, A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
        ans = [[0]*len(B[0]) for _ in range(len(A))]
        
        # for i in range(len(A)):
        #     for j in range(len(A[0])):
        #         for b_j in range(len(B[0])):
        #             ans[i][b_j] += A[i][j]*B[j][b_j]
        
        for i in range(len(A)):
            for b_i in range(len(B)):
                for b_j in range(len(B[0])):
                    ans[i][b_j] += A[i][b_i]*B[b_i][b_j]
        
        return ans

解法2:

利用hashmap来储存非0元素。这应该是面试中比较标准的解法。一个encode来把sparse_matrx转换成dense_matrix。然后对dense_matrix做乘法,最后用decode把dense_matrix的结果转换为sparse_matrix的结果

implementation很直观的

class Solution:
    def multiply(self, A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
        def encode(sparse_matrix):
            dense_matrix = {}
            for i in range(len(sparse_matrix)):
                for j in range(len(sparse_matrix[0])):
                    if sparse_matrix[i][j]:
                        dense_matrix[(i,j)] = sparse_matrix[i][j]
            return dense_matrix
        
        def decode(dense_matrix,row,col):
            sparse_matrix = [[0]*col for _ in range(row)]
            for (i,j),val in dense_matrix.items():
                sparse_matrix[i][j] = val
            
            return sparse_matrix
        
        A_dense = encode(A)
        B_dense = encode(B)
        ans_dense = collections.defaultdict(int)
        
        for (i,j) in A_dense.keys():
            for k in range(len(B[0])):
                if (j,k) in B_dense:
                    ans_dense[(i,k)] += A_dense[(i,j)]*B_dense[(j,k)]
                    
        return decode(ans_dense,len(A),len(B[0]))