MiniTorch 笔记

MiniTorch 学习笔记 | Fundamentals

前言

MiniTorch 的 Fundamental 部分主要讲的 Python 编程在 DL 的工程实现中的框架性方法,以及在这些方法下对基本数学运算的工程实现。官方 Doc 是先讲的 Assignment,然后是其背后的脚手架介绍,这样在做 Assignment 的时候有很多地方会不明其所以然。

本文将其倒过来总结,先讲脚手架,再讲 Assignment,逻辑上会更加清晰。下面是本文结构:

  • Contributing | 进行开源项目贡献时的应知应会
  • Property Testing | 模块测试方法
  • Modules | 深度学习中的模块化思想
  • Functional | 函数式编程
  • Visualization | 利用可视化方法调试代码
  • Assignment 0 | 基础数学运算的工程实现

1 Contributing

这部分主要讲如何向 MiniTorch 官方分支贡献代码。

1.1 Style Guide

MiniTorch 推荐用 Rust 实现的 Ruff 作为 Python formatter,看起来确实很快。https://user-images.githubusercontent.com/1309177/232603514-c95e9b0f-6b31-43de-9a80-9e844173fd6a.svg

1.2 Testing

MiniTorch 使用 pytest 进行测试。

repo 源码中以 test 开头的文件均作为测试文件,使用修饰符 @pytest.mark. 的函数均作为测试函数,比如:

1
2
3
4
5
6
7
# tests/test_module.py

@pytest.mark.task0_4
def test_module_forward() -> None:
    mod = ModuleRun()
    assert mod.forward() == 10
    assert mod() == 10
  • task0_4 group 进行测试时,使用 pytest -m task0_4 命令。

  • 对整个 test_module.py 中的 function 进行测试,使用 pytest tests/test_module.py 命令

  • 对某个 function 进行测试,使用 pytest tests/test_module.py -k test_stacked_demo 命令

1.3 Type Checking

现代的 Python 编程 一般都会有静态类型检查,以确保输入输出是符合设计规范的。MiniTorch 用的微软的 Pyright 静态类型检查库。

下面是两个在实现 fuction 时进行类型声明的例子:

1
2
3
4
5
def mul(x: float, y: float) -> float:
   # ...

def negList(ls: Iterable[float]) -> Iterable[float]:
   # ...

mul() 函数接受两个 float 类型的参数,返回一个 float 类型的结果。negList() 函数接受一个 float 类型的可迭代对象,返回一个 float 类型的可迭代对象。

1.4 Documentation

文档注解没啥好讲的,参考 Google 的 Python Docstrings 规范。

1.5 Pre-Commit

暂略(等用到时再回来补充)。

1.6 Continuous Integration (CI)

暂略(等用到时再回来补充)。

2 Property Testing

对于大规模的测试,手动构造的案例在数据、规模、类型等方面很难覆盖全面。MiniTorch 介绍了一种高效便捷的方案,利用 Python 的 Hypothesis 库进行属性测试。

其基本用法如下:

1. 随机数据

1
2
3
4
5
6
@given(integers(), integers())
def test_add(a, b):
    # Check same as slow system add
    assert my_add(a, b) == a + b
    # Check that order doesn't matter
    assert my_add(a, b) == my_add(b, a)

2. 随机数据 + 特殊案例

1
2
3
4
5
6
7
@given(integers(), integers())
@example(5, 7)
def test_add2(a, b):
    # Check same as slow system add
    assert my_add(a, b) == a + b
    # Check that order doesn't matter
    assert my_add(a, b) == my_add(b, a)

3. 指定数据类型和范围

1
2
3
4
5
6
7
8
@given(
    lists(small_floats, min_size=5, max_size=500),
    lists(small_floats, min_size=1, max_size=100),
)
def test_sum_distribute(ls1: List[float], ls2: List[float]) -> None:
    sum1 = sum(ls1) + sum(ls2)
    sum2 = sum(addLists(ls1, ls2))
    assert is_close(sum1, sum2)

3 Modules

MiniTorch 中的 Module 类是一个抽象类,用于实现深度学习中的模块化思想。Module 类的实例可以包含多个子模块,每个子模块可以包含多个实例变量。Module 类的实例可以通过 parameters() 方法获取所有实例变量,通过 named_parameters() 方法获取所有实例变量的名字和值。

Modules 类一般会存储三个属性:

  • parameters: 当前模块的所有参数
  • user_data:当前模块的用户数据
  • modules:子模块

比如:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class OtherModule(Module):
    pass


