@@ -4,39 +4,42 @@ using LinearAlgebra
44 par:: NamedTuple{N,T}
55end
66
7+ quoteof (f:: AffineTransform ) = :(AffineTransform ($ (quoteof (f. par))))
8+
79params (f:: AffineTransform ) = getfield (f, :par )
810
911@inline Base. getproperty (d:: AffineTransform , s:: Symbol ) = getfield (getfield (d, :par ), s)
1012
11- Base. propertynames (d:: AffineTransform{N} ) where {N} = N
13+ Base. propertynames (d:: AffineTransform{N} ) where {N} = N
1214
13- @inline Base. inv (f:: AffineTransform{(:μ,:σ)} ) = AffineTransform ((μ = - (f. σ \ f. μ), ω = f. σ))
14- @inline Base. inv (f:: AffineTransform{(:μ,:ω)} ) = AffineTransform ((μ = - f. ω * f. μ, σ = f. ω))
15+ @inline Base. inv (f:: AffineTransform{(:μ, :σ)} ) =
16+ AffineTransform ((μ = - (f. σ \ f. μ), ω = f. σ))
17+ @inline Base. inv (f:: AffineTransform{(:μ, :ω)} ) = AffineTransform ((μ = - f. ω * f. μ, σ = f. ω))
1518@inline Base. inv (f:: AffineTransform{(:σ,)} ) = AffineTransform ((ω = f. σ,))
1619@inline Base. inv (f:: AffineTransform{(:ω,)} ) = AffineTransform ((σ = f. ω,))
1720@inline Base. inv (f:: AffineTransform{(:μ,)} ) = AffineTransform ((μ = - f. μ,))
1821
1922# `size(f) == (m,n)` means `f : ℝⁿ → ℝᵐ`
20- Base. size (f:: AffineTransform{(:μ,:σ)} ) = size (f. σ)
21- Base. size (f:: AffineTransform{(:μ,:ω)} ) = size (f. ω)
22- Base. size (f:: AffineTransform{(:σ,)} ) = size (f. σ)
23- Base. size (f:: AffineTransform{(:ω,)} ) = size (f. ω)
23+ Base. size (f:: AffineTransform{(:μ, :σ)} ) = size (f. σ)
24+ Base. size (f:: AffineTransform{(:μ, :ω)} ) = size (f. ω)
25+ Base. size (f:: AffineTransform{(:σ,)} ) = size (f. σ)
26+ Base. size (f:: AffineTransform{(:ω,)} ) = size (f. ω)
2427
2528function Base. size (f:: AffineTransform{(:μ,)} )
2629 (n,) = size (f. μ)
27- return (n,n)
30+ return (n, n)
2831end
2932
3033Base. size (f:: AffineTransform , n:: Int ) = @inbounds size (f)[n]
3134
3235(f:: AffineTransform{(:μ,)} )(x) = x + f. μ
3336(f:: AffineTransform{(:σ,)} )(x) = f. σ * x
3437(f:: AffineTransform{(:ω,)} )(x) = f. ω \ x
35- (f:: AffineTransform{(:μ,:σ)} )(x) = f. σ * x + f. μ
36- (f:: AffineTransform{(:μ,:ω)} )(x) = f. ω \ x + f. μ
38+ (f:: AffineTransform{(:μ, :σ)} )(x) = f. σ * x + f. μ
39+ (f:: AffineTransform{(:μ, :ω)} )(x) = f. ω \ x + f. μ
3740
3841rowsize (x) = ()
39- rowsize (x:: AbstractArray ) = (size (x,1 ),)
42+ rowsize (x:: AbstractArray ) = (size (x, 1 ),)
4043
4144function rowsize (f:: AffineTransform )
4245 size_f = size (f)
@@ -46,7 +49,7 @@ function rowsize(f::AffineTransform)
4649end
4750
4851colsize (x) = ()
49- colsize (x:: AbstractArray ) = (size (x,2 ),)
52+ colsize (x:: AbstractArray ) = (size (x, 2 ),)
5053
5154function colsize (f:: AffineTransform )
5255 size_f = size (f)
6568 return x
6669end
6770
68- @inline function apply! (x, f:: AffineTransform{(:ω,), Tuple{F}} , z) where {F<: Factorization }
71+ @inline function apply! (x, f:: AffineTransform{(:ω,),Tuple{F}} , z) where {F<: Factorization }
6972 ldiv! (x, f. ω, z)
7073 return x
7174end
7578 return x
7679end
7780
78- @inline function apply! (x, f:: AffineTransform{(:μ,:σ)} , z)
81+ @inline function apply! (x, f:: AffineTransform{(:μ, :σ)} , z)
7982 apply! (x, AffineTransform ((σ = f. σ,)), z)
8083 apply! (x, AffineTransform ((μ = f. μ,)), x)
8184 return x
8285end
8386
84- @inline function apply! (x, f:: AffineTransform{(:μ,:ω)} , z)
87+ @inline function apply! (x, f:: AffineTransform{(:μ, :ω)} , z)
8588 apply! (x, AffineTransform ((ω = f. ω,)), z)
8689 apply! (x, AffineTransform ((μ = f. μ,)), x)
8790 return x
8891end
8992
90- function logjac (x:: AbstractMatrix )
91- (m,n) = size (x)
93+ function logjac (x:: AbstractMatrix )
94+ (m, n) = size (x)
9295 m == n && return first (logabsdet (x))
9396
9497 # Equivalent to sum(log, svdvals(x)), but much faster
99102logjac (x:: Number ) = log (abs (x))
100103
101104# TODO : `log` doesn't work for the multivariate case, we need the log absolute determinant
102- logjac (f:: AffineTransform{(:μ,:σ)} ) = logjac (f. σ)
103- logjac (f:: AffineTransform{(:μ,:ω)} ) = - logjac (f. ω)
105+ logjac (f:: AffineTransform{(:μ, :σ)} ) = logjac (f. σ)
106+ logjac (f:: AffineTransform{(:μ, :ω)} ) = - logjac (f. ω)
104107logjac (f:: AffineTransform{(:σ,)} ) = logjac (f. σ)
105108logjac (f:: AffineTransform{(:ω,)} ) = - logjac (f. ω)
106109logjac (f:: AffineTransform{(:μ,)} ) = 0.0
@@ -130,16 +133,16 @@ function params(μ::Affine)
130133 return merge (nt1, nt2)
131134end
132135
133- function paramnames (:: Type{A} ) where {N,M, A<: Affine{N,M} }
136+ function paramnames (:: Type{A} ) where {N,M,A<: Affine{N,M} }
134137 tuple (union (N, paramnames (M))... )
135138end
136139
137- Base. propertynames (d:: Affine{N} ) where {N} = N ∪ (:parent ,:f )
140+ Base. propertynames (d:: Affine{N} ) where {N} = N ∪ (:parent , :f )
138141
139- @inline function Base. getproperty (d:: Affine , s:: Symbol )
142+ @inline function Base. getproperty (d:: Affine , s:: Symbol )
140143 if s === :parent
141144 return getfield (d, :parent )
142- elseif s === :f
145+ elseif s === :f
143146 return getfield (d, :f )
144147 else
145148 return getproperty (getfield (d, :f ), s)
@@ -166,18 +169,18 @@ end
166169
167170function logdensity (d:: Affine{(:μ,)} , x)
168171 z = x - d. μ
169- logdensity (d. parent, z)
172+ logdensity (d. parent, z)
170173end
171174
172- function logdensity (d:: Affine{(:μ,:σ)} , x)
175+ function logdensity (d:: Affine{(:μ, :σ)} , x)
173176 z = d. σ \ (x - d. μ)
174- logdensity (d. parent, z)
177+ logdensity (d. parent, z)
175178end
176179
177- function logdensity (d:: Affine{(:μ,:ω)} , x)
180+ function logdensity (d:: Affine{(:μ, :ω)} , x)
178181 z = d. ω * (x - d. μ)
179182 logdensity (d. parent, z)
180- end
183+ end
181184
182185# # logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
183186# @inline function logdensity(d::Affine{(:μ,:σ), P, Tuple{V,M}}, x) where {P, V<:AbstractVector, M<:AbstractMatrix}
190193# end
191194# sum(zⱼ -> logdensity(d.parent, zⱼ), z)
192195# end
193-
196+
194197# # logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
195198# @inline function logdensity(d::Affine{(:μ,:ω), P,Tuple{V,M}}, x) where {P,V<:AbstractVector, M<:AbstractMatrix}
196199# z = x - d.μ
@@ -202,31 +205,35 @@ basemeasure(d::Affine) = affine(getfield(d, :f), basemeasure(d.parent))
202205
203206# We can't do this until we know we're working with Lebesgue measure, since for
204207# example it wouldn't make sense to apply a log-Jacobian to a point measure
205- basemeasure (d:: Affine{N,L} ) where {N, L<: Lebesgue } = weightedmeasure (- logjac (d), d. parent)
208+ basemeasure (d:: Affine{N,L} ) where {N,L<: Lebesgue } = weightedmeasure (- logjac (d), d. parent)
206209
207- function basemeasure (d:: Affine{N,M} ) where {N,L<: Lebesgue , M<: ProductMeasure{Returns{L}} }
210+ function basemeasure (d:: Affine{N,M} ) where {N,L<: Lebesgue ,M<: ProductMeasure{Returns{L}} }
208211 weightedmeasure (- logjac (d), d. parent)
209212end
210213
211214logjac (d:: Affine ) = logjac (getfield (d, :f ))
212215
213- function Random. rand! (rng:: Random.AbstractRNG , d:: Affine , x:: AbstractVector{T} , z= Vector {T} (undef, size (getfield (d,:f ),2 ))) where {T}
216+ function Random. rand! (
217+ rng:: Random.AbstractRNG ,
218+ d:: Affine ,
219+ x:: AbstractVector{T} ,
220+ z = Vector {T} (undef, size (getfield (d, :f ), 2 ))
221+ ) where {T}
214222 rand! (rng, parent (d), z)
215223 f = getfield (d, :f )
216224 apply! (x, f, z)
217225 return x
218226end
219227
220-
221228# function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
222229# f = getfield(d, :f)
223230# z = rand(rng, T, parent(d))
224231# apply!(x, f, z)
225232# return z
226233# end
227234
228- supportdim (nt:: NamedTuple{(:μ,:σ)} ) = colsize (nt. σ)
229- supportdim (nt:: NamedTuple{(:μ,:ω)} ) = rowsize (nt. ω)
230- supportdim (nt:: NamedTuple{(:σ,)} ) = colsize (nt. σ)
231- supportdim (nt:: NamedTuple{(:ω,)} ) = rowsize (nt. ω)
232- supportdim (nt:: NamedTuple{(:μ,)} ) = size (nt. μ)
235+ supportdim (nt:: NamedTuple{(:μ, :σ)} ) = colsize (nt. σ)
236+ supportdim (nt:: NamedTuple{(:μ, :ω)} ) = rowsize (nt. ω)
237+ supportdim (nt:: NamedTuple{(:σ,)} ) = colsize (nt. σ)
238+ supportdim (nt:: NamedTuple{(:ω,)} ) = rowsize (nt. ω)
239+ supportdim (nt:: NamedTuple{(:μ,)} ) = size (nt. μ)
0 commit comments