Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvFeatures block to represent backbone outputs #278

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions FastVision/src/FastVision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ include("blocks/bounded.jl")
include("blocks/image.jl")
include("blocks/mask.jl")
include("blocks/keypoints.jl")
include("blocks/convfeatures.jl")

include("encodings/onehot.jl")
include("encodings/imagepreprocessing.jl")
Expand Down
52 changes: 52 additions & 0 deletions FastVision/src/blocks/convfeatures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

"""
ConvFeatures{N}(n) <: Block
ConvFeatures(n, size)

Block representing features from a convolutional neural network backbone
with `n` feature channels and `N` spatial dimensions.

For example, a 2D ResNet's convolutional layers may produce a `h`x`w`x`ch` output
that is passed further to the classifier head.

## Examples

A feature block with 512 channels and variable spatial dimensions:

```julia
FastVision.ConvFeatures{2}(512)
# or equivalently
FastVision.ConvFeatures(512, (:, :))
```

A feature block with 512 channels and fixed spatial dimensions:

```julia
FastVision.ConvFeatures(512, (4, 4))
```

"""
struct ConvFeatures{N} <: Block
n::Int
size::NTuple{N, DimSize}
end

ConvFeatures{N}(n) where {N} = ConvFeatures{N}(n, ntuple(_ -> :, N))

function FastAI.checkblock(block::ConvFeatures{N}, a::AbstractArray{T, M}) where {M, N, T}
M == N + 1 || return false
return checksize(block.size, size(a)[begin:N])
end

function FastAI.mockblock(block::ConvFeatures)
rand(Float32, map(l -> l isa Colon ? 8 : l, block.size)..., block.n)
end


@testset "ConvFeatures [block]" begin
@test ConvFeatures(16, (:, :)) == ConvFeatures{2}(16)
@test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 2, 2, 16))
@test checkblock(ConvFeatures(16, (:, :)), rand(Float32, 3, 2, 16))
@test checkblock(ConvFeatures(16, (2, 2)), rand(Float32, 2, 2, 16))
@test !checkblock(ConvFeatures(16, (2, :)), rand(Float32, 3, 2, 16))
end