svm-predict.c 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include <stdio.h>
  2. #include <ctype.h>
  3. #include <stdlib.h>
  4. #include <string.h>
  5. #include "svm.h"
  6. int print_null(const char *s, ...) { return 0; }
  7. static int (*info)(const char *fmt, ...) = &printf;
  8. struct svm_node *x;
  9. int max_nr_attr = 64;
  10. struct svm_model *model;
  11. int predict_probability = 0;
  12. static char *line = NULL;
  13. static int max_line_len;
  14. static char *readline(FILE *input) {
  15. int len;
  16. if (fgets(line, max_line_len, input) == NULL)
  17. return NULL;
  18. while (strrchr(line, '\n') == NULL) {
  19. max_line_len *= 2;
  20. line = (char *) realloc(line, max_line_len);
  21. len = (int) strlen(line);
  22. if (fgets(line + len, max_line_len - len, input) == NULL)
  23. break;
  24. }
  25. return line;
  26. }
  27. void exit_input_error(int line_num) {
  28. fprintf(stderr, "Wrong input format at line %d\n", line_num);
  29. exit(1);
  30. }
  31. void predict(FILE *input, FILE *output) {
  32. int correct = 0;
  33. int total = 0;
  34. double error = 0;
  35. double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
  36. int svm_type = svm_get_svm_type(model);
  37. int nr_class = svm_get_nr_class(model);
  38. double *prob_estimates = NULL;
  39. int j;
  40. if (predict_probability) {
  41. if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
  42. info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",
  43. svm_get_svr_probability(model));
  44. else if (svm_type == ONE_CLASS) {
  45. // nr_class = 2 for ONE_CLASS
  46. prob_estimates = (double *) malloc(nr_class * sizeof(double));
  47. fprintf(output, "label normal outlier\n");
  48. } else {
  49. int *labels = (int *) malloc(nr_class * sizeof(int));
  50. svm_get_labels(model, labels);
  51. prob_estimates = (double *) malloc(nr_class * sizeof(double));
  52. fprintf(output, "labels");
  53. for (j = 0; j < nr_class; j++)
  54. fprintf(output, " %d", labels[j]);
  55. fprintf(output, "\n");
  56. free(labels);
  57. }
  58. }
  59. max_line_len = 1024;
  60. line = (char *) malloc(max_line_len * sizeof(char));
  61. while (readline(input) != NULL) {
  62. int i = 0;
  63. double target_label, predict_label;
  64. char *idx, *val, *label, *endptr;
  65. int inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
  66. label = strtok(line, " \t\n");
  67. if (label == NULL) // empty line
  68. exit_input_error(total + 1);
  69. target_label = strtod(label, &endptr);
  70. if (endptr == label || *endptr != '\0')
  71. exit_input_error(total + 1);
  72. while (1) {
  73. if (i >= max_nr_attr - 1) // need one more for index = -1
  74. {
  75. max_nr_attr *= 2;
  76. x = (struct svm_node *) realloc(x, max_nr_attr * sizeof(struct svm_node));
  77. }
  78. idx = strtok(NULL, ":");
  79. val = strtok(NULL, " \t");
  80. if (val == NULL)
  81. break;
  82. errno = 0;
  83. x[i].index = (int) strtol(idx, &endptr, 10);
  84. if (endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
  85. exit_input_error(total + 1);
  86. else
  87. inst_max_index = x[i].index;
  88. errno = 0;
  89. x[i].value = strtod(val, &endptr);
  90. if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
  91. exit_input_error(total + 1);
  92. ++i;
  93. }
  94. x[i].index = -1;
  95. if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC || svm_type == ONE_CLASS)) {
  96. predict_label = svm_predict_probability(model, x, prob_estimates);
  97. fprintf(output, "%g", predict_label);
  98. for (j = 0; j < nr_class; j++)
  99. fprintf(output, " %g", prob_estimates[j]);
  100. fprintf(output, "\n");
  101. } else {
  102. predict_label = svm_predict(model, x);
  103. fprintf(output, "%.17g\n", predict_label);
  104. }
  105. if (predict_label == target_label)
  106. ++correct;
  107. error += (predict_label - target_label) * (predict_label - target_label);
  108. sump += predict_label;
  109. sumt += target_label;
  110. sumpp += predict_label * predict_label;
  111. sumtt += target_label * target_label;
  112. sumpt += predict_label * target_label;
  113. ++total;
  114. }
  115. if (svm_type == NU_SVR || svm_type == EPSILON_SVR) {
  116. info("Mean squared error = %g (regression)\n", error / total);
  117. info("Squared correlation coefficient = %g (regression)\n",
  118. ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) /
  119. ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))
  120. );
  121. } else
  122. info("Accuracy = %g%% (%d/%d) (classification)\n",
  123. (double) correct / total * 100, correct, total);
  124. if (predict_probability)
  125. free(prob_estimates);
  126. }
  127. void exit_with_help() {
  128. printf(
  129. "Usage: svm-predict [options] test_file model_file output_file\n"
  130. "options:\n"
  131. "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); for one-class SVM only 0 is supported\n"
  132. "-q : quiet mode (no outputs)\n"
  133. );
  134. exit(1);
  135. }
  136. int main(int argc, char **argv) {
  137. FILE *input, *output;
  138. int i;
  139. // parse options
  140. for (i = 1; i < argc; i++) {
  141. if (argv[i][0] != '-') break;
  142. ++i;
  143. switch (argv[i - 1][1]) {
  144. case 'b':
  145. predict_probability = atoi(argv[i]);
  146. break;
  147. case 'q':
  148. info = &print_null;
  149. i--;
  150. break;
  151. default:
  152. fprintf(stderr, "Unknown option: -%c\n", argv[i - 1][1]);
  153. exit_with_help();
  154. }
  155. }
  156. if (i >= argc - 2)
  157. exit_with_help();
  158. input = fopen(argv[i], "r");
  159. if (input == NULL) {
  160. fprintf(stderr, "can't open input file %s\n", argv[i]);
  161. exit(1);
  162. }
  163. output = fopen(argv[i + 2], "w");
  164. if (output == NULL) {
  165. fprintf(stderr, "can't open output file %s\n", argv[i + 2]);
  166. exit(1);
  167. }
  168. if ((model = svm_load_model(argv[i + 1])) == 0) {
  169. fprintf(stderr, "can't open model file %s\n", argv[i + 1]);
  170. exit(1);
  171. }
  172. x = (struct svm_node *) malloc(max_nr_attr * sizeof(struct svm_node));
  173. if (predict_probability) {
  174. if (svm_check_probability_model(model) == 0) {
  175. fprintf(stderr, "Model does not support probabiliy estimates\n");
  176. exit(1);
  177. }
  178. } else {
  179. if (svm_check_probability_model(model) != 0)
  180. info("Model supports probability estimates, but disabled in prediction.\n");
  181. }
  182. predict(input, output);
  183. svm_free_and_destroy_model(&model);
  184. free(x);
  185. free(line);
  186. fclose(input);
  187. fclose(output);
  188. return 0;
  189. }