class MyModule(Module):
    def __init__(self):
        # Must initialize the super class!
        super().__init__()

        # Type 1, a parameter.
        self.parameter1 = Parameter(15)

        # Type 2, user data
        self.data = 25

        # Type 3. another Module
        self.sub_module = OtherModule()

MyModule 类的实例有三个实例变量:parameter1, data, sub_module。其中,parameter1 是一个 Parameter 类的实例,data 是一个普通的用户数据,sub_module 是一个 Module 子类的实例。

因此,Module 类会常用作类的嵌套定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Module1(Module):
    def __init__(self):
        super().__init__()
        self.p1 = Parameter(5)
        self.a = Module2()
        self.b = Module3()


class Module2(Module):
    def __init__(self):
        super().__init__()
        self.p2 = Parameter(10)


class Module3(Module):
    def __init__(self):
        super().__init__()
        self.c = Module4()


class Module4(Module):
    def __init__(self):
        super().__init__()
        self.p3 = Parameter(15)


Module1().named_parameters()

然后每个 Module 实例会有一个 mode 变量,用于标记当前模块的状态,train() 方法用于将当前模块及其所有子孙模块设置为训练状态,eval() 方法用于将当前模块及其所有子孙模块设置为评价状态。

4 Functional

Python 中,把为输入参数是 function,返回参数也是 function 的函数称为 高阶函数,它们被定义为一个统一的对象类型: Callable

MiniTorch 中实现了几个高阶函数,包括 map, zipWithreduce

在 Python 中,一个函数可以被定义为一个 callable 的对象,比如:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def add(a: float, b: float) -> float:
    return a + b

# v 是一个 Callable 类型的对象,接受两个 float 类型的参数,返回一个 float 类型的结果
v: Callable[[float, float], float] = add

def mul(a: float, b: float) -> float:
    return a * b

v: Callable[[float, float], float] = mul

我们可以把 一个 function 作为参数传递给另一个 function,比如:

1
2
3
4
def combine3(
    fn: Callable[[float, float], float], a: float, b: float, c: float
) -> float:
    return fn(fn(a, b), c)

combine3() 函数接受一个 Callable 类型的参数 fn,以及三个 float 类型的参数 a, b, c,返回一个 float 类型的结果。combine3() 函数的逻辑是先将 ab 传递给 fn 函数,然后再将 fn(a,b) 的结果和 c 传递给 fn() 函数,返回最终结果。例如:

1
2
3
4
5
6
print(combine3(add, 1, 3, 5))
print(combine3(mul, 1, 3, 5))

# Output:
# 9
# 15

还可以利用输入参数的 Callable 对象,构建新的函数,并作为返回值,比如:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def combine3(fn):
    def new_fn(a: float, b: float, c: float) -> float:
        return fn(fn(a, b), c)

    return new_fn

add3: Callable[[float, float, float], float] = combine3(add)
mul3: Callable[[float, float, float], float] = combine3(mul)

print(add3(1, 3, 5))
print(mul3(1, 3, 5))
# Output: 9
# Output: 15

再来一个稍微复杂点的例子,用 Callable 对象实现一个 Filter 函数,即一个函数接受一个 Filter 函数和一个可迭代对象,返回这个可迭代对象中满足 Filter 函数的结果:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def filter(fn: Callable[[float], bool]) -> Callable[[Iterable[float]], Iterable[float]]:
    def apply(ls: Iterable[float]):
        ret = []
        for x in ls:
            if fn(x):
                ret.append(x)
        return ret

    return apply

def is_positive(x: float) -> bool:
    return x > 0

filter_positive = filter(is_positive)
print(filter_positive([1, -1, 2, -2, 3, -3]))

# Output: [1, 2, 3]

5 Visualization

(暂略)用到的时候再来补充。

6 Assignment 0

Assignment 0 主要是对 MiniTorch 的基本操作和模块的实现。

Task 6.1: Operators

主要是一些基本的数学运算,下面记录几个关键的运算符。

6.1.1: is_close(x, y)

1
2
3
4
5
# minitorch/operators.py

# 检查两个数的差值是否小于给定的阈值
def is_close(x, y, eps=1e-5):
    return abs(x - y) < eps

6.1.2: relu(x)

ReLU 函数是一个常用的激活函数,用于将输入值映射到 0 到正无穷之间,其数学定义为:$f(x) = \max(0, x)$。

详细见 ReLU 函数

1
2
3
4
5
# minitorch/operators.py

# ReLU 激活函数
def relu(x):
    return max(0, x)

6.1.3: sigmoid(x)

