怎么用 Matlab 实现 aurograd 自动微分

注意:

  1. 本文是之前的文章《投影梯度下降法(ee227c Lecture4 学习笔记)》的一部分。因为我觉得可能还会有人需要这个方法,为了方便被搜索引擎搜到,所以单独拿出来写一篇文章。如果要下载源码,请点击原文。
  2. 本文所描述的方法只适用于有显式数学表达式的函数,自变量可以是标量、向量、矩阵,但是必须有数学表达式。一般的编程函数,比如里面有for循环啊、if啊、递归等,是用不了的。

在查文档的时候,发现有gradientdiff2等函数,但是它们都不满足我的需要。在问LLM的时候,LLM给出了这样的回答:

LLMの胡言乱语

LLM完全理解了我的需要,然后给出了一个完全错误的回答。所以,接下来的内容,就是实现这个图里面的东西。如果你发现图里面的东西不是你想要的,那可以不用看了。

接下来利用稀疏逆协方差估计的背景介绍一下所用的方法。

稀疏逆协方差估计是一种统计方法,用于估计变量之间的依赖关系。该方法通过优化一个包含正则化项的目标函数来实现。其优化目标为: \[ \min _{X \in \mathbb{R}^{n \times n}, X \succeq 0}\langle S, X\rangle-\log \operatorname{det}(X)+\alpha\|X\|_1 \] 在这里,我们定义: \[ \langle S, X\rangle=\text{tr} (S^\top X) \]\[ \|X\|_1=\sum_{ij}|X_{ij}| \] 在原文中,使用到了pythonautograd功能。但是Matlab不带有这个功能,为此,我颇下了一番功夫,最终得到一个解决方案。

这个问题解决起来很棘手,因为函数的参数都是矩阵。如果只用 syms A 指令,生成的只是单个变量。例如:

