form.jl 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. function form{T1,T2}(A::SparseMatrixCSC{T1,Int},b::Array{T2,1}; kwargs...)
  2. Ax = zeros(promote_type(T1,T2),size(A,1))
  3. return form(x -> A_mul_B!(1.0,A,x,0.0,Ax),b)
  4. end
  5. """
  6. x,flag,err,iter,resvec = form(A,b,tol=1e-2,maxIter=100,M=1,x=[],out=0)
  7. Generalized Minimal residual (GMRESm) method with restarts applied to A*x = b.
  8. Input:
  9. A - function computing A*xor
  10. M -preconditioner,function computing M\\x
  11. """
  12. function form(A::Function,b::Vector,tol::Real=1e-2,maxIter::Int=100,M::Function=identity,x::Vector=[],out::Int=2,storeInterm::Bool=false)
  13. n = length(b)
  14. if norm(b)==0;return zeros(eltype(b),n),-9,0.0,0,[0.0];end
  15. if isempty(x)
  16. x = zeros(n)
  17. r = M(b)
  18. else
  19. r = M(b-A(x))
  20. end
  21. if storeInterm
  22. x = zeros(n,maxIter)
  23. end
  24. bnrm2 = norm(b)
  25. if bnrm2 == 0.0;bnrm2 = 1.0;end
  26. err = norm( r ) / bnrm2
  27. if err < tol; return x,err;end
  28. #Arnoldi向量 m+1
  29. V = zeros(n,maxIter+1)
  30. #m+1,m
  31. H = zeros(maxIter+1,maxIter)
  32. #e1
  33. e1 = zeros(n)
  34. e1[1] = 1.0
  35. #暂时不考虑complex
  36. #残差数组
  37. resvec = zeros(maxIter)
  38. if out==2
  39. println(@sprintf("=== gmres ===\n%4s\t%7s\n","iter","relres"))
  40. end
  41. #初始化
  42. iter = 0
  43. flag = -1
  44. cnt = 1
  45. y = 0
  46. # v1 = r0/belta
  47. V[:,1] = r / norm(r)
  48. #theta = belta*e1
  49. s = norm(r)*e1;
  50. #显示迭代次数
  51. i = 1
  52. for i = 1:maxIter
  53. if out==2;;print(@sprintf("%3d\t",i));end
  54. #wi = Avi
  55. w = A(V[:,i])
  56. w = M(w)
  57. #Arnoldi process
  58. for k = 1:i
  59. # hij = (wj,vi)
  60. H[k,i] = dot(w,V[:,k])
  61. # wj = wj - hijVi
  62. w -= H[k,i] * V[:,k]
  63. end
  64. #hj+1,j = ||wj||2
  65. H[i+1,i] = norm(w)
  66. #求解方程Hjy = belta* e1
  67. y = H[1:i,1:i]\s[1:i]
  68. err = H[i+1,i]*abs(y[i])/bnrm2
  69. if out == 2;print(@sprintf("%1.1e\n",err));end
  70. resvec[cnt] = err
  71. cnt = cnt + 1
  72. if err <= tol
  73. flag = 0
  74. break
  75. end
  76. #vj+1 = wj/h(j+1,j)
  77. V[:,i+1] = w/H[i+1,i]
  78. end
  79. x += V[:,1:i]*y
  80. r = b - A(x)
  81. r = M(r)
  82. if out==2; print(@sprintf("\t %1.1e\n", err)); end
  83. if out>=0
  84. if flag==-1
  85. println(@sprintf("gmres iterated maxIter (=%d) times without achieving the desired tolerance.",maxIter))
  86. elseif flag==0 && out>=1
  87. println(@sprintf("gmres achieved desired tolerance at iteration %d. Residual norm is %1.2e.",iter,resvec[cnt]))
  88. end
  89. end
  90. return x,flag,resvec[cnt-1],iter,resvec[1:cnt-1]
  91. end
  92. """
  93. c,s,r = SymOrtho(a,b)
  94. Computes a Givens rotation
  95. Implementation is based on Table 2.9 in
  96. Choi, S.-C. T. (2006).
  97. Iterative Methods for Singular Linear Equations and Least-squares Problems.
  98. Phd thesis, Stanford University.
  99. """
  100. function symOrtho(a,b)
  101. c = 0.0; s = 0.0; r = 0.0
  102. if b==0
  103. s = 0.0
  104. r = abs(a)
  105. c = (a==0) ? c=1.0 : c = sign(a)
  106. elseif a == 0
  107. c = 0.0
  108. s = sign(b)
  109. r = abs(b)
  110. elseif abs(b) > abs(a)
  111. tau = a/b
  112. s = sign(b)/sqrt(1+tau^2)
  113. c = s*tau
  114. r = b/s
  115. elseif abs(a) > abs(b)
  116. tau = b/a
  117. c = sign(a)/sqrt(1+tau^2)
  118. s = c*tau
  119. r = a/c
  120. end
  121. return c,s,r
  122. end