Python 实现一个基于类的矩阵类(长文解析)

更新时间:

💡一则或许对你有用的小广告

欢迎加入小哈的星球 ,你将获得:专属的项目实战 / 1v1 提问 / Java 学习路线 / 学习打卡 / 每月赠书 / 社群讨论

截止目前, 星球 内专栏累计输出 90w+ 字,讲解图 3441+ 张,还在持续爆肝中.. 后续还会上新更多项目,目标是将 Java 领域典型的项目都整一波,如秒杀系统, 在线商城, IM 即时通讯,权限管理,Spring Cloud Alibaba 微服务等等,已有 3100+ 小伙伴加入学习 ,欢迎点击围观

前言

在编程与数据科学领域,矩阵运算是一个核心能力。无论是处理图像、数据分析,还是机器学习模型的构建,矩阵操作都贯穿始终。然而,手动编写矩阵运算的代码不仅繁琐,还容易出错。通过面向对象编程(OOP)的方式,我们可以将矩阵抽象为一个自定义类,从而高效地实现各种运算功能。本文将从零开始,逐步讲解如何用 Python 设计一个基于类的矩阵类,帮助读者理解面向对象编程与矩阵运算的结合逻辑。


矩阵类的设计思路

矩阵的基本概念

矩阵可以理解为一个二维表格,由行和列构成的数值集合。例如,一个 2×3 的矩阵可以表示为:
$$
\begin{bmatrix}
1 & 2 & 3 \
4 & 5 & 6
\end{bmatrix}
$$
在编程中,矩阵通常用嵌套列表(List of Lists)来存储。例如,上述矩阵可以用 Python 表示为 [[1, 2, 3], [4, 5, 6]]

面向对象编程的优势

通过将矩阵封装为一个类,我们可以:

  1. 封装数据与行为:将矩阵的数值和相关操作(如加法、乘法)绑定在一起。
  2. 提高代码复用性:通过继承或组合,可以轻松扩展功能(如添加转置、行列式计算等)。
  3. 增强可读性:通过类方法和属性,代码逻辑更清晰。

矩阵类的核心功能实现

第一步:初始化与基本结构

首先定义一个名为 Matrix 的类,并实现 __init__ 方法来初始化矩阵的数据和维度。

class Matrix:  
    def __init__(self, data):  
        """  
        初始化矩阵,检查输入数据是否为二维列表  
        :param data: 矩阵的数值列表,例如 [[1, 2], [3, 4]]  
        """  
        if not isinstance(data, list) or len(data) == 0 or not all(isinstance(row, list) for row in data):  
            raise ValueError("输入数据必须为非空的二维列表")  
        # 检查所有行的长度是否一致  
        row_length = len(data[0])  
        for row in data:  
            if len(row) != row_length:  
                raise ValueError("矩阵的每行元素个数必须相同")  
        self.data = data  
        self.rows = len(data)  
        self.cols = row_length  

关键点解析

  • __init__ 方法通过参数 data 接收矩阵的数值,并进行类型和格式校验。
  • 通过 self.rowsself.cols 记录矩阵的行数和列数,方便后续运算时使用。

第二步:实现字符串表示(__repr__ 方法)

为了让矩阵对象在打印时更直观,需要定义 __repr__ 方法。

    def __repr__(self):  
        return f"Matrix({self.rows}x{self.cols}):\n" + "\n".join([  
            " ".join(map(str, row)) for row in self.data  
        ])  

效果演示

m = Matrix([[1, 2], [3, 4]])  
print(m)  

第三步:实现加法与减法运算

矩阵的加减法要求两个矩阵的维度完全一致。通过重载 __add____sub__ 方法,可以实现自然的语法表达。

    def __add__(self, other):  
        if self.rows != other.rows or self.cols != other.cols:  
            raise ValueError("矩阵维度不匹配,无法相加")  
        # 逐元素相加  
        new_data = [  
            [self.data[i][j] + other.data[i][j]  
             for j in range(self.cols)]  
            for i in range(self.rows)  
        ]  
        return Matrix(new_data)  

    def __sub__(self, other):  
        # 减法逻辑与加法类似,仅符号不同  
        if self.rows != other.rows or self.cols != other.cols:  
            raise ValueError("矩阵维度不匹配,无法相减")  
        new_data = [  
            [self.data[i][j] - other.data[i][j]  
             for j in range(self.cols)]  
            for i in range(self.rows)  
        ]  
        return Matrix(new_data)  

示例代码

a = Matrix([[1, 2], [3, 4]])  
b = Matrix([[5, 6], [7, 8]])  
print(a + b)  

第四步:实现矩阵乘法(核心难点)

矩阵乘法的规则较为复杂:结果矩阵的第 (i) 行第 (j) 列元素,等于第一个矩阵的第 (i) 行与第二个矩阵第 (j) 列的对应元素乘积之和。此外,运算要求第一个矩阵的列数等于第二个矩阵的行数。

    def __matmul__(self, other):  
        """  
        矩阵乘法(使用 @ 运算符)  
        :param other: 另一个 Matrix 对象  
        :return: 新的 Matrix 对象  
        """  
        if self.cols != other.rows:  
            raise ValueError(  
                f"矩阵维度不匹配:{self.rows}x{self.cols} 与 {other.rows}x{other.cols} 无法相乘")  
        # 初始化结果矩阵的行和列  
        result_data = []  
        for i in range(self.rows):  
            new_row = []  
            for j in range(other.cols):  
                # 计算每个元素的点积  
                dot_product = sum(  
                    self.data[i][k] * other.data[k][j]  
                    for k in range(self.cols)  
                )  
                new_row.append(dot_product)  
            result_data.append(new_row)  
        return Matrix(result_data)  

