gmres.jl 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. function gmres{T1,T2}(A::SparseMatrixCSC{T1,Int},b::Array{T2,1},restrt::Int; kwargs...)
  2. Ax = zeros(promote_type(T1,T2),size(A,1))
  3. return gmres(x -> A_mul_B!(1.0,A,x,0.0,Ax),b,restrt;kwargs...)
  4. end
  5. """
  6. x,flag,err,iter,resvec = gmres(A,b,restrt,tol=1e-2,maxIter=100,M=1,x=[],out=0)
  7. Generalized Minimal residual (GMRESm) method with restarts applied to A*x = b.
  8. Input:
  9. A - function computing A*xor
  10. M -preconditioner,function computing M\\x
  11. """
  12. function gmres(A::Function,b::Vector,restrt::Int,tol::Real=1e-2,maxIter::Int=100,M::Function=identity,x::Vector=[],out::Int=0,storeInterm::Bool=false)
  13. n = length(b)
  14. if norm(b)==0;return zeros(eltype(b),n),-9,0.0,0,[0.0];end
  15. if isempty(x)
  16. x = zeros(n)
  17. r = M(b)
  18. else
  19. r = M(b-A(x))
  20. end
  21. if storeInterm
  22. x = zeros(n,maxIter)
  23. end
  24. bnrm2 = norm(b)
  25. if bnrm2 == 0.0;bnrm2 = 1.0;end
  26. err = norm( r ) / bnrm2
  27. if err < tol; return x,err;end
  28. #重启次数 m
  29. restrt = min(restrt,n-1)
  30. #Arnoldi向量 m+1
  31. V = zeros(n,restrt+1)
  32. #m+1,m
  33. H = zeros(restrt+1,restrt)
  34. #Gj的cs
  35. cs = zeros(restrt)
  36. #Gj的sn
  37. sn = zeros(restrt)
  38. #e1
  39. e1 = zeros(n)
  40. e1[1] = 1.0
  41. #暂时不考虑complex
  42. #残差数组
  43. resvec = zeros((1+restrt)*maxIter)
  44. if out==2
  45. println(@sprintf("=== gmres ===\n%4s\t%7s\n","iter","relres"))
  46. end
  47. #初始化
  48. iter = 0
  49. flag = -1
  50. cnt = 1
  51. for iter = 1:maxIter
  52. # v1 = r0/belta
  53. V[:,1] = r / norm(r)
  54. #theta = belta*e1
  55. s = norm(r)*e1;
  56. #显示迭代次数
  57. if out==2;;print(@sprintf("%3d\t",iter));end
  58. #开始迭代 带重启
  59. for i = 1:restrt
  60. #wi = Avi
  61. w = A(V[:,i])
  62. w = M(w)
  63. #Arnoldi process
  64. for k = 1:i
  65. # hij = (wj,vi)
  66. H[k,i] = dot(w,V[:,k])
  67. # wj = wj - hijVi
  68. w -= H[k,i] * V[:,k]
  69. end
  70. #hj+1,j = ||wj||2
  71. H[i+1,i] = norm(w)
  72. #Apply Givens rotation
  73. for k = 1:i-1
  74. temp = cs[k]*H[k,i] + sn[k]*H[k+1,i]
  75. H[k+1,i] = -sn[k] *H[k,i] + cs[k]*H[k+1,i]
  76. H[k,i] = temp
  77. end
  78. #跳出
  79. if H[i+1,i] == 0
  80. end
  81. #vj+1 = wj/h(j+1,j)
  82. V[:,i+1] = w/H[i+1,i]
  83. #From the Givens rotation Gj
  84. cs[i],sn[i] = symOrtho(H[i,i],H[i+1,i])
  85. #Apply Gj to last column of Hj+1,j
  86. H[i,i] = cs[i]*H[i,i]+sn[i,i]*H[i+1,i]
  87. H[i+1,i] = 0.0
  88. #Apply Gj to right-hand side
  89. s[i+1] = -sn[i]*s[i]
  90. s[i] = cs[i]*s[i]
  91. #check convergence
  92. err = abs(s[i+1])/bnrm2
  93. if out == 2;print(@sprintf("%1.1e",err));end
  94. resvec[cnt] = err
  95. if err <= tol
  96. y = H[1:i,1:i] \ s[1:i]
  97. x += V[:,1:i]*y
  98. if out == 2;print("\n"); end
  99. flag = 0;break;
  100. end
  101. cnt = cnt + 1
  102. end
  103. if err <= tol
  104. flag = 0
  105. break
  106. end
  107. y = H[1:restrt,1:restrt]\s[1:restrt]
  108. x += V[:,1:restrt]*y
  109. if storeInterm; X[:,iter] = x; end
  110. r = b - A(x)
  111. r = M(r)
  112. s[restrt+1] = norm(r)
  113. resvec[cnt] = abs(s[restrt+1]) / bnrm2
  114. if out==2; print(@sprintf("\t %1.1e\n", err)); end
  115. end
  116. if out>=0
  117. if flag==-1
  118. println(@sprintf("gmres iterated maxIter (=%d) times without achieving the desired tolerance.",maxIter))
  119. elseif flag==0 && out>=1
  120. println(@sprintf("gmres achieved desired tolerance at iteration %d. Residual norm is %1.2e.",iter,resvec[cnt]))
  121. end
  122. end
  123. if storeInterm
  124. return X[:,1:iter],flag,resvec[cnt],iter,resvec[1:cnt]
  125. else
  126. return x,flag,resvec[cnt],iter,resvec[1:cnt]
  127. end
  128. end
  129. """
  130. c,s,r = SymOrtho(a,b)
  131. Computes a Givens rotation
  132. Implementation is based on Table 2.9 in
  133. Choi, S.-C. T. (2006).
  134. Iterative Methods for Singular Linear Equations and Least-squares Problems.
  135. Phd thesis, Stanford University.
  136. """
  137. function symOrtho(a,b)
  138. c = 0.0; s = 0.0; r = 0.0
  139. if b==0
  140. s = 0.0
  141. r = abs(a)
  142. c = (a==0) ? c=1.0 : c = sign(a)
  143. elseif a == 0
  144. c = 0.0
  145. s = sign(b)
  146. r = abs(b)
  147. elseif abs(b) > abs(a)
  148. tau = a/b
  149. s = sign(b)/sqrt(1+tau^2)
  150. c = s*tau
  151. r = b/s
  152. elseif abs(a) > abs(b)
  153. tau = b/a
  154. c = sign(a)/sqrt(1+tau^2)
  155. s = c*tau
  156. r = a/c
  157. end
  158. return c,s,r
  159. end