sigmoid 函数也是一个常用的激活函数,用于将输入值映射到 0 到 1 之间,其数学定义为:$f(x) = \frac{1}{1 + e^{-x}}$。

详细见 Sigmoid 函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# minitorch/operators.py

# Sigmoid 激活函数
def sigmoid(x):
    return 1 / (1 + math.exp(-x))


# 为了实现稳定性,我们可以使用下面的方式
def sigmoid(x):
    return 1 / (1 + math.exp(-x)) if x >= 0 else math.exp(x) / (1 + math.exp(x))

6.1.4: log_back(x)

1
2
3
4
5
# minitorch/operators.py

# 对数函数
def log_back(x):
    return math.log(x)

6.1.5: inv_back(x)

1
2
3
4
5
# minitorch/operators.py

# 反函数的导数
def inv_back(x):
    return 1 / x**2

6.1.6: relu_back(x)

relu_back() 函数是 ReLU 激活函数的反函数,用于反向传播时计算 ReLU 的梯度。

ReLU 函数求导的结果是:

当 $x > 0$ 时,$f’(x) = 1$;当 $x \leq 0$ 时,$f’(x) = 0$。

当 $x > 0$ 时,relu_back(x, y) 返回 y(因为梯度是 1);当 $x \leq 0$ 时,返回 0

1
2
3
4
5
# minitorch/operators.py

# ReLU 反函数
def relu_back(x, y):
    return y if x > 0 else 0

Task 6.2: Testing and Debugging

这里主要讲一下测试 sigmoid 函数时遇到的问题。

sigmoid 函数的数学定义为:

$$ \sigmoid(x) = \frac{1}{1 + e^{-x}} $$

在测试时,发现 sigmoid 发生指数上溢,导致测试没过。当 正值x 很大时,分母直接变成1了,导致数学上sigmoid函数的有界性取到了1。

暂时不知道咋办,把测试用例加上了边界值。包括 sigmoid() 函数的严格单调递增,测试时也只测试了单调增的情况,严格性过不了。 下面是测试用例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# tests/test_operators.py

@pytest.mark.task0_2
@given(small_floats)
def test_sigmoid(a: float) -> None:
    """Check properties of the sigmoid function, specifically
    * It is always between 0.0 and 1.0.
    * one minus sigmoid is the same as sigmoid of the negative
    * It crosses 0 at 0.5
    * It is strictly increasing.
    """
    # TODO: Implement for Task 0.2.
    
    try:
        # 值域
        assert sigmoid(a) <= 1.0
        assert sigmoid(-a) >= 0.0
        
        # 对称性(指数函数精确度为 0.01)
        assert is_close(1.0 - sigmoid(a), sigmoid(-a))
        
        # 过定点
        assert sigmoid(0.0) == 0.5
        
        # (严格)单调递增
        assert sigmoid(a) <= sigmoid(a + 1.0)
    except:
        raise NotImplementedError("The sig() operator is implemented incorrectly")

Task 6.3: Functional Python

这里带我们实现了几个 Python 的高阶函数,包括 map, zipWithreduce。以前没用过,发现确实很有意思。

6.3.1: map

map 高阶函数接受一个函数 fn, 和一个可迭代对象 arr,返回一个新的可迭代对象。其中,fn 函数接受含有 1 个 float 类型的参数,返回一个 float 类型的结果。而 高阶函数 map 是将原可迭代对象 arr 中的元素逐一映射到 fn 函数的参数中,然后将这些结果塞到返回的列表中。

1
2
3
4
# minitorch/operators.py

def map(fn: Callable[[float], float], arr: Iterable[float]) -> Iterable[float]:
    return [fn(a) for a in arr]

6.3.2: zipWith

zipWith 高阶函数接受一个函数 fn, 和两个可迭代对象 arr1arr2,返回一个新的可迭代对象。其中,fn 函数接受含有 2 个 float 类型的参数,返回一个 float 类型的结果。而 高阶函数 zipWith 是将原可迭代对象 arr1arr2 中的元素分别逐一合并成一个元组,然后将这些元组塞到返回的列表中。

1
2
3
4
# minitorch/operators.py

def zipWith(fn: Callable[[float, float], float], arr1: Iterable[float], arr2: Iterable[float]) -> Iterable[float]:
    return [fn(a, b) for a, b in zip(arr1, arr2)]

6.3.3: reduce

reduce 高阶函数接受一个函数 fn, 和一个可迭代对象 arr,返回一个新的可迭代对象。其中,fn 函数接受含有 2 个 float 类型的参数,返回一个 float 类型的结果。而 高阶函数 reduce 是将原可迭代对象 arr 中的元素逐一进行 fn 函数的操作,这些 fn 函数的结果将进行累加,最终返回一个结果。

