Skip to content
Snippets Groups Projects
node.rs 11.7 KiB
Newer Older
use ollama_rs::generation::completion::request::GenerationRequest;
use std::process::Command;
use serde_json::Value;
use std::net::{
    SocketAddr, 
    UdpSocket
};
use std::{
    thread, 
    time
};
use ollama_rs::*;
use tokio::runtime::Runtime;
use chrono::Utc;
use std::sync::{
    Barrier, 
    Arc
};

use crate::data;

const BUFF_SIZE: usize = 4096;
const MODEL: &str = "mymist:latest";
const LOCALHOST: &str = "http://localhost";

/**
 * Make a node
 * 
 * @param barrier: Arc<Barrier> barrier to sync the threads
 * @param node_id: String node id
 * @param array: Value array of the node
 * @param rx: std::sync::mpsc::Receiver<Value> receiver to comunicate from the master thread
 * @param flags: data::Flags flags
 * 
 * @return 
 */
pub fn make_node(barrier: Arc<Barrier>, node_id: String , array: Value, rx: std::sync::mpsc::Receiver<Value>, flags: data::Flags) {

    if flags.debug == true {
        println!("|----------------make_node: {}----------------|", node_id);
    }

    let _child = thread::spawn(move || {
        let llm_string = array[1].get("llm").unwrap().as_str().unwrap();

        if flags.debug == true {
            println!("{} : llm -> {}, {}", node_id, llm_string, llm_string.split(":").collect::<Vec<&str>>()[1]);
        }
        
        let _ =  if cfg!(target_os = "windows") {
            Command::new("cmd")
                .args(&["/C", &format!("docker run -d -v ollama:/root/.ollama -p {}:11434 --name ollama_{} ollama/ollama", llm_string.split(":").collect::<Vec<&str>>()[1], node_id)])
                .output()
                .expect("failed to execute process")
        } else {
            Command::new("sh")
                    .arg("-c")
                    .arg(&format!("docker run -d -v ollama:/root/.ollama -p {}:11434 --name ollama_{} ollama/ollama", llm_string.split(":").collect::<Vec<&str>>()[1], node_id))
                    .output()
                    .expect("failed to execute process")
        };

        if flags.bypass ==  false {
            println!("{} -> arbitrary wait for the container to start", node_id);
            thread::sleep(time::Duration::from_secs(5));    
        }

        let mut _model = String::new();
        if flags.model == true {
            _model = flags.model_name.clone();

            let _ =  if cfg!(target_os = "windows") {
                Command::new("cmd")
                    .args(&["/C", &format!("docker exec --it ollama_{} ollama run {}", node_id, _model)])
                    .output()
                    .expect("failed to execute process")
            } else {
                Command::new("sh")
                        .arg("-c")
                        .arg(&format!("docker exec --it ollama_{} ollama run {}", node_id, _model))
                        .output()
                        .expect("failed to execute process")
            };
            Command::new("sh")
                .arg("-c")
                .arg(&format!("docker cp {} ollama_{}:/root/.ollama/{}", _model, node_id, _model))
                .arg(&format!("docker exec ollama_{} ollama create {} --file /root/.ollama/{}", node_id, _model, _model))
                .expect("failed to make model available to the container");

            if flags.bypass ==  false {
                    println!("{} -> arbitrary wait for the container to init model (30s)", node_id);
                    thread::sleep(time::Duration::from_secs(30));    
                }
                .arg("-c")
                .arg(&format!("docker exec ollama_{} ollama run {}", node_id, _model))
                .output()
                .expect("failed to execute process");
        }
        if flags.bypass ==  false {
            println!("{} -> arbitrary wait for the model to start", node_id);
            thread::sleep(time::Duration::from_secs(5));    
        }

        let node_socket = UdpSocket::bind(array[0].get("node").unwrap().as_str().unwrap()).expect("Could not bind node address");
        let mut comm_vec: Vec<SocketAddr> = Vec::new();
        
        let llm = Ollama::new(LOCALHOST.to_string(), llm_string.split(":").collect::<Vec<&str>>()[1].parse::<u16>().unwrap());

        node_socket.set_nonblocking(true).expect("Could not set non-blocking mode to node socket");

        if let Some(Value::Array(comm_arr)) = array[2].get("neighbours") {
            for comm in comm_arr {
                if let Some(comm_str) = comm.as_str() {
                    if let Ok(comm_addr) = comm_str.parse::<std::net::SocketAddr>() {
                        comm_vec.push(comm_addr);
                    } else {
                        println!("{} : Invalid comm address: {}", node_id, comm_str);
                    }
                } else {
                    println!("{} : Invalid comm address: {:?}", node_id, comm);
                }
            }
        }

        if flags.debug == true {
            println!("{} : node -> {}", node_id, node_socket.local_addr().unwrap());
            for comm in array[2].get("neighbours").iter() {
                println!("{} : comm -> {}", node_id, comm);
            }
        }

        barrier.wait();

        // node_loop(barrier, node_id, rx, node_socket, llm, array[2].get("neighbours"), debug).unwrap();
        node_loop(barrier, node_id, rx, node_socket, llm, comm_vec, flags.debug).unwrap();

    });
}

