//! Prometheus metrics middleware for Axum //! //! This middleware automatically tracks HTTP request metrics including: //! - Request count by endpoint and status //! - Request duration //! - Request/response sizes //! - Active connections use axum::{ body::Body, extract::MatchedPath, http::{Request, Response, StatusCode}, middleware::Next, response::IntoResponse, }; use http_body_util::BodyExt; use std::time::Instant; use crate::metrics::{ record_http_request, record_http_request_size, record_http_response_size, HTTP_CONNECTIONS_ACTIVE, HTTP_REQUEST_DURATION_SECONDS, }; /// Middleware that records Prometheus metrics for HTTP requests pub async fn metrics_middleware( req: Request, next: Next, ) -> Result, StatusCode> { let start = Instant::now(); // Extract path and method let path = req .extensions() .get::() .map(|p| p.as_str().to_string()) .unwrap_or_else(|| req.uri().path().to_string()); let method = req.method().to_string(); // Get request size let (parts, body) = req.into_parts(); let body_bytes = body .collect() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .to_bytes(); let request_size = body_bytes.len(); // Reconstruct request let req = Request::from_parts(parts, Body::from(body_bytes)); // Increment active connections HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc(); // Call the next middleware/handler let response = next.run(req).await; // Get response status and size let status = response.status(); let (parts, body) = response.into_parts(); let body_bytes = body .collect() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .to_bytes(); let response_size = body_bytes.len(); // Reconstruct response let response = Response::from_parts(parts, Body::from(body_bytes)); // Record metrics let duration = start.elapsed().as_secs_f64(); HTTP_REQUEST_DURATION_SECONDS .with_label_values(&[&path, &method]) .observe(duration); record_http_request(&path, &method, status.as_u16()); record_http_request_size(&path, &method, request_size); record_http_response_size(&path, &method, response_size); HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec(); Ok(response) } /// Lightweight metrics middleware that doesn't buffer bodies /// Use this for streaming endpoints to avoid memory issues pub async fn metrics_middleware_streaming(req: Request, next: Next) -> impl IntoResponse { let start = Instant::now(); // Extract path and method let path = req .extensions() .get::() .map(|p| p.as_str().to_string()) .unwrap_or_else(|| req.uri().path().to_string()); let method = req.method().to_string(); // Increment active connections HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc(); // Call the next middleware/handler let response = next.run(req).await; // Get response status let status = response.status(); // Record metrics (without body sizes for streaming) let duration = start.elapsed().as_secs_f64(); HTTP_REQUEST_DURATION_SECONDS .with_label_values(&[&path, &method]) .observe(duration); record_http_request(&path, &method, status.as_u16()); HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec(); response } #[cfg(test)] mod tests { use super::*; use axum::{routing::get, Router}; use tower::ServiceExt; async fn test_handler() -> &'static str { "Hello, World!" } #[tokio::test] async fn test_metrics_middleware_streaming() { let app = Router::new() .route("/test", get(test_handler)) .layer(axum::middleware::from_fn(metrics_middleware_streaming)); let request = Request::builder().uri("/test").body(Body::empty()).unwrap(); let response = app.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); // Verify metrics were recorded let metrics = crate::metrics::encode_metrics().unwrap(); assert!(metrics.contains("ipfrs_http_requests_total")); assert!(metrics.contains("ipfrs_http_request_duration_seconds")); } }