diff --git a/src/extensions/parser.rs b/src/extensions/parser.rs index d7b6384..05f3482 100644 --- a/src/extensions/parser.rs +++ b/src/extensions/parser.rs @@ -22,6 +22,7 @@ use crate::{ sum::Sum, }, error::LogriaError, + sanitizers::sanitize_filename, }, }; @@ -54,9 +55,10 @@ impl ExtensionMethods for Parser { /// Create parser file from a Parser struct fn save(self, file_name: &str) -> Result<(), LogriaError> { let parser_json = serde_json::to_string_pretty(&self).unwrap(); - let path = format!("{}/{}", patterns(), file_name); + let sanitized_filename = sanitize_filename(file_name); + let path = format!("{}/{}", patterns(), sanitized_filename); - match write(format!("{}/{}", patterns(), file_name), parser_json) { + match write(&path, parser_json) { Ok(()) => Ok(()), Err(why) => Err(LogriaError::CannotWrite(path, ::to_string(&why))), } diff --git a/src/extensions/session.rs b/src/extensions/session.rs index 3bbd97a..bbc4037 100644 --- a/src/extensions/session.rs +++ b/src/extensions/session.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use crate::{ constants::{cli::excludes::SESSION_FILE_EXCLUDES, directories::sessions}, extensions::extension::ExtensionMethods, - util::error::LogriaError, + util::{error::LogriaError, sanitizers::sanitize_filename}, }; #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug)] @@ -39,7 +39,8 @@ impl ExtensionMethods for Session { /// Create session file from a Session struct fn save(self, file_name: &str) -> Result<(), LogriaError> { let session_json = serde_json::to_string_pretty(&self).unwrap(); - let path = format!("{}/{}", sessions(), file_name); + let sanitized_filename = sanitize_filename(file_name); + let path = format!("{}/{}", sessions(), sanitized_filename); match write(&path, session_json) { Ok(()) => Ok(()), Err(why) => Err(LogriaError::CannotWrite(path, ::to_string(&why))), diff --git a/src/util/sanitizers.rs b/src/util/sanitizers.rs index 5101399..5bac030 100644 --- a/src/util/sanitizers.rs +++ b/src/util/sanitizers.rs @@ -1,8 +1,34 @@ +use std::{cmp::max, collections::HashSet, str::from_utf8, sync::LazyLock}; + use regex::bytes::Regex; -use std::{cmp::max, str::from_utf8}; use crate::constants::cli::patterns::ANSI_COLOR_PATTERN; +/// Characters disallowed in a filename +static FILENAME_DISALLOWED_CHARS: LazyLock> = + LazyLock::new(|| HashSet::from(['*', '"', '/', '\\', '<', '>', ':', '|', '?', '.'])); +/// The character to replace disallowed chars with +const FILENAME_REPLACEMENT_CHAR: char = '_'; + +/// Remove unsafe chars in [this list](FILENAME_DISALLOWED_CHARS). +/// +/// Does not need to use a `Cow` for optimization because the source is always generated based on chat data +/// so there is no opportunity for the original input to be passed in from another borrow. +pub fn sanitize_filename(filename: &str) -> String { + filename + .trim() + .chars() + .map(|letter| { + if letter.is_control() || FILENAME_DISALLOWED_CHARS.contains(&letter) { + FILENAME_REPLACEMENT_CHAR + } else { + letter + } + }) + .take(255) + .collect() +} + pub struct LengthFinder { color_pattern: Regex, } @@ -34,7 +60,7 @@ impl LengthFinder { #[cfg(test)] mod tests { - use crate::util::sanitizers::LengthFinder; + use crate::util::sanitizers::{LengthFinder, sanitize_filename}; #[test] fn test_length_clean() { @@ -81,4 +107,44 @@ mod tests { assert_eq!(rows, 1); assert_eq!(length, 6); } + + #[test] + fn test_sanitize_filename_clean() { + assert_eq!(sanitize_filename("normal_filename"), "normal_filename"); + assert_eq!(sanitize_filename("file.txt"), "file_txt"); + assert_eq!(sanitize_filename("file_123"), "file_123"); + } + + #[test] + fn test_sanitize_filename_invalid_chars() { + assert_eq!(sanitize_filename("file<>name"), "file__name"); + assert_eq!(sanitize_filename("file:name"), "file_name"); + assert_eq!(sanitize_filename("file\"name"), "file_name"); + assert_eq!(sanitize_filename("file|name"), "file_name"); + assert_eq!(sanitize_filename("file?name"), "file_name"); + assert_eq!(sanitize_filename("file*name"), "file_name"); + assert_eq!(sanitize_filename("file\\name"), "file_name"); + assert_eq!(sanitize_filename("file/name"), "file_name"); + } + + #[test] + fn test_sanitize_filename_control_chars() { + assert_eq!(sanitize_filename("file\x00name"), "file_name"); + assert_eq!(sanitize_filename("file\x1fname"), "file_name"); + assert_eq!(sanitize_filename("file\x7fname"), "file_name"); + } + + #[test] + fn test_sanitize_filename_trim() { + assert_eq!(sanitize_filename(" filename "), "filename"); + assert_eq!(sanitize_filename("..filename.."), "__filename__"); + assert_eq!(sanitize_filename("\tfilename\t"), "filename"); + } + + #[test] + fn test_sanitize_filename_long() { + let long_name = "a".repeat(300); + let sanitized = sanitize_filename(&long_name); + assert_eq!(sanitized.len(), 255); + } }