Skip to content

Commit 7c77d41

Browse files
authored
feat: Support obtaining client IP from CloudFront-Viewer-Address header (#28)
* feat: conditionally add headers for AWS CloudFront * feat: add CloudFront header to Insecure and Secure * fix: remove feature flag
1 parent 2c0e5d2 commit 7c77d41

File tree

3 files changed

+99
-4
lines changed

3 files changed

+99
-4
lines changed

src/insecure.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::rudimental::{
2-
CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, TrueClientIp,
3-
XForwardedFor, XRealIp,
2+
CfConnectingIp, CloudFrontViewerAddress, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader,
3+
TrueClientIp, XForwardedFor, XRealIp,
44
};
55
use axum::{
66
async_trait,
@@ -49,6 +49,7 @@ impl InsecureClientIp {
4949
.or_else(|| FlyClientIp::maybe_ip_from_headers(headers))
5050
.or_else(|| TrueClientIp::maybe_ip_from_headers(headers))
5151
.or_else(|| CfConnectingIp::maybe_ip_from_headers(headers))
52+
.or_else(|| CloudFrontViewerAddress::maybe_ip_from_headers(headers))
5253
.or_else(|| maybe_connect_info(extensions))
5354
.map(Self)
5455
.ok_or((

src/rudimental.rs

+88
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ pub struct TrueClientIp(pub IpAddr);
6868
#[derive(Debug)]
6969
pub struct CfConnectingIp(pub IpAddr);
7070

71+
/// Extracts a valid IP from `CloudFront-Viewer-Address` (AWS CloudFront) header
72+
///
73+
/// Rejects with a 500 error if the header is absent or the IP isn't valid
74+
#[derive(Debug)]
75+
pub struct CloudFrontViewerAddress(pub IpAddr);
76+
7177
pub(crate) trait SingleIpHeader {
7278
const HEADER: &'static str;
7379

@@ -162,6 +168,38 @@ impl_single_header!(FlyClientIp, "Fly-Client-IP");
162168
impl_single_header!(TrueClientIp, "True-Client-IP");
163169
impl_single_header!(CfConnectingIp, "CF-Connecting-IP");
164170

171+
impl SingleIpHeader for CloudFrontViewerAddress {
172+
const HEADER: &'static str = "cloudfront-viewer-address";
173+
174+
fn maybe_ip_from_headers(headers: &HeaderMap) -> Option<IpAddr> {
175+
headers
176+
.get(Self::HEADER)
177+
.and_then(|hv| hv.to_str().ok())
178+
// Spec: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/adding-cloudfront-headers.html#cloudfront-headers-viewer-location
179+
// Note: Both IPv4 and IPv6 addresses (in the specified format) do not contain
180+
// non-ascii characters, so no need to handle percent-encoding.
181+
//
182+
// CloudFront does not use `[::]:12345` style notation for IPv6 (unfortunately),
183+
// otherwise parsing via `SocketAddr` would be possible.
184+
.and_then(|hv| hv.rsplit_once(':').map(|(ip, _port)| ip))
185+
.and_then(|s| s.parse::<IpAddr>().ok())
186+
}
187+
}
188+
189+
#[async_trait]
190+
impl<S> FromRequestParts<S> for CloudFrontViewerAddress
191+
where
192+
S: Sync,
193+
{
194+
type Rejection = StringRejection;
195+
196+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
197+
Ok(Self(
198+
Self::maybe_ip_from_headers(&parts.headers).ok_or_else(Self::rejection)?,
199+
))
200+
}
201+
}
202+
165203
impl MultiIpHeader for XForwardedFor {
166204
const HEADER: &'static str = "X-Forwarded-For";
167205

@@ -532,4 +570,54 @@ mod tests {
532570
let res = app().oneshot(req).await.unwrap();
533571
assert_eq!(body_string(res.into_body()).await, "192.0.2.60");
534572
}
573+
574+
#[tokio::test]
575+
async fn cloudfront_viewer_address_ipv4() {
576+
fn app() -> Router {
577+
Router::new().route(
578+
"/",
579+
get(|ip: super::CloudFrontViewerAddress| async move { ip.0.to_string() }),
580+
)
581+
}
582+
583+
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
584+
let res = app().oneshot(req).await.unwrap();
585+
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
586+
587+
let req = Request::builder()
588+
.uri("/")
589+
.header("CloudFront-Viewer-Address", "198.51.100.10:46532")
590+
.body(Body::empty())
591+
.unwrap();
592+
let res = app().oneshot(req).await.unwrap();
593+
assert_eq!(body_string(res.into_body()).await, "198.51.100.10");
594+
}
595+
596+
#[tokio::test]
597+
async fn cloudfront_viewer_address_ipv6() {
598+
fn app() -> Router {
599+
Router::new().route(
600+
"/",
601+
get(|ip: super::CloudFrontViewerAddress| async move { ip.0.to_string() }),
602+
)
603+
}
604+
605+
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
606+
let res = app().oneshot(req).await.unwrap();
607+
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
608+
609+
let req = Request::builder()
610+
.uri("/")
611+
.header(
612+
"CloudFront-Viewer-Address",
613+
"2a09:bac1:3b20:38::17e:7:51786",
614+
)
615+
.body(Body::empty())
616+
.unwrap();
617+
let res = app().oneshot(req).await.unwrap();
618+
assert_eq!(
619+
body_string(res.into_body()).await,
620+
"2a09:bac1:3b20:38::17e:7"
621+
);
622+
}
535623
}

src/secure.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::rudimental::{
2-
CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, StringRejection,
3-
TrueClientIp, XForwardedFor, XRealIp,
2+
CfConnectingIp, CloudFrontViewerAddress, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader,
3+
StringRejection, TrueClientIp, XForwardedFor, XRealIp,
44
};
55
use axum::async_trait;
66
use axum::extract::{ConnectInfo, Extension, FromRequestParts};
@@ -44,6 +44,8 @@ pub enum SecureClientIpSource {
4444
CfConnectingIp,
4545
/// IP from the [`axum::extract::ConnectInfo`]
4646
ConnectInfo,
47+
/// IP from the `CloudFront-Viewer-Address` header
48+
CloudFrontViewerAddress,
4749
}
4850

4951
impl SecureClientIpSource {
@@ -77,6 +79,7 @@ impl FromStr for SecureClientIpSource {
7779
"TrueClientIp" => Self::TrueClientIp,
7880
"CfConnectingIp" => Self::CfConnectingIp,
7981
"ConnectInfo" => Self::ConnectInfo,
82+
"CloudFrontViewerAddress" => Self::CloudFrontViewerAddress,
8083
_ => return Err(ParseSecureClientIpSourceError(s.to_string())),
8184
})
8285
}
@@ -100,6 +103,9 @@ impl SecureClientIp {
100103
SecureClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(headers),
101104
SecureClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(headers),
102105
SecureClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(headers),
106+
SecureClientIpSource::CloudFrontViewerAddress => {
107+
CloudFrontViewerAddress::ip_from_headers(headers)
108+
}
103109
SecureClientIpSource::ConnectInfo => extensions
104110
.get::<ConnectInfo<SocketAddr>>()
105111
.map(|ConnectInfo(addr)| addr.ip())

0 commit comments

Comments
 (0)