ocsvm-v1.py 1011 B

1234567891011121314151617181920212223242526272829303132333435
  1. import matplotlib
  2. import numpy as np
  3. from sklearn.svm import OneClassSVM
  4. from data.loader_dc import DataLoader
  5. import matplotlib.pyplot as plt
  6. class OcSvm:
  7. def __init__(self):
  8. self._clf: "OneClassSVM" = OneClassSVM(nu=7.1e-3, gamma=1.6e-7) # scale
  9. def fit(self, data: "np.ndarray"):
  10. self._clf.fit(data)
  11. res: "np.ndarray" = self._clf.predict(data)
  12. print(f"train detail: {res[res == 1].size}/{data.shape[0]}")
  13. def predict(self, data: "np.ndarray", which: "int"):
  14. res: "np.ndarray" = self._clf.predict(data)
  15. print(f"predict detail: {res[res == which].size}/{data.shape[0]}")
  16. def main():
  17. np.random.seed(5)
  18. loader = DataLoader()
  19. cls = OcSvm()
  20. cls.fit(loader.have)
  21. cls.predict(loader.test, 1)
  22. cls.predict(loader.none, -1)
  23. if __name__ == "__main__":
  24. matplotlib.use("TkAgg")
  25. plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签SimHei
  26. plt.rcParams["axes.unicode_minus"] = False
  27. main()