lib_svm.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. from data.loader_ac import DataLoader
  3. import libsvm.svmutil as svm
  4. def train_save():
  5. loader = DataLoader()
  6. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.have]
  7. tags = np.ones(len(data)).tolist()
  8. pro = svm.svm_problem(tags, data)
  9. pra = svm.svm_parameter("-s 2 -g 2.57e-6 -n 7.1e-3")
  10. model = svm.svm_train(pro, pra)
  11. svm.svm_save_model("m2c/real.model", model)
  12. # predict test
  13. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.test]
  14. tags = np.ones(len(data)).tolist()
  15. svm.svm_predict(y=tags, x=data, m=model)
  16. # predict none
  17. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.none]
  18. tags = (np.ones(len(data)) * -1).tolist()
  19. svm.svm_predict(y=tags, x=data, m=model)
  20. def load_test():
  21. np.random.seed(5)
  22. model = svm.svm_load_model("m2c/real.model")
  23. loader = DataLoader("data/csv/both") # "csv/test"
  24. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.have]
  25. tags = np.ones(len(data)).tolist()
  26. svm.svm_predict(y=tags, x=data, m=model)
  27. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.test]
  28. tags = np.ones(len(data)).tolist()
  29. svm.svm_predict(y=tags, x=data, m=model)
  30. data = [{i + 1: line[i] for i in range(len(line))} for line in loader.none]
  31. tags = (np.ones(len(data)) * -1).tolist()
  32. svm.svm_predict(y=tags, x=data, m=model)
  33. if __name__ == "__main__":
  34. load_test()