123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- function gmres{T1,T2}(A::SparseMatrixCSC{T1,Int},b::Array{T2,1},restrt::Int; kwargs...)
- Ax = zeros(promote_type(T1,T2),size(A,1))
- return gmres(x -> A_mul_B!(1.0,A,x,0.0,Ax),b,restrt;kwargs...)
- end
- """
- x,flag,err,iter,resvec = gmres(A,b,restrt,tol=1e-2,maxIter=100,M=1,x=[],out=0)
- Generalized Minimal residual (GMRESm) method with restarts applied to A*x = b.
- Input:
- A - function computing A*xor
- M -preconditioner,function computing M\\x
- """
- function gmres(A::Function,b::Vector,restrt::Int,tol::Real=1e-2,maxIter::Int=100,M::Function=identity,x::Vector=[],out::Int=0,storeInterm::Bool=false)
- n = length(b)
- if norm(b)==0;return zeros(eltype(b),n),-9,0.0,0,[0.0];end
- if isempty(x)
- x = zeros(n)
- r = M(b)
- else
- r = M(b-A(x))
- end
- if storeInterm
- x = zeros(n,maxIter)
- end
- bnrm2 = norm(b)
- if bnrm2 == 0.0;bnrm2 = 1.0;end
- err = norm( r ) / bnrm2
- if err < tol; return x,err;end
- #重启次数 m
- restrt = min(restrt,n-1)
- #Arnoldi向量 m+1
- V = zeros(n,restrt+1)
- #m+1,m
- H = zeros(restrt+1,restrt)
- #Gj的cs
- cs = zeros(restrt)
- #Gj的sn
- sn = zeros(restrt)
- #e1
- e1 = zeros(n)
- e1[1] = 1.0
- #暂时不考虑complex
- #残差数组
- resvec = zeros((1+restrt)*maxIter)
- if out==2
- println(@sprintf("=== gmres ===\n%4s\t%7s\n","iter","relres"))
- end
- #初始化
- iter = 0
- flag = -1
- cnt = 1
- for iter = 1:maxIter
- # v1 = r0/belta
- V[:,1] = r / norm(r)
- #theta = belta*e1
- s = norm(r)*e1;
- #显示迭代次数
- if out==2;;print(@sprintf("%3d\t",iter));end
- #开始迭代 带重启
- for i = 1:restrt
- #wi = Avi
- w = A(V[:,i])
- w = M(w)
- #Arnoldi process
- for k = 1:i
- # hij = (wj,vi)
- H[k,i] = dot(w,V[:,k])
- # wj = wj - hijVi
- w -= H[k,i] * V[:,k]
- end
- #hj+1,j = ||wj||2
- H[i+1,i] = norm(w)
- #Apply Givens rotation
- for k = 1:i-1
- temp = cs[k]*H[k,i] + sn[k]*H[k+1,i]
- H[k+1,i] = -sn[k] *H[k,i] + cs[k]*H[k+1,i]
- H[k,i] = temp
- end
- #跳出
- if H[i+1,i] == 0
- end
- #vj+1 = wj/h(j+1,j)
- V[:,i+1] = w/H[i+1,i]
- #From the Givens rotation Gj
- cs[i],sn[i] = symOrtho(H[i,i],H[i+1,i])
- #Apply Gj to last column of Hj+1,j
- H[i,i] = cs[i]*H[i,i]+sn[i,i]*H[i+1,i]
- H[i+1,i] = 0.0
- #Apply Gj to right-hand side
- s[i+1] = -sn[i]*s[i]
- s[i] = cs[i]*s[i]
- #check convergence
- err = abs(s[i+1])/bnrm2
- if out == 2;print(@sprintf("%1.1e",err));end
- resvec[cnt] = err
- if err <= tol
- y = H[1:i,1:i] \ s[1:i]
- x += V[:,1:i]*y
- if out == 2;print("\n"); end
- flag = 0;break;
- end
- cnt = cnt + 1
- end
- if err <= tol
- flag = 0
- break
- end
- y = H[1:restrt,1:restrt]\s[1:restrt]
- x += V[:,1:restrt]*y
- if storeInterm; X[:,iter] = x; end
- r = b - A(x)
- r = M(r)
- s[restrt+1] = norm(r)
- resvec[cnt] = abs(s[restrt+1]) / bnrm2
- if out==2; print(@sprintf("\t %1.1e\n", err)); end
- end
- if out>=0
- if flag==-1
- println(@sprintf("gmres iterated maxIter (=%d) times without achieving the desired tolerance.",maxIter))
- elseif flag==0 && out>=1
- println(@sprintf("gmres achieved desired tolerance at iteration %d. Residual norm is %1.2e.",iter,resvec[cnt]))
- end
- end
- if storeInterm
- return X[:,1:iter],flag,resvec[cnt],iter,resvec[1:cnt]
- else
- return x,flag,resvec[cnt],iter,resvec[1:cnt]
- end
- end
- """
- c,s,r = SymOrtho(a,b)
- Computes a Givens rotation
- Implementation is based on Table 2.9 in
- Choi, S.-C. T. (2006).
- Iterative Methods for Singular Linear Equations and Least-squares Problems.
- Phd thesis, Stanford University.
- """
- function symOrtho(a,b)
- c = 0.0; s = 0.0; r = 0.0
- if b==0
- s = 0.0
- r = abs(a)
- c = (a==0) ? c=1.0 : c = sign(a)
- elseif a == 0
- c = 0.0
- s = sign(b)
- r = abs(b)
- elseif abs(b) > abs(a)
- tau = a/b
- s = sign(b)/sqrt(1+tau^2)
- c = s*tau
- r = b/s
- elseif abs(a) > abs(b)
- tau = b/a
- c = sign(a)/sqrt(1+tau^2)
- s = c*tau
- r = a/c
- end
- return c,s,r
- end
|