欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

拉格朗日乘子法总结(等式约束、不等式约束、非线性规划、KKT 条件)

程序员文章站 2022-06-13 16:35:20
...

前言

本文主要对拉格朗日乘子法进行总结,具体原理可以参考这两篇文章:

  1. 友情链接 1
  2. 友情链接 2

如果你对这篇文章可感兴趣,可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接。


拉格朗日乘子法

基本的拉格朗日乘子法主要是针对等式约束下的非线性方程最优化问题,具体形式如下:
{max  f(x)s.t.  g(x)=0 \left\{ \begin{aligned} \max & \ \ f(\textbf{x}) \\ s.t.& \ \ g(\textbf{x})=0 \\ \end{aligned} \right.
对于上述问题,我们可以构造拉格朗日算符函数 L(x,λ)=f(x)+λg(x)L(\textbf{x},\lambda)=f(\textbf{x})+\lambda g(\textbf{x}),然后可以得到最优解的必要条件如下:
{xL=Lx=f+λg=0λL=Lλ=g(x)=0 \left\{ \begin{aligned} & \nabla_{\textbf{x}}L=\displaystyle\frac{\partial L}{\partial \textbf{x}}=\nabla f+\lambda \nabla g=0 \\ & \nabla_{\lambda}L=\displaystyle\frac{\partial L}{\partial \lambda}=g(\textbf{x})=0 \end{aligned} \right.


Karush-Kuhn-Tucker (KKT) 条件

我们可以将上述的等式约束推广到不等式约束下,依然可以有方法进行解决,具体形式如下:
{min  f(x)s.t.  g(x)0 \left\{ \begin{aligned} \min & \ \ f(\textbf{x}) \\ s.t.& \ \ g(\textbf{x})\leq 0 \\ \end{aligned} \right.
对于上述问题,我们依然是构造拉格朗日算符函数 L(x,λ)=f(x)+λg(x)L(\textbf{x},\lambda)=f(\textbf{x})+\lambda g(\textbf{x}),但是最优解的条件发生了改变,具体条件如下:
{xL=Lx=f+λg=0g(x)0λ0λg(x)=0 \left\{ \begin{aligned} & \nabla_{\textbf{x}}L=\displaystyle\frac{\partial L}{\partial \textbf{x}}=\nabla f+\lambda \nabla g=0 \\ & g(\textbf{x})\leq 0\\ & \lambda \geq 0\\ & \lambda g(\textbf{x})=0 \end{aligned} \right.
上述条件即为 KKT\text{KKT} 条件,接下来我们给出非线性规划的标准形式:
{min  f(x)s.t.  gj(x)0,j=1,...,m       hk(x)0,k=1,...,p \left\{ \begin{aligned} \min & \ \ f(\textbf{x}) \\ s.t.& \ \ g_j(\textbf{x})\leq 0,j=1,...,m \\ \ \ \ \ \ & \ \ h_k(\textbf{x})\leq 0,k=1,...,p \\ \end{aligned} \right.
对于上述标准形式的非线性规划问题,我们可以定义拉格朗日算符函数 L(x,{λj},{μk})=f(x)+j=1mλjgj(x)+k=1pμkhk(x)L(\textbf{x},\{\lambda_j\},\{\mu_k\})=f(\textbf{x})+\sum\limits_{j=1}^m\lambda_j g_j(\textbf{x})+\sum\limits_{k=1}^p \mu_kh_k(\textbf{x}),其中 λj\lambda_j 对应等式约束 gj(x)g_j(\textbf{x}) 的拉格朗日乘数,μk\mu_k 对应不等式约束 hk(x)h_k(\textbf{x}) 的拉格朗日乘数。由此我们可以综合上述等式约束与不等式约束的最优值必要条件,具体如下:
{xL=Lx=f+j=1mλjgj+k=1pμkhk=0λjL=Lλj=gj(x)=0,j=1,...,mhk(x)0,k=1,...,pμk0,k=1,...,pμkhk(x)=0,k=1,...,p \left\{ \begin{aligned} & \nabla_{\textbf{x}}L=\displaystyle\frac{\partial L}{\partial \textbf{x}}=\nabla f+\sum\limits_{j=1}^m \lambda_j \nabla g_j+\sum\limits_{k=1}^p \mu_k\nabla h_k=0 \\ & \nabla_{\lambda_j}L=\displaystyle\frac{\partial L}{\partial \lambda_j}=g_j(\textbf{x})=0,j=1,...,m\\ &h_k(\textbf{x})\leq 0,k=1,...,p\\ & \mu_k\geq 0,k=1,...,p\\ & \mu_kh_k(\textbf{x})=0,k=1,...,p \end{aligned} \right.
接下来我们以一道习题为例来进行讲解。


2020牛客暑期多校训练营(第一场)D.Quadratic Form

题意

给定一个 n x nn \ \text{x} \ n 的正定对称矩阵 AA 以及一个 nn 维向量 b\textbf{b},想要求解如下非线性规划问题:
{max  bTxs.t.  g(x)=xTAx10,xR \left\{ \begin{aligned} \max & \ \ \textbf{b}^T\textbf{x} \\ s.t.& \ \ g(\textbf{x})=\textbf{x}^TA\textbf{x}-1\leq 0,\textbf{x}\in \mathbb{R} \\ \end{aligned} \right.
1n200,0Ai,j,bi109,g(x)>0,det(A)0(1\leq n\leq 200,0\leq |A_{i,j}|,|b_i|\leq 10^9,g(\textbf{x})>0,det(A)\not=0),最终答案 (bTx)2mod 998244353(\textbf{b}^T\textbf{x})^2 \text{mod} \ 998244353

思路

很明显,这是一道基于 KKT\text{KKT} 条件的问题。由于本问题是求 maxmax,因此我们需要转换成求 bTx-\textbf{b}^T\textbf{x} 的最小值,构造拉格朗日算符函数如下:
L(x,λ)=bTx+λ(xTAx1) L(\textbf{x},\lambda)=-\textbf{b}^T\textbf{x}+\lambda (\textbf{x}^TA\textbf{x}-1)
接下来我们可以得到如下的 KKT\text{KKT} 条件,即取到极值的必要条件:
{xL=Lx=f+λg=0g(x)0λ0λg(x)=0 \left\{ \begin{aligned} & \nabla_{\textbf{x}}L=\displaystyle\frac{\partial L}{\partial \textbf{x}}=\nabla f+\lambda \nabla g=0 \\ & g(\textbf{x})\leq 0\\ & \lambda \geq 0\\ & \lambda g(\textbf{x})=0 \end{aligned} \right.
由此可以得到如下两条等式:
{(1)b+2λAx=0(2)λxTAx=λ \left\{ \begin{aligned} & (1)-\textbf{b}+2\lambda A\textbf{x}=0 \\ & (2) \lambda \textbf{x}^TA\textbf{x}=\lambda \end{aligned} \right.
化简 (1)(1) 式,可以得到 (3)Ax=12λb(3)A\textbf{x}=\displaystyle\frac{1}{2\lambda}\textbf{b},代入 (2)(2) 式可以得到 xTb=2λ\textbf{x}^T\textbf{b}=2\lambda,注意 xTb=bTx=2λ\textbf{x}^T\textbf{b}=\textbf{b}^T\textbf{x}=2\lambda,继续化简 (3)(3) 式如下:
2λx=A1b2λbTx=bTA1bbTxbTx=(bTx)2=bTA1b 2\lambda \textbf{x}=A^{-1}\textbf{b} \\ 2\lambda \textbf{b}^T \textbf{x}=\textbf{b}^TA^{-1}\textbf{b} \\ \textbf{b}^T \textbf{x} \textbf{b}^T \textbf{x}=(\textbf{b}^T \textbf{x})^2=\textbf{b}^TA^{-1}\textbf{b} \\
因此本题的最终答案为 (bTA1b)%mod(\textbf{b}^TA^{-1}\textbf{b})\%mod,只需套入矩阵求逆的板子即可解决。

代码

#include <bits/stdc++.h>
#define rep(i,a,b) for(int i = a; i <= b; i++)
typedef long long ll;
const int N = 405;
const ll mod = 998244353;
using namespace std;

ll A[N][N], b[N], ans[N];
int n, B1[N], B2[N];

ll poww(ll a, ll b) {
    ll base = a, ans = 1;
    while(b) {
        if(b & 1) ans = ans * base % mod;
        base = base * base % mod;
        b >>= 1;
    }
    return ans;
}

bool matrixInv() {
    for(int k = 1; k <= n; k++) {
        for(int i = k; i <= n; i++)
            for(int j = k; j <= n; j++)
                if(A[i][j]) {
                    B1[k] = i, B2[k] = j; break;
                }
        for(int i = 1; i <= n; i++)
            swap(A[k][i], A[B1[k]][i]);
        for(int i = 1; i <= n; i++)
            swap(A[i][k], A[i][B2[k]]);
        if(!A[k][k]) {
            return false; // 不可逆
        }
        A[k][k] = poww(A[k][k], mod-2);
        for(int j = 1; j <= n; j++)
            if(j!=k) (A[k][j] *= A[k][k]) %= mod;
        for(int i = 1; i <= n; i++)
            if(i != k)
                for(int j = 1; j <= n; j++)
                    if(j != k)
                        (A[i][j] += mod - A[i][k] * A[k][j] %mod) %= mod;
        for(int i = 1; i <= n; i++)
            if(i != k)
                A[i][k] = (mod - A[i][k] * A[k][k] % mod) % mod;
    }
    for(int k = n; k; k--) {
        for(int i = 1; i <= n; i++)
            swap(A[B2[k]][i], A[k][i]);
        for(int i = 1; i <= n; i++)
            swap(A[i][B1[k]], A[i][k]);
    }
    return true; // 可逆
}

int main()
{
    while(~scanf("%d",&n)) {
        rep(i,1,n)
            rep(j,1,n) scanf("%lld",&A[i][j]);
        rep(i,1,n) scanf("%lld", &b[i]);
        matrixInv();
        rep(i,1,n) {
            ll v = 0;
            rep(j,1,n) v = (v + b[j] * A[i][j]) % mod; 
            ans[i] = v;
        }
        ll res = 0;
        rep(i,1,n) res = (res + ans[i] * b[i]) % mod;
        printf("%lld\n", res);
    }
    return 0;
}