나이브 베이지안 분류 기법(Naive Bayesian classifier)은 학습, 구현 과정이 쉽지만 성능도 잘 나오는 분류 방법입니다.


이 방법을 쓰려면 모든 속성이 서로 독립적이라는 가정이 있어야 합니다.


간략한 설명을 보시려면 http://bcho.tistory.com/1010 링크를 확인하시기 바랍니다.



이번 포스트에서는 나이브 베이지안 분류기를 자바로 구현하고자 하였습니다.


또한 구현한 모델(오브젝트)을 파일로 저장하고, 추후 읽어와서 쓸 수 있게 하여 학습 데이터를 매번 학습할 필요가 없도록 했습니다.



아래는 나이브 베이지안을 자바로 구현한 예제코드 전문입니다.


깃허브 파일 링크 : [1]


import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;

public class NaiveBayes {
	public static ArrayList<String> getData(String fileName) {
		ArrayList<String> data = new ArrayList<String>();
		
		BufferedReader inputStream = null;
		
		try {
			inputStream = new BufferedReader(new FileReader(fileName));
			
			String l;
			while ((l = inputStream.readLine()) != null) 
				data.add(l);
		} catch (IOException e) {
			System.err.println("getData: "+e.getMessage());
			System.exit(1);
		}
		
		return data;
	}
	
	public static Double[] getGaussianModel(ArrayList<Double> values) {
		Double[] g = new Double[2];
		
		double mu = 0d;
		double ss = 0d;
		
		for (int i = 0; i < values.size(); i++)
			mu += values.get(i);
		mu /= values.size();
		
		for (int i = 0; i < values.size(); i++)
			ss += (values.get(i)-mu)*(values.get(i)-mu);
		ss /= values.size();
		
		g[0] = mu;
		g[1] = Math.sqrt(ss);
		
		return g;
	}
	
	public static HashMap<String,Double[]> getNumerical(ArrayList<String> data, int col) {
		HashMap<String,Double[]> prb = new HashMap<String,Double[]>(); // value = {mu, sigma}

		HashMap<String,ArrayList<Double>> classes = new HashMap<String,ArrayList<Double>>();
		for (int i = 0; i < data.size(); i++) {
			String[] arr = data.get(i).split(",");
			ArrayList<Double> values = new ArrayList<Double>();
			if (classes.containsKey(arr[arr.length-1]))
					values = classes.get(arr[arr.length-1]);
			values.add(Double.parseDouble(arr[col]));
			classes.put(arr[arr.length-1], values);
		}
		
		Iterator<String> itr = classes.keySet().iterator();
		while(itr.hasNext()) {
			String key = itr.next();
			prb.put(key, getGaussianModel(classes.get(key)));
		}
		
		return prb;
	}
	
	public static HashMap<String,Double> getCategorical(ArrayList<String> data, int col) {
		HashMap<String,Double> prb = new HashMap<String,Double>();
		HashMap<String,Integer> counts = new HashMap<String,Integer>();
		
		int countAll = 0;
		for (int i = 0; i < data.size(); i++) {
			String[] arr = data.get(i).split(",");
			int count = 0;
			if (counts.containsKey(arr[col]+","+arr[arr.length-1]))
					count = counts.get(arr[col]+","+arr[arr.length-1]);
			count++;
			countAll++;
			counts.put(arr[col]+","+arr[arr.length-1], count);
		}
		
		for (int i = 0; i < counts.keySet().size(); i++) {
			String key = (String) counts.keySet().toArray()[i];
			prb.put(key, (double)counts.get(key) / (double)countAll);
		}
		
		return prb;
	}
	
