Add link_depth option to control maximum depth of links to follow when using --start-at

This commit is contained in:
Joshua Ferguson 2024-01-05 18:42:34 -06:00
parent a7517106e8
commit 4b1a2918a4
2 changed files with 245 additions and 33 deletions

View File

@ -2,6 +2,7 @@ use eyre::{eyre, Result};
use gumdrop::Options;
use obsidian_export::{postprocessors::*, ExportError};
use obsidian_export::{Exporter, FrontmatterStrategy, WalkOptions};
use std::sync::Arc;
use std::{env, path::PathBuf};
const VERSION: &str = env!("CARGO_PKG_VERSION");
@ -23,6 +24,13 @@ struct Opts {
#[options(no_short, help = "Only export notes under this sub-path")]
start_at: Option<PathBuf>,
#[options(
no_short,
help = "Maximum depth of links to follow when using --start-at. Does nothing if --start-at is not specified",
default = "0"
)]
link_depth: usize,
#[options(
help = "Frontmatter strategy (one of: always, never, auto)",
no_short,
@ -91,7 +99,7 @@ fn main() {
..Default::default()
};
let mut exporter = Exporter::new(root, destination);
let mut exporter = Exporter::new(root.clone(), destination.clone());
exporter.frontmatter_strategy(args.frontmatter_strategy);
exporter.process_embeds_recursively(!args.no_recursive_embeds);
exporter.walk_options(walk_options);
@ -102,38 +110,55 @@ fn main() {
let tags_postprocessor = filter_by_tags(args.skip_tags, args.only_tags);
exporter.add_postprocessor(&tags_postprocessor);
let recursive_resolver: RecursiveResolver;
let shared_state: Arc<SharedResolverState> = SharedResolverState::new(args.link_depth);
let mut dont_recurse = true;
let callback;
if let Some(path) = args.start_at {
exporter.start_at(path);
exporter.start_at(path.clone());
if args.link_depth > 0 {
dont_recurse = false;
recursive_resolver =
RecursiveResolver::new(root, path, destination, shared_state.clone());
callback = |ctx: &mut obsidian_export::Context,
events: &mut Vec<pulldown_cmark::Event<'_>>| {
recursive_resolver.postprocess(ctx, events)
};
exporter.add_postprocessor(&callback);
}
}
if let Err(err) = exporter.run() {
match err {
ExportError::FileExportError {
ref path,
ref source,
} => match &**source {
// An arguably better way of enhancing error reports would be to construct a custom
// `eyre::EyreHandler`, but that would require a fair amount of boilerplate and
// reimplementation of basic reporting.
ExportError::RecursionLimitExceeded { file_tree } => {
eprintln!(
"Error: {:?}",
eyre!(
"'{}' exceeds the maximum nesting limit of embeds",
path.display()
)
);
eprintln!("\nFile tree:");
for (idx, path) in file_tree.iter().enumerate() {
eprintln!(" {}-> {}", " ".repeat(idx), path.display());
loop {
if let Err(err) = exporter.run() {
match err {
ExportError::FileExportError {
ref path,
ref source,
} => match &**source {
// An arguably better way of enhancing error reports would be to construct a custom
// `eyre::EyreHandler`, but that would require a fair amount of boilerplate and
// reimplementation of basic reporting.
ExportError::RecursionLimitExceeded { file_tree } => {
eprintln!(
"Error: {:?}",
eyre!(
"'{}' exceeds the maximum nesting limit of embeds",
path.display()
)
);
eprintln!("\nFile tree:");
for (idx, path) in file_tree.iter().enumerate() {
eprintln!(" {}-> {}", " ".repeat(idx), path.display());
}
eprintln!("\nHint: Ensure notes are non-recursive, or specify --no-recursive-embeds to break cycles")
}
eprintln!("\nHint: Ensure notes are non-recursive, or specify --no-recursive-embeds to break cycles")
}
_ => eprintln!("Error: {:?}", eyre!(err)),
},
_ => eprintln!("Error: {:?}", eyre!(err)),
},
_ => eprintln!("Error: {:?}", eyre!(err)),
};
std::process::exit(1);
};
};
std::process::exit(1);
}
if dont_recurse || shared_state.update_and_check_should_continue() {
break;
}
}
}

View File

