Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/handlers/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use actix_web::{
Error, HttpMessage, HttpRequest, Route,
dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
error::{ErrorBadRequest, ErrorForbidden, ErrorUnauthorized},
http::header::{self, HeaderName, HeaderValue},
http::header::{self, HeaderMap, HeaderName, HeaderValue},
};
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use chrono::{Duration, TimeDelta, Utc};
Expand Down Expand Up @@ -194,7 +194,7 @@ where
}

let auth_result: Result<_, Error> = (self.auth_method)(&mut req, self.action);

let headers = req.headers().clone();
let fut = self.service.call(req);
Box::pin(async move {
let Ok(key) = key else {
Expand All @@ -209,7 +209,7 @@ where

// if session is expired, refresh token
if sessions().is_session_expired(&key) {
refresh_token(user_and_tenant_id, &key).await?;
refresh_token(user_and_tenant_id, &key, headers).await?;
}

match auth_result? {
Expand Down Expand Up @@ -296,6 +296,7 @@ fn get_user_and_tenant(
pub async fn refresh_token(
user_and_tenant_id: Result<(Result<String, RBACError>, Option<String>), RBACError>,
key: &SessionKey,
headers: HeaderMap,
) -> Result<(), Error> {
let oidc_client = OIDC_CLIENT.get();

Expand All @@ -320,8 +321,7 @@ pub async fn refresh_token(
let refreshed_token = match client
.read()
.await
.client()
.refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str()))
.refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str()), headers)
.await
{
Ok(bearer) => bearer,
Expand Down Expand Up @@ -571,6 +571,10 @@ where
header::COOKIE,
HeaderValue::from_str(&format!("session={}", id)).unwrap(),
);

// remove basic auth header
req.headers_mut().remove(header::AUTHORIZATION);

let session = SessionKey::SessionId(id);
req.extensions_mut().insert(session.clone());
Users.new_session(&user, session, TimeDelta::seconds(20));
Expand Down
35 changes: 6 additions & 29 deletions src/handlers/http/modal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use base64::{Engine, prelude::BASE64_STANDARD};
use bytes::Bytes;
use futures::future;
use once_cell::sync::OnceCell;
use openid::Discovered;
use relative_path::RelativePathBuf;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
Expand All @@ -40,15 +39,15 @@ use crate::{
correlation::CORRELATIONS,
hottier::{HotTierManager, StreamHotTier},
metastore::metastore_traits::MetastoreObject,
oidc::{Claims, DiscoveredClient},
oauth::{OAuthProvider, connect_oidc},
option::Mode,
parseable::{DEFAULT_TENANT, PARSEABLE},
storage::{ObjectStorageProvider, PARSEABLE_ROOT_DIRECTORY},
users::{dashboards::DASHBOARDS, filters::FILTERS},
utils::get_node_id,
};

use super::{API_BASE_PATH, API_VERSION, cross_origin_config, health_check, resource_check};
use super::{cross_origin_config, health_check, resource_check};

pub mod ingest;
pub mod ingest_server;
Expand All @@ -58,28 +57,7 @@ pub mod server;
pub mod ssl_acceptor;
pub mod utils;

pub type OpenIdClient = Arc<openid::Client<Discovered, Claims>>;

pub static OIDC_CLIENT: OnceCell<Arc<RwLock<GlobalClient>>> = OnceCell::new();

#[derive(Debug)]
pub struct GlobalClient {
client: DiscoveredClient,
}

impl GlobalClient {
pub fn set(&mut self, client: DiscoveredClient) {
self.client = client;
}

pub fn client(&self) -> &DiscoveredClient {
&self.client
}

pub fn new(client: DiscoveredClient) -> Self {
Self { client }
}
}
pub static OIDC_CLIENT: OnceCell<Arc<RwLock<Box<dyn OAuthProvider>>>> = OnceCell::new();

// to be decided on what the Default version should be
pub const DEFAULT_VERSION: &str = "v4";
Expand Down Expand Up @@ -114,10 +92,9 @@ pub trait ParseableServer {
Self: Sized,
{
if let Some(config) = oidc_client {
let client = config
.connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code"))
.await?;
OIDC_CLIENT.get_or_init(|| Arc::new(RwLock::new(GlobalClient::new(client))));
let gc = connect_oidc(config).await?;
OIDC_CLIENT
.get_or_init(|| Arc::new(RwLock::new(Box::new(gc) as Box<dyn OAuthProvider>)));
}

// get the ssl stuff
Expand Down
1 change: 1 addition & 0 deletions src/handlers/http/modal/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl ParseableServer for Server {
.service(Self::get_llm_webscope())
.service(Self::get_oauth_webscope())
.service(Self::get_user_role_webscope())
.service(Self::get_roles_webscope())
.service(Self::get_counts_webscope().wrap(from_fn(
resource_check::check_resource_utilization_middleware,
)))
Expand Down
Loading
Loading