	public static HashMap<String,Double> getClass(ArrayList<String> data) {
		HashMap<String,Double> prb = new HashMap<String,Double>();
		HashMap<String,Integer> counts = new HashMap<String,Integer>();
		
		int countAll = 0;
		for (int i = 0; i < data.size(); i++) {
			String[] arr = data.get(i).split(",");
			int count = 0;
			if (counts.containsKey(arr[arr.length-1]))
				count = counts.get(arr[arr.length-1]);
			count++;
			countAll++;
			counts.put(arr[arr.length-1], count);
		}
		
		for (int i = 0; i < counts.keySet().size(); i++) {
			String key = (String)counts.keySet().toArray()[i];
			int value = counts.get(key);

			prb.put(key, (double)value/(double)countAll);
		}
		
		return prb;
	}
	
	public static void printPrb(HashMap<String,Double> prb) {
		System.out.println("printPrb");
		for (int i = 0; i < prb.keySet().size(); i++) {
			String key = (String) prb.keySet().toArray()[i];
			System.out.println(key+"\t"+prb.get(key));
		}
		System.out.println();
	}

	public static void printPrb2(HashMap<String,Double[]> prb) {
		System.out.println("printPrb");
		for (int i = 0; i < prb.keySet().size(); i++) {
			String key = (String) prb.keySet().toArray()[i];
			System.out.println(key+"\t"+prb.get(key)[0]+"\t"+prb.get(key)[1]);
		}
		System.out.println();
	}
	
	public static double getGassusianValue(Double[] ms, double val) {
		return 1.0/(ms[1]*Math.sqrt(2*Math.PI))*Math.exp(-(val-ms[0])*(val-ms[0]) / (2.0*ms[1]*ms[1])); 
	}
	
	public static void printAttribute(String str) {
		String[] arr = str.split(",");
		for (int i = 0; i < arr.length; i++)
			System.out.print(arr[i]+" ");
		System.out.println();
	}
	
	public static void getTest(ArrayList<String> test, HashMap<String,Double> prb_class, HashMap<String,Object>[] prb_attributes, int[] categoric, int[] numeric, int N) {
		Set<String> classes = prb_class.keySet();
		
		for (int i = 0; i < test.size(); i++) {
			double yesOrNo = Double.MIN_VALUE;
			String isYes = "";
			String[] attr = test.get(i).split(",");
			
			// PRINT ATTRIBUTES
			printAttribute(test.get(i));
			for (int c = 0; c < classes.size(); c++) {
				double cond_prb = 1d;
				String cls = (String)classes.toArray()[c];
				for (int j = 0; j < categoric.length; j++) {
					int col = categoric[j];
					
					StringBuffer key = new StringBuffer();
					key.append(attr[col]);
					key.append(",");
					key.append(cls);
					
					HashMap<String,Object> temp_prb = prb_attributes[col];
					
					if (temp_prb.containsKey(key.toString()))
						cond_prb *= (Double)temp_prb.get(key.toString());
					else
						cond_prb = 1.0 / ((double) N+temp_prb.keySet().size()); // simply smoothing
					
					if (cond_prb > yesOrNo) {
						yesOrNo = cond_prb;
						isYes = cls;
					}
					// TEST
//					System.out.println(key.toString()+"\t"+(Double)temp_prb.get(key.toString()));
//					System.out.println();
				}
				
				for (int j = 0; j < numeric.length; j++) {
					int col = numeric[j];
					
					StringBuffer key = new StringBuffer();
					key.append(cls);
					
					HashMap<String,Object> temp_prb = prb_attributes[col];
					
					Double[] ms = (Double[]) temp_prb.get(key.toString());
					
					cond_prb *= getGassusianValue(ms, Double.parseDouble(attr[col]));
					
					if (cond_prb > yesOrNo) {
						yesOrNo = cond_prb;
						isYes = cls;
					}
					
					// TEST
//					System.out.println(key.toString()+"\t"+getGassusianValue(ms, Double.parseDouble(attr[col])));
//					System.out.println();
				}
				
				System.out.println(cls+"\t"+cond_prb);
			}
			System.out.println(isYes+"\n");
		}
	}
	
