diff --git a/article_analyzer.py b/article_analyzer.py index 103bbf3..16207c2 100644 --- a/article_analyzer.py +++ b/article_analyzer.py @@ -26,6 +26,7 @@ class ArticleClassification(BaseModel): class ArticleAnalyzer: def __init__(self): self.model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", num_labels=3, cache_dir=None) + self.model.to("cuda") self.base_prompt = """ Classify the following article into one of these categories: - Regulatory News