Skip to content

Commit 520b663

Browse files
committed
Untested CLI system
1 parent ece7e44 commit 520b663

4 files changed

Lines changed: 113 additions & 2 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ edition = "2024"
55

66
[dependencies]
77
targetgen-lib = { path = "targetgen-lib"}
8-
clap = {version = "4.5.20", features = ["default"]}
8+
clap = {version = "4.5.20", features = ["derive"]}
99
simple_logger = "5.0.0"
1010
log = "0.4.22"

src/cli_management.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::path::PathBuf;
2+
use clap::Parser;
3+
use log::debug;
4+
use targetgen_lib::generator::TargetGenerator;
5+
6+
#[derive(Parser, Debug)]
7+
#[clap(name = "targetgen", version = "0.1.0", author = "Declan Emery", about = "A tool for generating synthetic bird's eye view images for training machine learning models.")]
8+
pub struct TargetgenCli {
9+
#[clap(short, long, help = "The path to the backgrounds image directory.")]
10+
pub backgrounds: PathBuf,
11+
12+
#[clap(short, long, help = "The path to the objects image directory.")]
13+
pub objects: PathBuf,
14+
15+
#[clap(short, long, help = "The output folder.")]
16+
pub output: PathBuf,
17+
18+
#[clap(short, long, help = "The path to the annotations file.")]
19+
pub annotations: PathBuf,
20+
21+
#[clap(short, long, help = "Enable logging.")]
22+
pub enable_logging: Option<bool>,
23+
24+
#[clap(short, long, help = "The number of target images to generate.")]
25+
pub num_targets: Option<u32>,
26+
27+
#[clap(short, long, help = "The number of objects per image.")]
28+
pub num_objects: Option<u32>,
29+
30+
#[clap(short, long, help = "Whether or not to visualize the bounding boxes of the objects.")]
31+
pub visualize_bboxes: Option<bool>,
32+
33+
#[clap(short, long, help = "The color to use for the maskover effect, which basically fills the bounding box with a color.")]
34+
pub maskover_color: Option<String>,
35+
36+
#[clap(short, long, help = "Whether or not to allow duplicates of the same object within the same generated target image.")]
37+
pub permit_duplicates: Option<bool>,
38+
39+
#[clap(short, long, help = "Whether or not to allow objects to collide with each other, AKA overlap.")]
40+
pub permit_collisions: Option<bool>,
41+
42+
#[clap(short, long, help = "The size of the cache in MBs, which holds resized objects (initialization only).")]
43+
pub cache_size: Option<u8>,
44+
45+
#[clap(short, long, help = "The number of worker threads to use for generating the target images.")]
46+
pub worker_threads: Option<u8>,
47+
48+
#[clap(short, long, help = "Whether or not to compress the generated target images.")]
49+
pub compress: Option<bool>,
50+
51+
#[clap(short, long, help = "Should the objects be randomly rotated (currently only supports 90 degree rotations).")]
52+
pub do_random_rotation: Option<bool>,
53+
}
54+
55+
pub fn run(args: TargetgenCli) {
56+
if let Some(enable_logging) = args.enable_logging {
57+
if enable_logging {
58+
simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Debug).init().unwrap();
59+
}
60+
}
61+
62+
debug!("Running with args: {:?}", args);
63+
64+
let mut tg = TargetGenerator::new(args.backgrounds, args.objects, args.annotations).unwrap();
65+
66+
let num_targets = args.num_targets.unwrap_or(1);
67+
let num_objects = args.num_objects.unwrap_or(6);
68+
69+
if let Some(visualize_bboxes) = args.visualize_bboxes {
70+
tg.config.visualize_bboxes = visualize_bboxes;
71+
}
72+
73+
/*if let Some(maskover_color) = args.maskover_color {
74+
tg.config.maskover_color = Some(maskover_color.parse().unwrap());
75+
}*/
76+
77+
if let Some(permit_duplicates) = args.permit_duplicates {
78+
tg.config.permit_duplicates = permit_duplicates;
79+
}
80+
81+
if let Some(permit_collisions) = args.permit_collisions {
82+
tg.config.permit_collisions = permit_collisions;
83+
}
84+
85+
if let Some(cache_size) = args.cache_size {
86+
tg.config.cache_size = cache_size;
87+
}
88+
89+
if let Some(worker_threads) = args.worker_threads {
90+
tg.config.worker_threads = worker_threads;
91+
}
92+
93+
if let Some(compress) = args.compress {
94+
tg.config.compress = compress;
95+
}
96+
97+
if let Some(do_random_rotation) = args.do_random_rotation {
98+
tg.config.do_random_rotation = do_random_rotation;
99+
}
100+
101+
tg.generate_targets(num_targets, ..num_objects, args.output).unwrap();
102+
103+
tg.close();
104+
105+
debug!("Finished running.");
106+
}

src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
mod cli_management;
2+
3+
use clap::Parser;
14
use log::LevelFilter;
25
use simple_logger::SimpleLogger;
36
use targetgen_lib::generator;
47
use targetgen_lib::generator::TargetGenerator;
8+
use crate::cli_management::run;
59

610
fn main() {
7-
11+
run(cli_management::TargetgenCli::parse());
812
}
913

1014
#[ignore]

targetgen-lib/src/generator/util.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,5 @@ fn test_resize_ratio() {
9494
assert_eq!(resize_ratio(1.0, 35.0), 35.0);
9595
assert_eq!(resize_ratio(1.0, 70.0), 70.0);
9696
assert_eq!(resize_ratio(1.0, 105.0), 105.0);
97+
assert_eq!(resize_ratio(2.0, 140.0), 280.0);
9798
}

0 commit comments

Comments
 (0)