	public static void saveModel(Object obj, String fileName) {
		try {
			FileOutputStream fout = new FileOutputStream(fileName);
			ObjectOutputStream oos = new ObjectOutputStream(fout);
			oos.writeObject(obj);
		} catch (IOException e) {
			System.err.println("saveModel: "+e.getMessage());
			System.exit(1);
		}
	}

	public static Object getModel(String fileName) {
		Object obj = new Object();
		try {
			FileInputStream fin = new FileInputStream(fileName);
			ObjectInputStream ios = new ObjectInputStream(fin);
			obj = ios.readObject();
		} catch (IOException e) {
			System.err.println("getModel: "+e.getMessage());
			System.exit(1);
		} catch (ClassNotFoundException e) {
			System.err.println("getModel: "+e.getMessage());
			System.exit(1);
		}
		return obj;
	}
	
	public static void main(String[] args) {
		String fileName = "/home/spark/bigDataProcessing/module2/play_tennis.csv";
		String fileName2 = "/home/spark/bigDataProcessing/module2/play_tennis_test.csv";
		ArrayList<String> training = getData(fileName);
		ArrayList<String> test = getData(fileName2);
		
		int[] col_categoric = {0,3};
		int[] col_numeric = {1,2};
		
		// PROBABILITY of CLASSES
		HashMap<String,Double> prb_class = getClass(training);
//		printPrb(prb_class);
		
		// PROBABILITY of ATTRIBUTES
		HashMap[] prb_attributes;
		prb_attributes = new HashMap[4];

		// For CATEGORIC ATTRIBUTES
		for (int i = 0; i < col_categoric.length; i++) {
			prb_attributes[col_categoric[i]] = getCategorical(training, col_categoric[i]);
//			printPrb(prb_attributes[col_categoric[i]]);
		}
		
		// For NUMERIC ATTRIBUTES
		for (int i = 0; i < col_numeric.length; i++) {
			prb_attributes[col_numeric[i]] = getNumerical(training, col_numeric[i]);
//			printPrb2(prb_attributes[col_numeric[i]]);
		}
		
		// TEST with MODEL OBJECT
		getTest(test, prb_class, prb_attributes, col_categoric, col_numeric, training.size());
		
		// SAVE MODEL as FILE
		String file1 = "prb_class";
		String file2 = "prb_attributes";
		saveModel(prb_class, file1);
		saveModel(prb_attributes, file2);
		
		// GET MODEL from FILE
		HashMap<String,Double> prb_class2 = (HashMap<String,Double>) getModel(file1);
		HashMap[] prb_attributes2 = (HashMap[]) getModel(file2);

		// TEST from FILE OBJECT
		getTest(test, prb_class2, prb_attributes2, col_categoric, col_numeric, training.size());
	}
}



소스에서 saveModel, getModel 함수로 오브젝트를 파일로 쓰고, 읽는 예제를 포함했습니다.


	public static void saveModel(Object obj, String fileName) {
		try {
			FileOutputStream fout = new FileOutputStream(fileName);
			ObjectOutputStream oos = new ObjectOutputStream(fout);
			oos.writeObject(obj);
		} catch (IOException e) {
			System.err.println("saveModel: "+e.getMessage());
			System.exit(1);
		}
	}

	public static Object getModel(String fileName) {
		Object obj = new Object();
		try {
			FileInputStream fin = new FileInputStream(fileName);
			ObjectInputStream ios = new ObjectInputStream(fin);
			obj = ios.readObject();
		} catch (IOException e) {
			System.err.println("getModel: "+e.getMessage());
			System.exit(1);
		} catch (ClassNotFoundException e) {
			System.err.println("getModel: "+e.getMessage());
			System.exit(1);
		}
		return obj;
	}



메인 함수에서 하단 부에 파일로 읽어온 오브젝트를 가져오고 알맞게 형변환을 하여 모델을 테스트할 수 있습니다.


		// GET MODEL from FILE
		HashMap<String,Double> prb_class2 = (HashMap<String,Double>) getModel(file1);
		HashMap[] prb_attributes2 = (HashMap[]) getModel(file2);


Posted by 공돌이pooh
,