test calls to finbert working - need to chunk at 512

This commit is contained in:
Jericho 2025-03-22 06:25:14 +08:00
parent b4ef1ad8a2
commit 3864d7e93c

View File

@ -4,8 +4,8 @@ from sqlalchemy import create_engine, Column, String, Float, MetaData, Table
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from article_analyzer import ArticleAnalyzer from article_analyzer import ArticleAnalyzer
import nltk
sys.path.append(os.path.join(os.path.dirname(__file__), 'finBERT', 'finbert')) import logging
Base = declarative_base() Base = declarative_base()
@ -28,6 +28,11 @@ def read_html_files(folder_path):
if __name__ == "__main__": if __name__ == "__main__":
# nltk.set_proxy('http://127.0.0.1:7890')
# nltk.download('punkt_tab')
logging.basicConfig(level=logging.CRITICAL)
analyzer = ArticleAnalyzer() analyzer = ArticleAnalyzer()
engine = create_engine('sqlite:///article_analysis.db') engine = create_engine('sqlite:///article_analysis.db')
@ -39,29 +44,33 @@ if __name__ == "__main__":
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
result = analyzer.classify_article_finbert("Strong earning growth and expending market shares have positionned the company for long term success.")
print(f'result {result}')
for file, content in html_files.items():
result = analyzer.classify_article_finbert(content)
filename = os.path.basename(file) # for file, content in html_files.items():
# result = analyzer.classify_article_finbert(content)
label = result[0]['label'] # filename = os.path.basename(file)
score = result[0]['score'] # print(f'result {result}')
analysis = ArticleAnalysis(filename=filename, label=label, score=score) # label = result[0]['label']
# score = result[0]['score']
try: # analysis = ArticleAnalysis(filename=filename, label=label, score=score)
session.add(analysis)
session.commit()
except:
session.rollback()
existing = session.query(ArticleAnalysis).filter_by(filename=filename).first() # try:
if existing: # session.add(analysis)
existing.label = label # session.commit()
existing.score = score # except:
session.commit() # session.rollback()
finally:
session.close()
print(f"article [{file}] - analyzed as [{result}]\n") # existing = session.query(ArticleAnalysis).filter_by(filename=filename).first()
# if existing:
# existing.label = label
# existing.score = score
# session.commit()
# finally:
# session.close()
# print(f"article [{file}] - analyzed as [{result}]\n")