From 4b1a2918a4698f4bc67a89e7f13f2882278997f2 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Fri, 5 Jan 2024 18:42:34 -0600 Subject: [PATCH] Add link_depth option to control maximum depth of links to follow when using --start-at --- src/main.rs | 87 ++++++++++++------- src/postprocessors.rs | 191 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 245 insertions(+), 33 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1798d1b..6d19fc0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use eyre::{eyre, Result}; use gumdrop::Options; use obsidian_export::{postprocessors::*, ExportError}; use obsidian_export::{Exporter, FrontmatterStrategy, WalkOptions}; +use std::sync::Arc; use std::{env, path::PathBuf}; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -23,6 +24,13 @@ struct Opts { #[options(no_short, help = "Only export notes under this sub-path")] start_at: Option, + #[options( + no_short, + help = "Maximum depth of links to follow when using --start-at. Does nothing if --start-at is not specified", + default = "0" + )] + link_depth: usize, + #[options( help = "Frontmatter strategy (one of: always, never, auto)", no_short, @@ -91,7 +99,7 @@ fn main() { ..Default::default() }; - let mut exporter = Exporter::new(root, destination); + let mut exporter = Exporter::new(root.clone(), destination.clone()); exporter.frontmatter_strategy(args.frontmatter_strategy); exporter.process_embeds_recursively(!args.no_recursive_embeds); exporter.walk_options(walk_options); @@ -102,38 +110,55 @@ fn main() { let tags_postprocessor = filter_by_tags(args.skip_tags, args.only_tags); exporter.add_postprocessor(&tags_postprocessor); - + let recursive_resolver: RecursiveResolver; + let shared_state: Arc = SharedResolverState::new(args.link_depth); + let mut dont_recurse = true; + let callback; if let Some(path) = args.start_at { - exporter.start_at(path); + exporter.start_at(path.clone()); + if args.link_depth > 0 { + dont_recurse = false; + recursive_resolver = + RecursiveResolver::new(root, path, destination, shared_state.clone()); + callback = |ctx: &mut obsidian_export::Context, + events: &mut Vec>| { + recursive_resolver.postprocess(ctx, events) + }; + exporter.add_postprocessor(&callback); + } } - - if let Err(err) = exporter.run() { - match err { - ExportError::FileExportError { - ref path, - ref source, - } => match &**source { - // An arguably better way of enhancing error reports would be to construct a custom - // `eyre::EyreHandler`, but that would require a fair amount of boilerplate and - // reimplementation of basic reporting. - ExportError::RecursionLimitExceeded { file_tree } => { - eprintln!( - "Error: {:?}", - eyre!( - "'{}' exceeds the maximum nesting limit of embeds", - path.display() - ) - ); - eprintln!("\nFile tree:"); - for (idx, path) in file_tree.iter().enumerate() { - eprintln!(" {}-> {}", " ".repeat(idx), path.display()); + loop { + if let Err(err) = exporter.run() { + match err { + ExportError::FileExportError { + ref path, + ref source, + } => match &**source { + // An arguably better way of enhancing error reports would be to construct a custom + // `eyre::EyreHandler`, but that would require a fair amount of boilerplate and + // reimplementation of basic reporting. + ExportError::RecursionLimitExceeded { file_tree } => { + eprintln!( + "Error: {:?}", + eyre!( + "'{}' exceeds the maximum nesting limit of embeds", + path.display() + ) + ); + eprintln!("\nFile tree:"); + for (idx, path) in file_tree.iter().enumerate() { + eprintln!(" {}-> {}", " ".repeat(idx), path.display()); + } + eprintln!("\nHint: Ensure notes are non-recursive, or specify --no-recursive-embeds to break cycles") } - eprintln!("\nHint: Ensure notes are non-recursive, or specify --no-recursive-embeds to break cycles") - } + _ => eprintln!("Error: {:?}", eyre!(err)), + }, _ => eprintln!("Error: {:?}", eyre!(err)), - }, - _ => eprintln!("Error: {:?}", eyre!(err)), - }; - std::process::exit(1); - }; + }; + std::process::exit(1); + } + if dont_recurse || shared_state.update_and_check_should_continue() { + break; + } + } } diff --git a/src/postprocessors.rs b/src/postprocessors.rs index 2ca4183..c8a4a98 100644 --- a/src/postprocessors.rs +++ b/src/postprocessors.rs @@ -1,7 +1,16 @@ //! A collection of officially maintained [postprocessors][crate::Postprocessor]. -use super::{Context, MarkdownEvents, PostprocessorResult}; -use pulldown_cmark::Event; +use std::{ + collections::BTreeSet, + fmt::DebugStruct, + path::{Path, PathBuf}, + sync::{Arc, Mutex, RwLock}, +}; + +use super::{Context, MarkdownEvents, PostprocessorResult, PERCENTENCODE_CHARS}; +use percent_encoding::{percent_decode_str, utf8_percent_encode, AsciiSet}; +use pulldown_cmark::{CowStr, Event, Tag}; +use rayon::iter::{ParallelDrainRange, ParallelIterator}; use serde_yaml::Value; /// This postprocessor converts all soft line breaks to hard line breaks. Enabling this mimics @@ -51,6 +60,184 @@ fn filter_by_tags_( } } +pub struct SharedResolverState { + depth: usize, + current_depth: RwLock, + files_to_parse: RwLock>, + linked_files: Mutex>, +} + +impl SharedResolverState { + pub fn new(depth: usize) -> Arc { + Arc::new(SharedResolverState { + depth, + current_depth: RwLock::new(0), + files_to_parse: RwLock::new(BTreeSet::new()), + linked_files: Mutex::new(Vec::new()), + }) + } + pub fn update_and_check_should_continue(&self) -> bool { + let mut current_depth = self.current_depth.write().unwrap(); + + if *current_depth < self.depth { + *current_depth += 1; + let mut files_to_parse = self.files_to_parse.write().unwrap(); + *files_to_parse = self + .linked_files + .lock() + .unwrap() + .par_drain(..) + .collect::>(); + if !files_to_parse.is_empty() { + return false; + } + } + return true; + } +} + +pub struct RecursiveResolver { + root: PathBuf, + start_at: PathBuf, + destination: PathBuf, + shared_state: Arc, +} + +impl<'a: 'url, 'url> RecursiveResolver { + pub fn new( + root: PathBuf, + start_at: PathBuf, + destination: PathBuf, + shared_state: Arc, + ) -> RecursiveResolver { + RecursiveResolver { + root, + start_at, + destination, + + shared_state: shared_state.clone(), + } + } + + pub fn start_at(&mut self, start_at: PathBuf) { + self.start_at = start_at; + } + + /// If this is the first iteration, links to files outside of start_at are changed so + /// that they are to in the root of the destination + pub fn postprocess( + &self, + context: &'a mut Context, + events: &'url mut MarkdownEvents, + ) -> PostprocessorResult { + println!("postprocess: recursive_resolver"); + match *self.shared_state.current_depth.read().unwrap() == 0 { + true => self.first_run(context, events), + false => { + if !self + .shared_state + .files_to_parse + .read() + .unwrap() + .contains(context.current_file()) + { + return PostprocessorResult::StopAndSkipNote; + } + self.other_runs(context, events) + } + } + } + + fn first_run( + &self, + context: &'a mut Context, + events: &'url mut MarkdownEvents, + ) -> PostprocessorResult { + //let path_changed = context.current_file() != &self.start_at; + for event in events.iter_mut() { + if let Event::Start(Tag::Link(_, url, _)) = event { + println!("url: {}", url); + if url.starts_with("https://") || url.starts_with("http://") { + continue; + } + + let vault_path: PathBuf = get_vault_path(url, &self.start_at.as_path()); + println!("vault_path: {}", vault_path.to_string_lossy()); + // may still be within start_at + if vault_path.starts_with(&self.start_at) { + continue; + } + + if vault_path.exists() { + let vaultless_path = vault_path.strip_prefix(self.root.as_path()).unwrap(); + set_url(url, self.destination.join(vaultless_path)); + self.shared_state + .linked_files + .lock() + .unwrap() + .push(vault_path); + } + } + } + PostprocessorResult::Continue + } + + fn other_runs( + &self, + context: &'a mut Context, + events: &'url mut MarkdownEvents, + ) -> PostprocessorResult { + //let path_changed = context.current_file() != self.start_at; + for event in events.iter_mut() { + let relative_start = self.start_at.clone().strip_prefix(&self.root).unwrap(); + if let Event::Start(Tag::Link(_, url, _)) = event { + if url.starts_with("https://") || url.starts_with("http://") { + continue; + } + let vault_path = get_vault_path(url, self.root.as_path()); + + // if it's within start_at, we need to strip the difference between root and start_at + + //let vaultless_path = vault_path.strip_prefix(self.root.as_path()).unwrap(); + if vault_path.exists() { + if vault_path.starts_with(&self.start_at) { + let link_destination = self + .destination + .join(vault_path.strip_prefix(&self.start_at).unwrap()); + set_url(url, link_destination); + } + if *self.shared_state.current_depth.read().unwrap() < self.shared_state.depth { + self.shared_state + .linked_files + .lock() + .unwrap() + .push(vault_path); + } + } + } + } + PostprocessorResult::Continue + } +} +fn get_vault_path(url: &mut CowStr<'_>, root: &Path) -> PathBuf { + let path_stub = PathBuf::from( + percent_decode_str(url.as_ref()) + .decode_utf8() + .unwrap() + .as_ref(), + ); + root.join(path_stub).canonicalize().unwrap() +} +fn set_url(url: &mut CowStr<'_>, link_destination: PathBuf) { + *url = CowStr::from( + utf8_percent_encode( + &format!("{}", link_destination.to_string_lossy()), + PERCENTENCODE_CHARS, + ) + .to_string(), + ); +} + #[test] fn test_filter_tags() { let tags = vec![