/*
* Node loop
*
* @param barrier: Arc<Barrier> barrier to sync the threads
* @param node_id: String node id
* @param rx: std::sync::mpsc::Receiver<Value> receiver to comunicate from the master thread
* @param node_socket: std::net::UdpSocket node socket
* @param llm: Ollama Ollama instance
* @param comm_vec: Vec<SocketAddr> vector of neighbours
* @param debug: bool debug flag
*
* @return std::io::Result<()>
*/
fn node_loop( barrier: Arc<Barrier>, node_id: String, rx: std::sync::mpsc::Receiver<Value>, node_socket: std::net::UdpSocket, llm: Ollama,comm_vec: Vec<SocketAddr>, debug: bool) -> std::io::Result<()> {
    let mut msg_received = false;
    let mut buf = [0; BUFF_SIZE];
    let mut data: data::Message = data::Message { id: "init".to_string(), message: "".to_string() };
    let mut response: data::Message = data::Message { id: "".to_string(), message: "".to_string() };
    let start_time = Utc::now().timestamp();
    let rt = Runtime::new().unwrap();

    loop {
        // listen to master
        let elapsed = Utc::now().timestamp() - start_time;
        match rx.try_recv() {
            Ok(Value::String(msg)) => {
                if debug == true {
                    println!("{} -> {}", node_id, msg);
                }

                let (first, rest) = match msg.split_whitespace().next() {
                    Some(first) => (first, msg.trim_start_matches(first).trim_start()),
                    None => continue,
                };

                if debug == true {
                    println!("{} (first, rest) -> {} {}", node_id, first, rest);
                }

                match first.trim().replace(" ", "").as_str() {
                    // "send" => {
                    //     msg_received = true;
                    //     data.message = rest.to_string();
                    //     if debug == true {
                    //         println!("{} : send received -> {}", node_id, data.message);
                    //     }
                    "quit" => {
                        drop(node_socket);
                        if cfg!(target_os = "windows") {
                            match Command::new("cmd")
                                .args(&["/C", &format!("docker rm -f /ollama_{}", node_id)])
                                .output() {
                                    Ok(_) => {},
                                    Err(e) => println!("{} : Error: {}", node_id, e)
                                }
                        } else {
                            match Command::new("sh")
                                .arg("-c")
                                .arg(&format!("docker rm -f /ollama_{}", node_id))
                                .output() {
                                    Ok(_) => {},
                                    Err(e) => println!("{} : Error: {}", node_id, e)
                                }
                        };

                        println!("child exited");
                        
                        barrier.wait();

                        return Ok(());
                    },
                    _ => {
                        // println!("{} : Invalid command: {}", node_id, first);
                        data.message = msg.to_string();
                        msg_received = true;
                    }
                    
                }
            },
            Ok(Value::Null) => {
                continue;
            },
            Ok(Value::Bool(_)) => {
                println!("Bool value not supported yet!");
            },
            Ok(Value::Number(_)) => {
                println!("Number value not supported yet!");
            },
            Ok(Value::Array(_)) => {
                println!("Array value not supported yet!");
            },
            Ok(Value::Object(_)) => {
                println!("Object value not supported yet!");
            },
            Err(std::sync::mpsc::TryRecvError::Empty) => {
            },
            Err(e) => println!("Error node_loop: {}", e)
        }

        match node_socket.recv_from(&mut buf) {
            Ok(_) => {
                let json_end = buf.iter().position(|&b| b == b'}');
                if let Some(end) = json_end {
                    let json_slice = &buf[..=end];
                    data = serde_json::from_slice(json_slice).unwrap();
                    if debug == true {
                        println!("{:?}:{} received -> {}", elapsed, node_id, data.message);
                    }
                    msg_received = true;
                } else {
                    println!("Invalid JSON data");
                }
            },
            Err(e) => {
                if e.kind() != std::io::ErrorKind::WouldBlock {
                    println!("{} : Error: {}", node_id, e);
                }
            }
        }
        
        rt.block_on(async {
            //listen to neighbours
            if msg_received == true {
                if debug == true {
                    println!("{:?}:{} request -> {}", elapsed, node_id, data.message);
                }
                response.message = llm.generate(GenerationRequest::new(MODEL.to_string(), data.message.clone())).await.clone().unwrap().response;
                println!("{:?}: {} response -> {}", elapsed, node_id, response.message);
                response.id = node_id.clone();
                if debug == true {
                    println!("{:?}:{} response -> {}", elapsed, node_id, response.message);
                }
                
                let send = serde_json::to_string(&response).unwrap();
                for comm in comm_vec.iter() {
                    let comm_str = comm.to_string();
                    let addr = comm_str.trim_matches(|c| c == '[' || c == ']' || c == '\"');
                    println!("{:?}:{} sending to -> {}", elapsed, node_id, addr);
                    match node_socket.send_to(send.as_bytes(), addr){
                        Ok(_) => {
                            if debug == true {
                                println!("{:?}:{} sent -> {}", elapsed, node_id, send);
                            }
                        },
                        Err(e) => {
                            println!("{} : Error: {}", node_id, e);
                        }
                    }
                }
            }
            msg_received = false;
        });
    }
}