import matplotlib import numpy as np from sklearn.svm import OneClassSVM from data.loader_dc import DataLoader import matplotlib.pyplot as plt class OcSvm: def __init__(self): self._clf: "OneClassSVM" = OneClassSVM(nu=7.1e-3, gamma=1.6e-7) # scale def fit(self, data: "np.ndarray"): self._clf.fit(data) res: "np.ndarray" = self._clf.predict(data) print(f"train detail: {res[res == 1].size}/{data.shape[0]}") def predict(self, data: "np.ndarray", which: "int"): res: "np.ndarray" = self._clf.predict(data) print(f"predict detail: {res[res == which].size}/{data.shape[0]}") def main(): np.random.seed(5) loader = DataLoader() cls = OcSvm() cls.fit(loader.have) cls.predict(loader.test, 1) cls.predict(loader.none, -1) if __name__ == "__main__": matplotlib.use("TkAgg") plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签SimHei plt.rcParams["axes.unicode_minus"] = False main()