1
2
>> syms S X
>> sparse_inv_cov_syms = trace(S' * X) - log(det(X)) + 0.1 * sum(abs(X),'all')

输出

1
2
3
sparse_inv_cov_syms =

abs(X)/10 - log(X) + X*conj(S)

这显然不是我们要的结果。这时,我们需要生成矩阵符号:

1
2
3
4
>> n=5
>> syms S [n n]
>> syms X [n n]
>> sparse_inv_cov_syms = trace(S' * X) - log(det(X)) + 0.1 * sum(abs(X),'all')

输出

1
2
3
sparse_inv_cov_syms =

abs(X1_1)/10 + abs(X1_2)/10 + abs(X1_3)/10 + abs(X1_4)/10 + abs(X1_5)/10 + abs(X2_1)/10 + ...

但是此时如果我们直接用MatlabFunction函数转换,会得到一个以25个变量为参数的函数:

1
>> ans=MatlabFunction(sparse_inv_cov_syms)

输出:

1
2
3
4
5
ans =

包含以下值的 function_handle:

@(S1_1,S1_2,S1_3,S1_4,S1_5,S2_1,S2_2,S2_3,S2_4,S2_5,S3_1,S3_2,S3_3,S3_4,S3_5,S4_1,S4_2,S4_3,S4_4,S4_5,S5_1,S5_2,S5_3,S5_4,S5_5,X1_1,X1_2,X1_3,X1_4,X1_5,X2_1,X2_2,X2_3,X2_4,X2_5,X3_1,X3_2,X3_3,X3_4,X3_5,X4_1,X4_2,X4_3,X4_4,X4_5,X5_1,X5_2,X5_3,X5_4,X5_5)...

这并不是我们需要的形式。在生成函数时,我们使用Vars字符串,然后把SX用大括号括起来,这样就行了。

1
>> ans=MatlabFunction(sparse_inv_cov_syms,"Vars",{S,X})

输出

1
2
3
4
5
ans =

包含以下值的 function_handle:

@(in1,in2)abs(in2(1))./1.0e+1+abs(in2(6))./1.0e+1+abs(in2(11))./1.0e+1+abs(in2(16))./1.0e+1+abs(in2(21))./1.0e+1+abs(in2(2))./1.0e+1+abs(in2(7))./1.0e+1+abs(in2(12))./1.0e+1+abs(in2(17))./1.0e+1+abs(in2(22))./1.0e+1+abs(in2(3))./1.0e+1+abs(in2(8))./1.0e+1+abs(in2(13))./1.0e+1+abs(in2(18))./1.0e+1+abs(in2(23))./1.0e+1+abs(in2(4))./1.0e+1+abs(in2(9))./1.0e+1+abs(in2(14))./1.0e+1+abs(in2(19))./1.0e+1+abs(in2(24))./1.0e+1+abs(in2(5))./1.0e+1+abs(in2(10))./1.0e+1+abs(in2(15))./1.0e+1+abs(in2(20))./1.0e+1+abs(in2(25))./1.0e+1-log(in2(1).*in2(7).*in2(13).*in2(19).*in2(25)-in2(1).*in2(7).*in2(13).*in2(24).*in2(20)-in2(1).*in2(7).*in2(18).*in2(14).*in2(25)+in2(1).*in2(7).*in2(18).*in2(24).*in2(15)+in2(1).*in2(7).*in2(23).*in2(14).*in2(20)-in2(1).*in2(7).*in2(23).*in2(19).*in2(15)-in2(1).*in2(12).*in2(8).*in2(19).*in2(25)+in2(1).*in2(12).*in2(8).*in2(24).*in2(20)+in2(1).*in2(12).*in2(18).*in2(9).*in2(25)-in2(1).*in2(12).*in2(18).*in2(24).*in2(10)-in2(1).*in2(12).*in2(23).*in2(9).*in2(20)+in2(1).*in2(12).*in2(23).*in2(19).*in2(10)+in2(1).*in2(17).*in2(8).*in2(14).*in2(25)-in2(1).*in2(17).*in2(8).*in2(24).*in2(15)-in2(1).*in2(17).*in2(13).*in2(9).*in2(25)+in2(1).*in2(17).*in2(13).*in2(24).*in2(10)+in2(1).*in2(17).*in2(23).*in2(9).*in2(15)-in2(1).*in2(17).*in2(23).*in2(14).*in2(10)-in2(1).*in2(22).*in2(8).*in2(14).*in2(20)+in2(1).*in2(22).*in2(8).*in2(19).*in2(15)+in2(1).*in2(22).*in2(13).*in2(9).*in2(20)-in2(1).*in2(22).*in2(13).*in2(19).*in2(10)-in2(1).*in2(22).*in2(18).*in2(9).*in2(15)+in2(1).*in2(22).*in2(18).*in2(14).*in2(10)-in2(6).*in2(2).*in2(13).*in2(19).*in2(25)+in2(6).*in2(2).*in2(13).*in2(24).*in2(20)+in2(6).*in2(2).*in2(18).*in2(14).*in2(25)-in2(6).*in2(2).*in2(18).*in2(24).*in2(15)-in2(6).*in2(2).*in2(23).*in2(14).*in2(20)+in2(6).*in2(2).*in2(23).*in2(19).*in2(15)+in2(6).*in2(12).*in2(3).*in2(19).*in2(25)-in2(6).*in2(12).*in2(3).*in2(24).*in2(20)-in2(6).*in2(12).*in2(18).*in2(4).*in2(25)+in2(6).*in2(12).*in2(18).*in2(24).*in2(5)+in2(6).*in2(12).*in2(23).*in2(4).*in2(20)-in2(6).*in2(12).*in2(23).*in2(19).*in2(5)-in2(6).*in2(17).*in2(3).*in2(14).*in2(25)+in2(6).*in2(17).*in2(3).*in2(24).*in2(15)+in2(6).*in2(17).*in2(13).*in2(4).*in2(25)-in2(6).*in2(17).*in2(13).*in2(24).*in2(5)-in2(6).*in2(17).*in2(23).*in2(4).*in2(15)+in2(6).*in2(17).*in2(23).*in2(14).*in2(5)+in2(6).*in2(22).*in2(3).*in2(14).*in2(20)-in2(6).*in2(22).*in2(3).*in2(19).*in2(15)-in2(6).*in2(22).*in2(13).*in2(4).*in2(20)+in2(6).*in2(22).*in2(13).*in2(19).*in2(5)+in2(6).*in2(22).*in2(18).*in2(4).*in2(15)-in2(6).*in2(22).*in2(18).*in2(14).*in2(5)+in2(11).*in2(2).*in2(8).*in2(19).*in2(25)-in2(11).*in2(2).*in2(8).*in2(24).*in2(20)-in2(11).*in2(2).*in2(18).*in2(9).*in2(25)+in2(11).*in2(2).*in2(18).*in2(24).*in2(10)+in2(11).*in2(2).*in2(23).*in2(9).*in2(20)-in2(11).*in2(2).*in2(23).*in2(19).*in2(10)-in2(11).*in2(7).*in2(3).*in2(19).*in2(25)+in2(11).*in2(7).*in2(3).*in2(24).*in2(20)+in2(11).*in2(7).*in2(18).*in2(4).*in2(25)-in2(11).*in2(7).*in2(18).*in2(24).*in2(5)-in2(11).*in2(7).*in2(23).*in2(4).*in2(20)+in2(11).*in2(7).*in2(23).*in2(19).*in2(5)+in2(11).*in2(17).*in2(3).*in2(9).*in2(25)-in2(11).*in2(17).*in2(3).*in2(24).*in2(10)-in2(11).*in2(17).*in2(8).*in2(4).*in2(25)+in2(11).*in2(17).*in2(8).*in2(24).*in2(5)+in2(11).*in2(17).*in2(23).*in2(4).*in2(10)-in2(11).*in2(17).*in2(23).*in2(9).*in2(5)-in2(11).*in2(22).*in2(3).*in2(9).*in2(20)+in2(11).*in2(22).*in2(3).*in2(19).*in2(10)+in2(11).*in2(22).*in2(8).*in2(4).*in2(20)-in2(11).*in2(22).*in2(8).*in2(19).*in2(5)-in2(11).*in2(22).*in2(18).*in2(4).*in2(10)+in2(11).*in2(22).*in2(18).*in2(9).*in2(5)-in2(16).*in2(2).*in2(8).*in2(14).*in2(25)+in2(16).*in2(2).*in2(8).*in2(24).*in2(15)+in2(16).*in2(2).*in2(13).*in2(9).*in2(25)-in2(16).*in2(2).*in2(13).*in2(24).*in2(10)-in2(16).*in2(2).*in2(23).*in2(9).*in2(15)+in2(16).*in2(2).*in2(23).*in2(14).*in2(10)+in2(16).*in2(7).*in2(3).*in2(14).*in2(25)-in2(16).*in2(7).*in2(3).*in2(24).*in2(15)-in2(16).*in2(7).*in2(13).*in2(4).*in2(25)+in2(16).*in2(7).*in2(13).*in2(24).*in2(5)+in2(16).*in2(7).*in2(23).*in2(4).*in2(15)-in2(16).*in2(7).*in2(23).*in2(14).*in2(5)-in2(16).*in2(12).*in2(3).*in2(9).*in2(25)+in2(16).*in2(12).*in2(3).*in2(24).*in2(10)+in2(16).*in2(12).*in2(8).*in2(4).*in2(25)-in2(16).*in2(12).*in2(8).*in2(24).*in2(5)-in2(16).*in2(12).*in2(23).*in2(4).*in2(10)+in2(16).*in2(12).*in2(23).*in2(9).*in2(5)+in2(16).*in2(22).*in2(3).*in2(9).*in2(15)-in2(16).*in2(22).*in2(3).*in2(14).*in2(10)-in2(16).*in2(22).*in2(8).*in2(4).*in2(15)+in2(16).*in2(22).*in2(8).*in2(14).*in2(5)+in2(16).*in2(22).*in2(13).*in2(4).*in2(10)-in2(16).*in2(22).*in2(13).*in2(9).*in2(5)+in2(21).*in2(2).*in2(8).*in2(14).*in2(20)-in2(21).*in2(2).*in2(8).*in2(19).*in2(15)-in2(21).*in2(2).*in2(13).*in2(9).*in2(20)+in2(21).*in2(2).*in2(13).*in2(19).*in2(10)+in2(21).*in2(2).*in2(18).*in2(9).*in2(15)-in2(21).*in2(2).*in2(18).*in2(14).*in2(10)-in2(21).*in2(7).*in2(3).*in2(14).*in2(20)+in2(21).*in2(7).*in2(3).*in2(19).*in2(15)+in2(21).*in2(7).*in2(13).*in2(4).*in2(20)-in2(21).*in2(7).*in2(13).*in2(19).*in2(5)-in2(21).*in2(7).*in2(18).*in2(4).*in2(15)+in2(21).*in2(7).*in2(18).*in2(14).*in2(5)+in2(21).*in2(12).*in2(3).*in2(9).*in2(20)-in2(21).*in2(12).*in2(3).*in2(19).*in2(10)-in2(21).*in2(12).*in2(8).*in2(4).*in2(20)+in2(21).*in2(12).*in2(8).*in2(19).*in2(5)+in2(21).*in2(12).*in2(18).*in2(4).*in2(10)-in2(21).*in2(12).*in2(18).*in2(9).*in2(5)-in2(21).*in2(17).*in2(3).*in2(9).*in2(15)+in2(21).*in2(17).*in2(3).*in2(14).*in2(10)+in2(21).*in2(17).*in2(8).*in2(4).*in2(15)-in2(21).*in2(17).*in2(8).*in2(14).*in2(5)-in2(21).*in2(17).*in2(13).*in2(4).*in2(10)+in2(21).*in2(17).*in2(13).*in2(9).*in2(5))+in2(1).*conj(in1(1))+in2(6).*conj(in1(6))+in2(11).*conj(in1(11))+in2(16).*conj(in1(16))+in2(21).*conj(in1(21))+in2(2).*conj(in1(2))+in2(7).*conj(in1(7))+in2(12).*conj(in1(12))+in2(17).*conj(in1(17))+in2(22).*conj(in1(22))+in2(3).*conj(in1(3))+in2(8).*conj(in1(8))+in2(13).*conj(in1(13))+in2(18).*conj(in1(18))+in2(23).*conj(in1(23))+in2(4).*conj(in1(4))+in2(9).*conj(in1(9))+in2(14).*conj(in1(14))+in2(19).*conj(in1(19))+in2(24).*conj(in1(24))+in2(5).*conj(in1(5))+in2(10).*conj(in1(10))+in2(15).*conj(in1(15))+in2(20).*conj(in1(20))+in2(25).*conj(in1(25))

此时如果想用gradient函数求微分,又会报错:

1
>> g = gradient(sparse_inv_cov_syms,Xsyms)

输出:

1
2
错误使用 sym/gradient (line 39)
Second argument must be a vector of variables.

既然他说 must be a vector,那就把它向量化:

1
2
>> Xvec=X(:)
>> g = gradient(sparse_inv_cov_syms,Xvec)

输出:

1
2
3
4
5
6
7
8
g =

conj(Ssyms1_1) + sign(Xsyms1_1)/10 - ......
conj(Ssyms2_1) + sign(Xsyms2_1)/10 + ......
conj(Ssyms3_1) + sign(Xsyms3_1)/10 - ......
conj(Ssyms4_1) + sign(Xsyms4_1)/10 + ......
conj(Ssyms5_1) + sign(Xsyms5_1)/10 - ......
......

成功了。

所以我先用Matlab的符号运算生成了矩阵符号,书写函数以后,将被求导的矩阵转换成向量,用gradient函数进行符号求导,然后再进行一次矩阵向量化的包装,这样一来,基本可以实现autograd自动求导的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
% 定义基本变量
n = 5;
A = randn(n,n);
S = A*A';

% 定义目标函数,这里使用了符号运算
syms Ssyms 5; % 表示生成5x5的矩阵符号
syms Xsyms 5;
sparse_inv_cov_syms = trace(Ssyms' * Xsyms) - log(det(Xsyms)) + 0.1 * sum(abs(Xsyms),'all'); % 目标函数
X_vec=Xsyms(:);% 将矩阵 X 向量化,因为后面的gradient函数的第二个参数必须是vector
sparse_inv_cov_grad_syms = gradient(sparse_inv_cov_syms,X_vec); % 微分函数

objective=MatlabFunction(sparse_inv_cov_syms,"Vars",{Ssyms,Xsyms});
% 将符号函数转换为Matlab函数。这里要加"Vars",并且把两个矩阵符号用大括号括起来
% 这样生成的Matlab函数的参数才是矩阵
grad=MatlabFunction(sparse_inv_cov_grad_syms,"Vars",{Ssyms,X_vec});
gradien=@(X) grad(S,X(:)); % 用矩阵向量化包装,形成真正的梯度函数

这时,只要运行

1
grad(X)

它就会返回函数

1
sparse_inv_cov_syms = trace(Ssyms' * Xsyms) - log(det(Xsyms)) + 0.1 * sum(abs(Xsyms),'all');

也就是 \[ \text{tr}(S^\top X)-\log \operatorname{det}(X)+0.1\|X\|_1 \] 在矩阵X处的导数。


本站的运行成本约为每个月5元人民币,如果您觉得本站有用,欢迎打赏:

image-20250112002710775


怎么用 Matlab 实现 aurograd 自动微分
https://suzumiyaakizuki.github.io/2025/01/11/怎么用matlab实现自动求导/
作者
SuzumiyaAkizuki
发布于
2025年1月11日
许可协议