1234567891011121314151617181920212223242526272829303132333435 |
- import matplotlib
- import numpy as np
- from sklearn.svm import OneClassSVM
- from data.loader_dc import DataLoader
- import matplotlib.pyplot as plt
- class OcSvm:
- def __init__(self):
- self._clf: "OneClassSVM" = OneClassSVM(nu=7.1e-3, gamma=1.6e-7) # scale
- def fit(self, data: "np.ndarray"):
- self._clf.fit(data)
- res: "np.ndarray" = self._clf.predict(data)
- print(f"train detail: {res[res == 1].size}/{data.shape[0]}")
- def predict(self, data: "np.ndarray", which: "int"):
- res: "np.ndarray" = self._clf.predict(data)
- print(f"predict detail: {res[res == which].size}/{data.shape[0]}")
- def main():
- np.random.seed(5)
- loader = DataLoader()
- cls = OcSvm()
- cls.fit(loader.have)
- cls.predict(loader.test, 1)
- cls.predict(loader.none, -1)
- if __name__ == "__main__":
- matplotlib.use("TkAgg")
- plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签SimHei
- plt.rcParams["axes.unicode_minus"] = False
- main()
|