Skip to content

Commit

Permalink
fix session
Browse files Browse the repository at this point in the history
  • Loading branch information
levkk committed Dec 4, 2024
1 parent dfdad53 commit 8271b74
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 82 deletions.
2 changes: 1 addition & 1 deletion examples/auth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Controller for ProtectedAreaController {
}

async fn handle(&self, request: &Request) -> Result<Response, Error> {
let session = request.session().unwrap();
let session = request.session();
let welcome = format!("<h1>Welcome, user {:?}</h1>", session.session_id);
Ok(Response::new().html(welcome))
}
Expand Down
6 changes: 2 additions & 4 deletions examples/turbo/src/controllers/signup/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ pub struct LoggedInCheck;
#[rwf::async_trait]
impl Middleware for LoggedInCheck {
async fn handle_request(&self, request: Request) -> Result<Outcome, Error> {
if let Some(session) = request.session() {
if session.authenticated() {
return Ok(Outcome::Stop(request, Response::new().redirect("/chat")));
}
if request.session().authenticated() {
return Ok(Outcome::Stop(request, Response::new().redirect("/chat")));
}

Ok(Outcome::Forward(request))
Expand Down
9 changes: 4 additions & 5 deletions rwf-tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,10 @@ impl RestController for BasePlayerController {
type Resource = i64;

async fn get(&self, request: &Request, id: &i64) -> Result<Response, Error> {
if let Some(session) = request.session() {
session
.websocket()
.send(websocket::Message::Text("controller websocket".into()))?;
}
request
.session()
.websocket()
.send(websocket::Message::Text("controller websocket".into()))?;
Ok(Response::new().html(format!("<h1>base player controller, id: {}</h1>", id)))
}

Expand Down
6 changes: 1 addition & 5 deletions rwf/src/controller/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,7 @@ impl SessionAuth {
#[async_trait]
impl Authentication for SessionAuth {
async fn authorize(&self, request: &Request) -> Result<bool, Error> {
if let Some(session) = request.session() {
Ok(session.authenticated())
} else {
Ok(false)
}
Ok(request.session().authenticated())
}

async fn denied(&self, _request: &Request) -> Result<Response, Error> {
Expand Down
5 changes: 1 addition & 4 deletions rwf/src/controller/middleware/csrf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ impl Middleware for Csrf {
}

let header = request.header(CSRF_HEADER);
let session_id = match request.session_id() {
Some(session_id) => session_id.to_string(),
None => return Ok(Outcome::Stop(request, Response::csrf_error())),
};
let session_id = request.session_id().to_string();

if let Some(header) = header {
if csrf_token_validate(header, &session_id) {
Expand Down
6 changes: 1 addition & 5 deletions rwf/src/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,7 @@ pub trait WebsocketController: Controller {
) -> Result<bool, Error> {
use tokio::sync::broadcast::error::RecvError;

let session_id = if let Some(session) = request.session() {
session.session_id.clone()
} else {
return Err(Error::SessionMissingError);
};
let session_id = request.session().session_id.clone();

info!(
"{} {} {} connected",
Expand Down
58 changes: 18 additions & 40 deletions rwf/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
#[derive(Debug, Clone)]
pub struct Request {
head: Head,
session: Option<Session>,
session: Session,
inner: Arc<Inner>,
params: Option<Arc<Params>>,
received_at: OffsetDateTime,
Expand All @@ -35,7 +35,7 @@ impl Default for Request {
fn default() -> Self {
Self {
head: Head::default(),
session: None,
session: Session::default(),
inner: Arc::new(Inner::default()),
params: None,
received_at: OffsetDateTime::now_utc(),
Expand Down Expand Up @@ -100,8 +100,8 @@ impl Request {
let cookies = head.cookies();

let (session, renew_session) = match cookies.get_session()? {
Some(session) => (Some(session), false),
None => (Some(Session::anonymous()), true),
Some(session) => (session, false),
None => (Session::anonymous(), true),
};

Ok(Request {
Expand Down Expand Up @@ -214,8 +214,8 @@ impl Request {
/// Get the session set on the request, if any. While all requests served
/// by Rwf should have a session (guest or authenticated), some HTTP clients
/// may not send the cookie back (e.g. cURL won't).
pub fn session(&self) -> Option<&Session> {
self.session.as_ref()
pub fn session(&self) -> &Session {
&self.session
}

/// Was the CSRF protection bypassed on this request?
Expand All @@ -235,22 +235,16 @@ impl Request {
///
/// This should uniquely identify a browser if it's a guest session,
/// or a user if the user is logged in.
pub fn session_id(&self) -> Option<SessionId> {
self.session
.as_ref()
.map(|session| session.session_id.clone())
pub fn session_id(&self) -> SessionId {
self.session.session_id.clone()
}

/// Get the authenticated user's ID. Combined with the `?` operator,
/// will return `403 - Unauthorized` if not logged in.
pub fn user_id(&self) -> Result<i64, Error> {
if let Some(session_id) = self.session_id() {
match session_id {
SessionId::Authenticated(id) => Ok(id),
_ => Err(Error::Forbidden),
}
} else {
Err(Error::Forbidden)
match self.session_id() {
SessionId::Authenticated(id) => Ok(id),
_ => Err(Error::Forbidden),
}
}

Expand All @@ -273,9 +267,7 @@ impl Request {
/// ```
pub async fn user<T: Model>(&self, conn: &mut ConnectionGuard) -> Result<Option<T>, Error> {
match self.session_id() {
Some(SessionId::Authenticated(user_id)) => {
Ok(Some(T::find(user_id).fetch(conn).await?))
}
SessionId::Authenticated(user_id) => Ok(Some(T::find(user_id).fetch(conn).await?)),

_ => Ok(None),
}
Expand All @@ -294,8 +286,9 @@ impl Request {
///
/// This is automatically done by the HTTP server,
/// if the session is available.
pub fn set_session(mut self, session: Option<Session>) -> Self {
pub(crate) fn set_session(mut self, session: Session) -> Self {
self.session = session;
self.renew_session = true;
self
}

Expand Down Expand Up @@ -328,10 +321,7 @@ impl Request {
/// let response = request.login(1234);
/// ```
pub fn login(&self, user_id: i64) -> Response {
let mut session = self
.session()
.map(|s| s.clone())
.unwrap_or(Session::empty());
let mut session = self.session.clone();
session.session_id = SessionId::Authenticated(user_id);
Response::new().set_session(session).html("")
}
Expand Down Expand Up @@ -385,12 +375,7 @@ impl Request {
/// let response = request.logout();
/// ```
pub fn logout(&self) -> Response {
let mut session = self
.session()
.map(|s| s.clone())
.unwrap_or(Session::empty());
session.session_id = SessionId::default();
Response::new().set_session(session).html("")
Response::new().set_session(Session::anonymous()).html("")
}

pub(crate) fn renew_session(&self) -> bool {
Expand All @@ -416,13 +401,7 @@ impl ToTemplateValue for Request {
"query".to_string(),
self.path().query().to_string().to_template_value()?,
);
hash.insert(
"session".to_string(),
match self.session() {
Some(session) => session.to_template_value()?,
None => Value::Null,
},
);
hash.insert("session".to_string(), self.session().to_template_value()?);
Ok(Value::Hash(hash))
}
}
Expand Down Expand Up @@ -479,12 +458,11 @@ pub mod test {
assert_eq!(req.peer(), &dummy_ip());
assert_eq!(req.upgrade_websocket(), false);
assert_eq!(req.skip_csrf(), false);
assert_eq!(req.session(), None);
assert!(!req.session().authenticated());
assert!(req.user_id().is_err());
assert_eq!(req.body(), b"12345");
assert_eq!(req.string(), "12345".to_string());
assert!(req.form_data().is_err());
assert!(req.session_id().is_none());
assert_eq!(req.query().len(), 1);
assert_eq!(req.path().base(), "/apples");

Expand Down
25 changes: 9 additions & 16 deletions rwf/src/http/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,28 +228,21 @@ impl Response {
///
/// This makes sure a valid session cookie is set on all responses.
pub fn from_request(mut self, request: &Request) -> Result<Self, Error> {
// Set an anonymous session if none is set on the request.
if self.session.is_none() && request.session().is_none() {
self.session = Some(Session::anonymous());
}

// Session set manually on the request already.
if let Some(ref session) = self.session {
self.cookies.add_session(&session)?;
} else {
let session = request.session();

if let Some(session) = session {
if session.should_renew() || request.renew_session() {
let session = session
.clone()
.renew(get_config().general.session_duration());
self.cookies.add_session(&session)?;

// Set the session on the response, so it can be
// passed down in handle_stream.
self.session = Some(session);
}
if session.should_renew() || request.renew_session() {
let session = session
.clone()
.renew(get_config().general.session_duration());
self.cookies.add_session(&session)?;

// Set the session on the response, so it can be
// passed down in handle_stream.
self.session = Some(session);
}
}

Expand Down
5 changes: 4 additions & 1 deletion rwf/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ impl Server {

// Set the session on the request before we pass it down
// to the stream handler.
let request = request.set_session(response.session().clone());
let request = match response.session().clone() {
Some(session) => request.set_session(session),
None => request,
};
let ok = response.status().ok();

// Calculate duration.
Expand Down
2 changes: 1 addition & 1 deletion rwf/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::colors::MaybeColorize;
use crate::config::get_config;

use pool::{ConnectionRequest, ToConnectionRequest};
use pool::ToConnectionRequest;
use std::time::{Duration, Instant};
use tracing::{error, info};

Expand Down

0 comments on commit 8271b74

Please sign in to comment.