Skip to content
Snippets Groups Projects
main.rs 9.49 KiB
Newer Older
use serde_json::Value;
use std::fs::File;
use std::io::prelude::*;
use std::io::stdin;
use std::sync::mpsc::channel;
use std::sync::{
    Arc,
    Barrier
};
use std::path::Path;

mod data;
mod node;

/**
 * Parse the JSON file and store the result in the parsed variable
 *
 * @param arg: String file path
 * @param parsed: &mut Value variable to store the result
 * @param debug: bool debug flag
 *
 * @return std::io::Result<()>
 */
fn parse_json(arg: String, parsed: &mut Value, debug: bool) -> std::io::Result<()> {
    if arg == "exit" {
        return Err(std::io::Error::new(std::io::ErrorKind::Other, "exit"));
    }

    let mut _file: Option<File> = None;
    match File::open(arg) {
        Ok(opened_file) => _file = Some(opened_file),
        Err(_) => {
            let _ = parse_json(get_input(Some("Invalid path")), parsed, debug);
            return Ok(());
        }
    }

    if let Some(ref mut actual_file) = _file {
        let mut content = String::new();
        actual_file.read_to_string(&mut content)?;
        *parsed = serde_json::from_str(&content)?;
    }

    if debug == true {
        println!(
            "|----------------parse_json--------------|\n{}",
            parsed.to_string()
        );
    }

    Ok(())
}

/**
 * Print the help message
 */
fn print_help() {
    println!("|----------------HELP-------------------|");
    println!("<node> <command> => send command to node");
    println!("exit => exit program");
    println!("list => list avaible nodes");
    println!("<node> quit => quit node");
CARDILE VINCENT's avatar
CARDILE VINCENT committed
    println!("command:\n  send <message> => send message to llm\n  quit => quit node \n sendjson <json> => send json to node");
    println!("help => display this message");
    println!("----------------------------------------");
}

/**
 * Get the input from the user
 *
 * @return String
 */
fn get_input(msg: Option<&str>) -> String {
    loop {
        println!("{}", msg.unwrap_or("").to_string());
        let mut input = String::new();
        match stdin().read_line(&mut input) {
            Ok(_) => {
                return input.trim().to_string();
            }
            Err(_) => {
                println!("Error reading input");
                return String::from("");
            }
        }
    }
}

/**
 * Main loop
 *
 * @param barrier: Arc<Barrier> barrier to sync the threads
 * @param nodes: &mut Vec<Node> vector of nodes
 * @param debug: bool debug flag
 *
 * @return std::io::Result<()>
 */
fn main_loop(barrier: Arc<Barrier>, nodes: &mut Vec<data::Node>, flags: data::Flags) -> std::io::Result<()> {
    // init input

    barrier.wait();

    let mut input = String::new();

    print_help();
    println!("nodes :");
    for node in nodes.iter() {
        println!("  {}", node.id);
    }
    println!("----------------------------------------");

    let mut did_something = false;

    loop {
        match stdin().read_line(&mut input) {
            Ok(_) => {
                match input.trim() {
                    "exit" => {
                        println!("exiting the program");
                        for node in nodes.iter() {
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                            node.chan_s.send(Value::String("quit".to_string())).unwrap();
                        }
                        return Ok(());
                    }
                    "list" => {
                        for node in nodes.iter() {
                            println!("{}", node.id);
                        }
                    }
                    _ => {
                        // check if the input is a node command
                        let mut index_to_remove = None;
                        for (index, node) in nodes.iter().enumerate() {
                            let (target, mut command) =
                                input.trim().split_at(input.trim().find(" ").unwrap_or(0));
                            command = command.trim();
                            if target == node.id {
                                let (control, data) = command.split_once("%s ").unwrap_or(("",""));
                                if flags.debug == true {
                                    println!("control: {}, data: {}", control, data);
                                }
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                                if control == "sendjson" {
                                    node.chan_s.send(serde_json::Value::String(make_request(data, command))).unwrap();
                                } else {
                                    node.chan_s.send(serde_json::Value::String(command.to_string())).unwrap();
                                }

                                did_something = true;
                                // check if the command is quit, need to remove the node if true
                                if command == "quit" {
                                    index_to_remove = Some(index);
                                }
                            }
                        }
                        if let Some(index) = index_to_remove {
                            nodes.remove(index);
                        }

                        if did_something == false {
                            print_help();
                        } else {
                            did_something = false;
                        }
                    }
                }
                input.clear();
            }
            Err(_) => print_help(),
        }
CARDILE VINCENT's avatar
CARDILE VINCENT committed

        for node in nodes.iter() {
            match node.chan_r.try_recv() {
                Ok(value) => {
                    println!("{}: {}", node.id, value.to_string());
                }
                Err(_) => (),
            }
        }
    }
}

