【bzoj2707】[SDOI2012]走迷宫

2015.04.14 19:08 Tue | 22次阅读 | 旧日oi | 固定链接 | 源码

Description

Morenan被困在了一个迷宫里。迷宫可以视为N个点M条边的有向图,其中Morenan处于起点S,迷宫的终点设为T。可惜的是,Morenan非常的脑小,他只会从一个点出发随机沿着一条从该点出发的有向边,到达另一个点。这样,Morenan走的步数可能很长,也可能是无限,更可能到不了终点。若到不了终点,则步数视为无穷大。但你必须想方设法求出Morenan所走步数的期望值。

Input

第1行4个整数,N,M,S,T
第[2, M+1]行每行两个整数o1, o2,表示有一条从o1到o2的边。

Output

一个浮点数,保留小数点3位,为步数的期望值。若期望值为无穷大,则输出"INF"。
【样例输入1】
6 6 1 6
1 2
1 3
2 4
3 5
4 6
5 6
【样例输出1】
3.000
【样例输入2】
9 12 1 9
1 2
2 3
3 1
3 4
3 7
4 5
5 6
6 4
6 7
7 8
8 9
9 7
【样例输出2】
9.500
【样例输入3】
2 0 1 2
【样例输出3】
INF
【数据范围】
测试点
N
M
Hint
[1, 6]
<=10
<=100
[7, 12]
<=200
<=10000
[13, 20]
<=10000
<=1000000
保证强连通分量的大小不超过100
另外,均匀分布着40%的数据,图中没有环,也没有自环

题解

强连通分量+高斯消元解数学期望
答案为inf的条件是存在一个能从S出发到达不了T的点,扫一遍就行。
如果是有向无环图的话我们可以直接用递推解期望方程,但是它不是……
所以我们先用强连通分量把它变成一个DAG,然后在强连通分量内部用高斯消元解期望方程,把扩展出去的节点的期望值加到方程结果那边去就行了。
列个方程吧,exp[x]=1/outdegree[y]*exp[y]+1,其中y有边指向x,exp[T]=0;
170多行,不短啊……

我的程序

#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<bitset>
#include<stack>
#include<ctime>
#include<cmath>
#include<queue>
#include<set>
#include<map>
#define maxn 100010
#define maxm 1000010
#define ll long long
#define mod 1000000007
#define inf 0x3f3f3f3f
using namespace std;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,m,s,t;
int out[maxn];
struct edge{
    int to;
    int next;
}e[maxm];
int h[maxn],tp;
void ae(int u,int v)
{
    e[++tp].to=v;
    e[tp].next=h[u];
    h[u]=tp;
}
int scc_cnt,sccno[maxn],rank[maxn];
vector<int> scc[maxn];
stack<int> st;
int dfn[maxn],low[maxn],dfs_clock;
int vis[maxn];
bool calc[maxn];
double a[205][205],ex[10005];
void dfs(int u)
{
    dfn[u]=++dfs_clock;
    low[u]=dfn[u];
    st.push(u);
    for(int v,i=h[u];e[i].to;i=e[i].next)
    {
        if(!dfn[v=e[i].to])
        {
            dfs(v=e[i].to);
            low[u]=min(low[u],low[v]);
        }
        else if(!sccno[v]) low[u]=min(low[u],dfn[v]);
    }
    if(dfn[u]==low[u])
    {
        scc_cnt++;
        int x=-1;
        while(x!=u)
        {
            x=st.top();st.pop();
            sccno[x]=scc_cnt;
            scc[scc_cnt].push_back(x);
            rank[x]=scc[scc_cnt].size()-1;
        }
    }
}
void tarjan()
{
    for(int i=1;i<=n;i++)
    if(!dfn[i]) dfs(i);
}
void check(int x)
{
    vis[x]=0;
    if(x==sccno[t])
    {
        vis[x]=1;
        return;
    }
    for(int u,k=0;k<scc[x].size();k++)
    {
        u=scc[x][k];
        for(int v,i=h[u];e[i].to;i=e[i].next)
        if(sccno[v=e[i].to]!=x)
        {
            if(vis[sccno[v]]==-1) check(sccno[v]);
            if(vis[sccno[v]]==1) vis[x]=1;
        }
    }
}
void gauss(int equ,int var)
{
    int i,j,k,id;
    for(i=id=0;i<var;i++,id++)
    {
        for(k=j=id;j<equ;j++) if(fabs(a[j][i])>fabs(a[id][i])) k=j;
        if(k!=id) for(j=i;j<=var;j++) swap(a[k][j],a[id][j]);
        for(int j=id+1;j<equ;j++)
        {
            double rate=a[j][i]/a[id][i];
            for(k=i;k<=var;k++)
            a[j][k]-=rate*a[id][k];
        }
    }
    for(i=equ-1;i>=0;i--)
    {
        for(j=i+1;j<var;j++)
        a[i][var]-=a[i][j]*a[j][var];
        a[i][var]/=a[i][i];
    }
}
void solve(int x)
{
    int size=scc[x].size();
    for(int u,i=0;i<size;i++)
    for(int v,j=h[u=scc[x][i]];e[j].to;j=e[j].next)
    if(sccno[v=e[j].to]!=x&&!calc[sccno[v]]) 
    solve(sccno[v]);
    memset(a,0,sizeof(a));
    for(int u,k=0;k<size;k++)
    {
        a[k][k]=1;
        if((u=scc[x][k])==t) continue;
        a[k][size]=1;
        for(int v,i=h[u];e[i].to;i=e[i].next)
        if(sccno[v=e[i].to]==x) a[k][rank[v]]-=1.0/out[u];
        else a[k][scc[x].size()]+=ex[v]/out[u];
    }
    gauss(size,size);
    for(int i=0;i<size;i++) ex[scc[x][i]]=a[rank[scc[x][i]]][size];
    calc[x]=1;
}
int main()
{
    n=read();m=read();s=read();t=read();
    if(s==t)
    {
        puts("0.000");
        return 0;
    }
    for(int u,v,i=1;i<=m;i++) 
    {
        u=read();v=read();
        ae(u,v);out[u]++;
    }
    tarjan();
    memset(vis,-1,sizeof(vis));
    check(sccno[s]);
    for(int i=1;i<=n;i++) 
    if(vis[i]==0)
    {
        puts("INF");
        return 0;
    }
    solve(sccno[s]);
    printf("%.3lf",ex[s]);
}```