updated finbert calls
This commit is contained in:
parent
885f51b83d
commit
268c09be9e
@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from transformers import pipeline
|
||||
from finBERT.finbert import predict
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
import ollama
|
||||
from pydantic import BaseModel
|
||||
import markdownify
|
||||
@ -24,7 +25,7 @@ class ArticleClassification(BaseModel):
|
||||
|
||||
class ArticleAnalyzer:
|
||||
def __init__(self):
|
||||
self.classifier = pipeline("text-classification", model="ProsusAI/finbert")
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained('args.model_path', num_labels=3, cache_dir=None)
|
||||
self.base_prompt = """
|
||||
Classify the following article into one of these categories:
|
||||
- Regulatory News
|
||||
@ -74,12 +75,6 @@ class ArticleAnalyzer:
|
||||
|
||||
def classify_article_finbert(self, article_html):
|
||||
article_md = self.convert_to_markdown(article_html)
|
||||
chunk_size = 512
|
||||
chunks = [article_md[i:i + chunk_size] for i in range(0, len(article_md), chunk_size)]
|
||||
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
result = self.classifier(chunk)
|
||||
results.append(result)
|
||||
results = predict(article_md, model=self.model, use_gpu=True)
|
||||
|
||||
return results
|
||||
Loading…
x
Reference in New Issue
Block a user