1
2
3
4
# minitorch/operators.py

def reduce(fn: Callable[[float, float], float], arr: Iterable[float], init: float) -> float:
    return builtins.sum([fn(init, a) for a in arr]) # 使用内置 sum() 函数求和

下面用上面的几个高阶函数实现 negList(), addLists(), sum() 函数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# minitorch/operators.py

def negList(x):
    return map(neg, x)

def addLists(x, y):
    return zipWith(add, x, y)

def sum(x):
    return reduce(add, x, 0)

其中,neg()add() 函数分别是取负数和加法函数:

1
2
3
4
5
6
7
# minitorch/operators.py

def neg(x):
    return -x

def add(x, y):
    return x+y

Task 6.4: Modules

这部分是一个树状类模块的部分功能实现。包括:Train() 训练函数,eval() 评价/测试函数,以及获取当前模块及所有子孙模块的实例变量的参数 parameters() 和列表序列 named_parameters()

由于 Module 类是可以嵌套定义的,因此当使用上述函数时,需要确保所有子孙 Module 实例均调用被 call 的函数,因此我们要递归实现这些方法。

6.4.1: Train()

Module 类内定义了实例变量 training,用于标记当前模块是否处于训练状态。Train() 函数用于将当前模块及其所有子孙模块设置为训练状态。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# minitorch/module.py

# modules() 是类内已默认实现的方法
def modules(self) -> Sequence[Module]:
    """Return the direct child modules of this module."""
    m: Dict[str, Module] = self.__dict__["_modules"] # 获取当前模块的所有子模块,用 __dict__ 获取当前实例的"_modules"属性
    return list(m.values())

def train(self) -> None:
    """Set the mode of this module and all descendent modules to `train`."""
    # TODO: Implement for Task 0.4.
    self.training = True            # 设置当前模块的 training 为 True
    for module in self.modules():   # 遍历当前实例的所有子模块
        module.train()              # 递归调用子模块的 train() 方法,从而设置所有子孙模块的 training 为 True

6.4.2: eval()

eval() 函数用于将当前模块及其所有子孙模块设置为评价/测试状态,其实就是将当前实例的所有子孙 Module 实例的 training 实例变量设置为 False

1
2
3
4
5
6
7
8
# minitorch/module.py

def eval(self) -> None:
    """Set the mode of this module and all descendent modules to `eval`."""
    # TODO: Implement for Task 0.4.
    self.training = False           # 思路同上
    for module in self.modules():
        module.eval()

6.4.3: parameters()

named_parameters() 函数用于将当前模块实例以及其所有子孙模块中,所有实例变量以(name, value)的形式返回。

Modules 类是一个可以嵌套定义的树状类,举个例子:

模块 A1 实例内部有 1 个实例变量 p1 和 non_param,有两个子模块实例 A2 和 A3。A2 内部有 1 个实例变量 p2。A3 内部有 1 个子模块 A4。A4 内部有 1 个实例变量 P3,那么这个树状结构的模块定义如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# minitorch/module.py

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
    """Collect all the parameters of this module and its descendents.
    Returns
    -------
        The name and `Parameter` of each ancestor parameter.
    """
    # TODO: Implement for Task 0.4.
    named_params: list[Tuple[str, Parameter]] = []
    for name, param in self._parameters.items(): # 获取当前模块的所有实例变量
        named_params.append((name, param))
    for module_name, module in self._modules.items():   # 遍历当前模块的所有子模块
        for name, param in module.named_parameters():    # 递归调用子模块的 named_parameters() 方法
            full_name = f"{module_name}.{name}"
            named_params.append((full_name, param))
    return named_params

6.4.4: parameters()

parameters() 函数用于将当前模块实例以及其所有子孙模块中,所有实例变量以列表的形式返回,逻辑上和 named_parameters() 函数类似,只是返回的是 Parameter 实例。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# minitorch/module.py

def parameters(self) -> Sequence[Parameter]:
    """Enumerate over all the parameters of this module and its descendents."""
    # TODO: Implement for Task 0.4.
    params: list[Parameter] = []
    for param in self._parameters.values(): # 获取当前模块的所有实例变量
        params.append(param)
    for module in self._modules.values():   # 遍历当前模块的所有子模块
        params.extend(module.parameters())
    return params
给作者倒杯卡布奇诺 ~
Albresky 支付宝支付宝
Albresky 微信微信