示例代码

a = Matrix([[1, 2], [3, 4]])  
b = Matrix([[5, 6], [7, 8]])  
print(a @ b)  

第五步:实现转置操作

矩阵转置是将行和列互换的操作。例如,一个 2×3 的矩阵转置后变为 3×2。

    def transpose(self):  
        """  
        返回转置后的矩阵  
        """  
        # 使用 zip 和列表推导式简化代码  
        transposed = zip(*self.data)  
        # 将元组转换为列表  
        new_data = [list(row) for row in transposed]  
        return Matrix(new_data)  

示例代码

m = Matrix([[1, 2], [3, 4]])  
print(m.transpose())  

扩展功能与异常处理

支持标量运算

除了矩阵之间的运算,我们还可以让矩阵与标量(如数字)进行相加或相乘。

    def __mul__(self, scalar):  
        """  
        与标量相乘(例如 2 * matrix 或 matrix * 2)  
        """  
        if not isinstance(scalar, (int, float)):  
            raise TypeError("仅支持与数字相乘")  
        new_data = [  
            [element * scalar for element in row]  
            for row in self.data  
        ]  
        return Matrix(new_data)  

示例代码

m = Matrix([[1, 2], [3, 4]])  
print(m * 3)  

异常处理与友好提示

在运算过程中,通过抛出 ValueErrorTypeError,可以明确告知用户错误原因。例如:

try:  
    a = Matrix([[1, 2], [3, 4]])  
    b = Matrix([[5], [6]])  # 列数为1  
    print(a + b)  # 维度不匹配  
except ValueError as e:  
    print(e)  # 输出:"矩阵维度不匹配,无法相加"  

完整代码与测试案例

完整的 Matrix 类

class Matrix:  
    def __init__(self, data):  
        if not isinstance(data, list) or len(data) == 0 or not all(isinstance(row, list) for row in data):  
            raise ValueError("输入数据必须为非空的二维列表")  
        row_length = len(data[0])  
        for row in data:  
            if len(row) != row_length:  
                raise ValueError("矩阵的每行元素个数必须相同")  
        self.data = data  
        self.rows = len(data)  
        self.cols = row_length  

    def __repr__(self):  
        return f"Matrix({self.rows}x{self.cols}):\n" + "\n".join([  
            " ".join(map(str, row)) for row in self.data  
        ])  

    def __add__(self, other):  
        if self.rows != other.rows or self.cols != other.cols:  
            raise ValueError("矩阵维度不匹配,无法相加")  
        new_data = [  
            [self.data[i][j] + other.data[i][j]  
             for j in range(self.cols)]  
            for i in range(self.rows)  
        ]  
        return Matrix(new_data)  

    def __sub__(self, other):  
        if self.rows != other.rows or self.cols != other.cols:  
            raise ValueError("矩阵维度不匹配,无法相减")  
        new_data = [  
            [self.data[i][j] - other.data[i][j]  
             for j in range(self.cols)]  
            for i in range(self.rows)  
        ]  
        return Matrix(new_data)  

    def __matmul__(self, other):  
        if self.cols != other.rows:  
            raise ValueError(  
                f"矩阵维度不匹配:{self.rows}x{self.cols} 与 {other.rows}x{other.cols} 无法相乘")  
        result_data = []  
        for i in range(self.rows):  
            new_row = []  
            for j in range(other.cols):  
                dot_product = sum(  
                    self.data[i][k] * other.data[k][j]  
                    for k in range(self.cols)  
                )  
                new_row.append(dot_product)  
            result_data.append(new_row)  
        return Matrix(result_data)  

    def transpose(self):  
        transposed = zip(*self.data)  
        new_data = [list(row) for row in transposed]  
        return Matrix(new_data)  

    def __mul__(self, scalar):  
        if not isinstance(scalar, (int, float)):  
            raise TypeError("仅支持与数字相乘")  
        new_data = [  
            [element * scalar for element in row]  
            for row in self.data  
        ]  
        return Matrix(new_data)  

测试案例

a = Matrix([[1, 2], [3, 4]])  
b = Matrix([[5, 6], [7, 8]])  

print("加法结果:")  
print(a + b)  

print("\n乘法结果:")  
print(a @ b)  

print("\n转置后的矩阵:")  
print(a.transpose())  

print("\n标量乘法结果(3 * a):")  
print(3 * a)  

总结与扩展方向

通过本文的讲解,我们完成了以下目标:

  1. 类的结构设计:将矩阵的属性(如行数、列数)和行为(如加法、乘法)封装到一个类中。
  2. 面向对象的核心思想:通过方法重载(如 __add__)实现自然的语法,提升代码的可读性。
  3. 异常处理:确保代码在维度不匹配或输入错误时提供清晰的反馈。

未来扩展方向

  • 优化性能:使用 NumPy 库加速大规模矩阵运算。
  • 添加更多运算:如求逆矩阵、行列式、特征值等。
  • 支持矩阵的切片操作:例如提取子矩阵。

希望本文能帮助读者掌握如何将数学概念转化为可复用的 Python 类,并为后续学习高级算法或数据分析打下基础。编程的本质在于将复杂问题拆解为可管理的模块,而面向对象编程正是这一思想的完美体现。

最新发布