/**
 * Get the arguments from the command line
 *
 * @param args: std::env::Args arguments
 * @param arg: &mut String file path
 * @param flags: &mut flags flags
 *
 * @return std::io::Result<()>
 */
fn get_args(args: std::env::Args, arg: &mut String, flags: &mut data::Flags) -> std::io::Result<()> {
    let mut got_path = false;
    let mut last = "none";
    for args_it in args.skip(1) {
        
        if last == "-m" {
            flags.model_name = args_it.to_string();
            last = "none";
            continue;
        }
        
        match args_it.as_str() {
            "-d" => {
                println!("Debug mode enabled");
                flags.debug = true;
            },
            "-m" => {
                flags.model = true;
                last = "-m";
            },
            "-b" => {
                flags.bypass = true;
                println!("WARNING: Bypass mode enabled");
            },
            _ => {
                if Path::new(&args_it).exists() {
                    *arg = args_it.to_string();
                    got_path = true;
                } else {
                    println!("Invalid argument: {}", args_it);
                    return Err(std::io::Error::new(std::io::ErrorKind::Other, "Invalid argument"));
                }
            }
        }
    }

    let mut _path: String = String::new();
    while got_path == false {
        _path = get_input(Some("Enter the path to the JSON file"));
        if _path == "exit" {
            return Err(std::io::Error::new(std::io::ErrorKind::Other, "exit"));
        } else if Path::new(&_path).exists() {
            got_path = true;
        } else {
            println!("Invalid path");
        }
    }

    Ok(())
}

CARDILE VINCENT's avatar
CARDILE VINCENT committed
pub fn make_request(control: &str, data: &str) -> String {
    return data::REQUEST_BODY.replace("{1}", control).replace("{2}", data);
}

fn main() -> std::io::Result<()> {
    let mut arg: String = String::new();
    let mut flags: data::Flags = data::Flags {
        debug: false,
        model: false,
        model_name: String::new(),
    };

    match get_args(std::env::args(), &mut arg, &mut flags){
        Ok(_) => (),
        Err(e) => {
            println!("Error getting arguments: {}", e.to_string());
            return Ok(());
        }
    }

    if flags.debug == true {
        println!("|----------------get_arg-------------------|");
        println!("arg: {}", arg);
        println!("model: {}", flags.model_name);
    }

    let mut parsed = serde_json::Value::Null;
    let mut nodes: Vec<data::Node> = Vec::new();

    match parse_json(arg, &mut parsed, flags.debug) {
        Ok(_) => (),
        Err(e) => {
            println!("Error parsing JSON file: {}", e.to_string());
            return Ok(());
        }
    }

    let barrier = Arc::new(Barrier::new(
        parsed.as_object().expect("Expected a JSON object").len() + 1,
    )); // +1 for the main thread

    if let Value::Object(map) = &parsed {
        for (key, value) in map {
            let (tx, rx) = channel();
CARDILE VINCENT's avatar
CARDILE VINCENT committed
            let (tx2, rx2) = channel();
            let new_node: data::Node = data::Node {
                id: key.to_string(),
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                chan_s: tx,
                chan_r: rx2,
            };
            let new_com: data::Com = data::Com {
                chan_s: tx2,
                chan_r: rx,
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                node::make_node(Arc::clone(&barrier), key.to_string() , value.clone(), new_com, flags.clone());
            }
            nodes.push(new_node);
        }
    }

    main_loop(Arc::clone(&barrier), &mut nodes, flags)?;

    barrier.wait();

    Ok(())
}