@ -1,7 +1,16 @@
//! A collection of officially maintained [postprocessors][crate::Postprocessor].
use super::{Context, MarkdownEvents, PostprocessorResult};
use pulldown_cmark::Event;
use std::{
collections::BTreeSet,
fmt::DebugStruct,
path::{Path, PathBuf},
sync::{Arc, Mutex, RwLock},
};
use super::{Context, MarkdownEvents, PostprocessorResult, PERCENTENCODE_CHARS};
use percent_encoding::{percent_decode_str, utf8_percent_encode, AsciiSet};
use pulldown_cmark::{CowStr, Event, Tag};
use rayon::iter::{ParallelDrainRange, ParallelIterator};
use serde_yaml::Value;
/// This postprocessor converts all soft line breaks to hard line breaks. Enabling this mimics
@ -51,6 +60,184 @@ fn filter_by_tags_(
}
}
pub struct SharedResolverState {
depth: usize,
current_depth: RwLock<usize>,
files_to_parse: RwLock<BTreeSet<PathBuf>>,
linked_files: Mutex<Vec<PathBuf>>,
}
impl SharedResolverState {
pub fn new(depth: usize) -> Arc<SharedResolverState> {
Arc::new(SharedResolverState {
depth,
current_depth: RwLock::new(0),
files_to_parse: RwLock::new(BTreeSet::new()),
linked_files: Mutex::new(Vec::new()),
})
}
pub fn update_and_check_should_continue(&self) -> bool {
let mut current_depth = self.current_depth.write().unwrap();
if *current_depth < self.depth {
*current_depth += 1;
let mut files_to_parse = self.files_to_parse.write().unwrap();
*files_to_parse = self
.linked_files
.lock()
.unwrap()
.par_drain(..)
.collect::<BTreeSet<PathBuf>>();
if !files_to_parse.is_empty() {
return false;
}
}
return true;
}
}
pub struct RecursiveResolver {
root: PathBuf,
start_at: PathBuf,
destination: PathBuf,
shared_state: Arc<SharedResolverState>,
}
impl<'a: 'url, 'url> RecursiveResolver {
pub fn new(
root: PathBuf,
start_at: PathBuf,
destination: PathBuf,
shared_state: Arc<SharedResolverState>,
) -> RecursiveResolver {
RecursiveResolver {
root,
start_at,
destination,
shared_state: shared_state.clone(),
}
}
pub fn start_at(&mut self, start_at: PathBuf) {
self.start_at = start_at;
}
/// If this is the first iteration, links to files outside of start_at are changed so
/// that they are to in the root of the destination
pub fn postprocess(
&self,
context: &'a mut Context,
events: &'url mut MarkdownEvents,
) -> PostprocessorResult {
println!("postprocess: recursive_resolver");
match *self.shared_state.current_depth.read().unwrap() == 0 {
true => self.first_run(context, events),
false => {
if !self
.shared_state
.files_to_parse
.read()
.unwrap()
.contains(context.current_file())
{
return PostprocessorResult::StopAndSkipNote;
}
self.other_runs(context, events)
}
}
}
fn first_run(
&self,
context: &'a mut Context,
events: &'url mut MarkdownEvents,
) -> PostprocessorResult {
//let path_changed = context.current_file() != &self.start_at;
for event in events.iter_mut() {
if let Event::Start(Tag::Link(_, url, _)) = event {
println!("url: {}", url);
if url.starts_with("https://") || url.starts_with("http://") {
continue;
}
let vault_path: PathBuf = get_vault_path(url, &self.start_at.as_path());
println!("vault_path: {}", vault_path.to_string_lossy());
// may still be within start_at
if vault_path.starts_with(&self.start_at) {
continue;
}
if vault_path.exists() {
let vaultless_path = vault_path.strip_prefix(self.root.as_path()).unwrap();
set_url(url, self.destination.join(vaultless_path));
self.shared_state
.linked_files
.lock()
.unwrap()
.push(vault_path);
}
}
}
PostprocessorResult::Continue
}
fn other_runs(
&self,
context: &'a mut Context,
events: &'url mut MarkdownEvents,
) -> PostprocessorResult {
//let path_changed = context.current_file() != self.start_at;
for event in events.iter_mut() {
let relative_start = self.start_at.clone().strip_prefix(&self.root).unwrap();
if let Event::Start(Tag::Link(_, url, _)) = event {
if url.starts_with("https://") || url.starts_with("http://") {
continue;
}
let vault_path = get_vault_path(url, self.root.as_path());
// if it's within start_at, we need to strip the difference between root and start_at
//let vaultless_path = vault_path.strip_prefix(self.root.as_path()).unwrap();
if vault_path.exists() {
if vault_path.starts_with(&self.start_at) {
let link_destination = self
.destination
.join(vault_path.strip_prefix(&self.start_at).unwrap());
set_url(url, link_destination);
}
if *self.shared_state.current_depth.read().unwrap() < self.shared_state.depth {
self.shared_state
.linked_files
.lock()
.unwrap()
.push(vault_path);
}
}
}
}
PostprocessorResult::Continue
}
}
fn get_vault_path(url: &mut CowStr<'_>, root: &Path) -> PathBuf {
let path_stub = PathBuf::from(
percent_decode_str(url.as_ref())
.decode_utf8()
.unwrap()
.as_ref(),
);
root.join(path_stub).canonicalize().unwrap()
}
fn set_url(url: &mut CowStr<'_>, link_destination: PathBuf) {
*url = CowStr::from(
utf8_percent_encode(
&format!("{}", link_destination.to_string_lossy()),
PERCENTENCODE_CHARS,
)
.to_string(),
);
}
#[test]
fn test_filter_tags() {
let tags = vec![