From 268c09be9e7d48dbf4af54e1ab4457433742423e Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Sat, 22 Mar 2025 04:52:17 +0800 Subject: [PATCH] updated finbert calls --- article_analyzer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/article_analyzer.py b/article_analyzer.py index b6409a0..0476585 100644 --- a/article_analyzer.py +++ b/article_analyzer.py @@ -1,5 +1,6 @@ from enum import Enum -from transformers import pipeline +from finBERT.finbert import predict +from transformers import AutoModelForSequenceClassification import ollama from pydantic import BaseModel import markdownify @@ -24,7 +25,7 @@ class ArticleClassification(BaseModel): class ArticleAnalyzer: def __init__(self): - self.classifier = pipeline("text-classification", model="ProsusAI/finbert") + self.model = AutoModelForSequenceClassification.from_pretrained('args.model_path', num_labels=3, cache_dir=None) self.base_prompt = """ Classify the following article into one of these categories: - Regulatory News @@ -74,12 +75,6 @@ class ArticleAnalyzer: def classify_article_finbert(self, article_html): article_md = self.convert_to_markdown(article_html) - chunk_size = 512 - chunks = [article_md[i:i + chunk_size] for i in range(0, len(article_md), chunk_size)] - - results = [] - for chunk in chunks: - result = self.classifier(chunk) - results.append(result) + results = predict(article_md, model=self.model, use_gpu=True) return results \ No newline at end of file